cuda out of memory(PyTorch)

网友投稿 728 2022-11-23

cuda out of memory(PyTorch)

cuda out of memory(PyTorch)

文章目录

​​情况1​​​​情况2​​

​​解法1​​​​解法2​​

情况1

model.forward()过程中,中间变量过多,导致GPU使用量增大,如下所示:

def forward(self, x): batch_size = x.shape[0] x0 = self.base_model(x) # Add positional info x1 = self.up1(x0) # 512, 32, 32 x2 = self.up2(x1) # 256, 64, 64 x3 = self.up3(x2) # 256, 128, 128 outc = self.outc(x3) # 1, 128, 128 outr = self.outr(x3) # 2, 128, 128 return outc, outr

将中间传递的变量统一为x:

def forward(self, x): batch_size = x.shape[0] x = self.base_model(x) # Add positional info x = self.up1(x) # 512, 32, 32 x = self.up2(x) # 256, 64, 64 x = self.up3(x) # 256, 128, 128 outc = self.outc(x) # 1, 128, 128 outr = self.outr(x) # 2, 128, 128 return outc, outr

情况2

程序运行过程中会产生很多中间变量,pytorch不会清理这些中间变量,就会爆显存。

解法1

loss = self.criteration(output, label)loss_sum += loss####更改为loss = self.criteration(output, label)loss_sum += loss.item()

解法2

torch.cuda.empty_cache() 可清理缓存,应该是最有效最便捷的

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Leetcode题目71. 简化路径
下一篇:遥感影像32位转8位(python)
相关文章

 发表评论

暂时没有评论,来抢沙发吧~