【PyTorch】按照论文思想实现通道和空间两种注意力机制

网友投稿 916 2022-09-30

【PyTorch】按照论文思想实现通道和空间两种注意力机制

【PyTorch】按照论文思想实现通道和空间两种注意力机制

from turtle import forwardimport torchfrom torch import nnclass ChannelAttention(nn.Module): # ratio表示MLP中,中间层in_planes缩小的比例 def __init__(self, in_plances, ratio=16) -> None: super().__init__() self.max_pool = nn.AdaptiveMaxPool2d((1,1)) self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) ''' (1) in_plances / ratio, 其结果为小数,导致模型报错; (2) in_plances // ratio, 向下取整; (3) Conv2d中bias为False主要是为了模拟MLP多层感知机的功能; ''' self.mlp = nn.Sequential( # 此处没有中括号 nn.Conv2d(in_plances, in_plances // ratio, 1, bias=False), # 此处为什么卷积不需要偏置,是为了模拟FC nn.ReLU(), nn.Conv2d(in_plances // ratio, in_plances, 1, bias=False) # python 中/与//的区别 ) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = self.max_pool(x) x1 = self.mlp(x1) x2 = self.avg_pool(x) x2 = self.mlp(x2) # 此处直接相加,而不是拼接 # torch.cat(x1, x2) out = x1 + x2 out = self.sigmoid(out) return outclass SpatialAttention(nn.Module): def __init__(self) -> None: super().__init__() self.conv2d = nn.Conv2d(2, 1, 7, padding=3, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): ''' 注意,此处并不是简单的最大值和均值池化操作,而是cross channel的; ''' avg_pool = torch.mean(x, dim=1, keepdim=True) # Bx1xHxW max_pool, _ = torch.max(x, dim=1, keepdim=True) # Bx1xHxW, 此处非常容易出错,少_ # Bx2xHxW out = torch.cat([avg_pool, max_pool], dim=1) out = self.conv2d(out) out = self.sigmoid(out) return out if __name__ == '__main__': from torchinfo import summary # import hiddenlayer as h device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Channel Attention') layer = ChannelAttention(32).to(device) summary(layer, (1, 32, 224, 224)) print('Spatial Attention') layer = SpatialAttention().to(device) summary(layer, (1, 32, 224, 224)) # graph = h.build_graph(layer, torch.zeros([1, 32, 224, 224])) # graph.theme = h.graph.THEMES['blue'].copy() # graph.save('test.png') print('done!')

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:小程序页面跳转有哪几种方式(小程序直接跳转小程序)
下一篇:利用小程序api接口动态获取元素宽高
相关文章

 发表评论

暂时没有评论,来抢沙发吧~