DenseNet网络搭建

2018年10月20日15:11:46DenseNet网络搭建已关闭评论 355

代码如下:

 

  1. import torch  
  2. import torch.nn as nn  
  3. #首先定义一个卷积块,其顺序是bn->relu->conv  
  4. def conv_block(in_channel, out_channel):  
  5.      layer = nn.Sequential(  
  6.           nn.BatchNorm2d(in_channel),  
  7.           nn.ReLU(True),  
  8.           nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)  
  9.      )  
  10.      return layer  
  11. class dense_block(nn.Module):  
  12.      def __init__(self, in_channel, growth_rate, num_layers):  
  13.           super(dense_block, self).__init__()  
  14.           block = []  
  15.           channel = in_channel  
  16.           for i in range(num_layers):  
  17.                block.append(conv_block(channel, growth_rate))  
  18.                channel += growth_rate  
  19.           self.net = nn.Sequential(*block)  
  20.             
  21.      def forward(self, x):  
  22.           for layer in self.net:  
  23.                out = layer(x)  
  24.                x = torch.cat((out, x), dim=1)  
  25.           return x  
  26. #除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet   
  27. #会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,  
  28. #为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用  
  29. # 1 x 1 的卷积  
  30. def transition(in_channel, out_channel):  
  31.      trans_layer = nn.Sequential(  
  32.           nn.BatchNorm2d(in_channel),  
  33.           nn.ReLU(True),  
  34.           nn.Conv2d(in_channel, out_channel, 1),  
  35.           nn.AvgPool2d(2, 2)  
  36.           )  
  37.      return trans_layer  
  38.   
  39. class DenseNet(nn.Module):  
  40.      def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):  
  41.           super(DenseNet, self).__init__()  
  42.           self.block1 = nn.Sequential(  
  43.                nn.Conv2d(in_channel, 64, 7, 2, 3),  
  44.                nn.BatchNorm2d(64),  
  45.                nn.ReLU(True),  
  46.                nn.MaxPool2d(3, 2, padding=1)  
  47.           )  
  48.             
  49.           channels = 64  
  50.           block = []  
  51.           for i, layers in enumerate(block_layers):  
  52.                block.append(dense_block(channels, growth_rate, layers))  
  53.                channels += layers * growth_rate  
  54.                if i!= len(block_layers) - 1:  
  55.                     block.append(transition(channels, channels // 2)) #通过transition 层将大小减半,通道数减半  
  56.                     channels = channels // 2  
  57.                       
  58.           self.block2 = nn.Sequential(*block)  
  59.           self.block2.add_module('bn', nn.BatchNorm2d(channels))  
  60.           self.block2.add_module('relu', nn.ReLU(True))  
  61.           self.block2.add_module('avg_pool', nn.AvgPool2d(3))  
  62.             
  63.           self.classifier = nn.Linear(channels, num_classes)  
  64.             
  65.      def forward(self, x):  
  66.           x = self.block1(x)  
  67.           x = self.block2(x)  
  68.             
  69.           x = x.view(x.shape[0], -1)  
  70.           x = self.classifier(x)  
  71.           return x  
  72.       
  73. densenet=DenseNet(3,10)  
  74. print(densenet)  

历史上的今天:

当时只道是寻常