react 前端框架如何驱动企业数字化转型与创新发展
1475
2022-10-27
使用PyTorch创建可逆神经网络的框架
RevTorch
Framework for creating (partially) reversible neural networks with PyTorch
RevTorch is introduced and explained in our paper A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation, which was accepted for presentation at MICCAI 2019.
If you find this code helpful in your research please cite the following paper:
@article{PartiallyRevUnet2019Bruegger, author={Br{\"u}gger, Robin and Baumgartner, Christian F. and Konukoglu, Ender}, title={A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation}, journal={arXiv:1906.06148}, year={2019},
Installation
Use pip to install RevTorch:
$ pip install revtorch
RevTorch requires PyTorch. However, PyTorch is not included in the dependencies since the required PyTorch version is dependent on your system. Please install PyTorch following the instructions on the PyTorch website.
Usage
This example shows how to use the RevTorch framework.
import torchimport torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport revtorch as rvdef train(): trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms.ToTensor()) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True) net = PartiallyReversibleNet() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters()) for epoch in range(2): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() #logging stuff running_loss += loss.item() LOG_INTERVAL = 200 if i % LOG_INTERVAL == (LOG_INTERVAL-1): # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / LOG_INTERVAL)) running_loss = 0.0class PartiallyReversibleNet(nn.Module): def __init__(self): super(PartiallyReversibleNet, self).__init__() #initial non-reversible convolution to get to 32 channels self.conv1 = nn.Conv2d(3, 32, 3) #construct reversible sequencce with 4 reversible blocks blocks = [] for i in range(4): #f and g must both be a nn.Module whos output has the same shape as its input f_func = nn.Sequential(nn.ReLU(), nn.Conv2d(16, 16, 3, padding=1)) g_func = nn.Sequential(nn.ReLU(), nn.Conv2d(16, 16, 3, padding=1)) #we construct a reversible block with our F and G functions blocks.append(rv.ReversibleBlock(f_func, g_func)) #pack all reversible blocks into a reversible sequence self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks)) #non-reversible convolution to get to 10 channels (one for each label) self.conv2 = nn.Conv2d(32, 10, 3) def forward(self, x): x = self.conv1(x) #the reversible sequence can be used like any other nn.Module. Memory-saving backpropagation is used automatically x = self.sequence(x) x = self.conv2(F.relu(x)) x = F.avg_pool2d(x, (x.shape[2], x.shape[3])) x = x.view(x.shape[0], x.shape[1]) return xif __name__ == "__main__": train()
Python version
Tested with Python 3.6 and PyTorch 1.1.0. Should work with any version of Python 3.
Changelog
Version 0.2.4
Added option to disable eager discarding of variables to allow for multiple backward() calls
Version 0.2.3
Added option to use the same random seed for the forward and backwar pass (Pull request)
Version 0.2.1
Added option to select the dimension along which the tensor is split (Pull request)
Version 0.2.0
Fixed memory leak when not consuming output of the reversible block (Issue)
Version 0.1.0
Initial release
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~