三种SENet Pytorch代码如下:文章源自联网快讯-https://x1995.cn/4534.html
- class CSEModule(nn.Module):
- def __init__(self, ch, re=16):
- super().__init__()
- self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(ch,ch//re,1),
- nn.ReLU(inplace=True),
- nn.Conv2d(ch//re,ch,1),
- nn.Sigmoid())
- def forward(self, x):
- return x * self.cSE(x) #cSE
- class SSEModule(nn.Module):
- def __init__(self, ch, re=16):
- super().__init__()
- self.sSE = nn.Sequential(nn.Conv2d(ch,ch,1),
- nn.Sigmoid())
- def forward(self, x):
- return x * self.sSE(x) #sSE
- class SCSEModule(nn.Module):
- def __init__(self, ch, re=16):
- super().__init__()
- self.cSE = nn.Sequential(nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(ch,ch//re,1),
- nn.ReLU(inplace=True),
- nn.Conv2d(ch//re,ch,1),
- nn.Sigmoid())
- self.sSE = nn.Sequential(nn.Conv2d(ch,ch,1),
- nn.Sigmoid())
- def forward(self, x):
- return x * self.cSE(x) + x * self.sSE(x) #scSE
论文原文:Concurrent Spatial and Channel Squeeze & Excitation in Fullly Convolutional Networks文章源自联网快讯-https://x1995.cn/4534.html
继续阅读
评论