小程序游戏SDK如何推动游戏开发的快速创新与变革
710
2022-11-23
cuda out of memory(PyTorch)
文章目录
情况1情况2
解法1解法2
情况1
model.forward()过程中,中间变量过多,导致GPU使用量增大,如下所示:
def forward(self, x): batch_size = x.shape[0] x0 = self.base_model(x) # Add positional info x1 = self.up1(x0) # 512, 32, 32 x2 = self.up2(x1) # 256, 64, 64 x3 = self.up3(x2) # 256, 128, 128 outc = self.outc(x3) # 1, 128, 128 outr = self.outr(x3) # 2, 128, 128 return outc, outr
将中间传递的变量统一为x:
def forward(self, x): batch_size = x.shape[0] x = self.base_model(x) # Add positional info x = self.up1(x) # 512, 32, 32 x = self.up2(x) # 256, 64, 64 x = self.up3(x) # 256, 128, 128 outc = self.outc(x) # 1, 128, 128 outr = self.outr(x) # 2, 128, 128 return outc, outr
情况2
程序运行过程中会产生很多中间变量,pytorch不会清理这些中间变量,就会爆显存。
解法1
loss = self.criteration(output, label)loss_sum += loss####更改为loss = self.criteration(output, label)loss_sum += loss.item()
解法2
torch.cuda.empty_cache() 可清理缓存,应该是最有效最便捷的
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~