GoogLeNet的搭建

2018年10月13日10:24:43 评论 328

废话不多说,上代码:

 

  1. import torch
  2. import torch.nn as nn
  3. #定义一个卷积加一个relu激活函数和一个batchnorm作为一个基本的层结构
  4. def conv_relu(in_channels, out_channels, kernel, stride=1, padding=0):
  5.      layer = nn.Sequential(
  6.           nn.Conv2d(in_channels, out_channels, kernel, stride, padding),
  7.           nn.BatchNorm2d(out_channels, eps=1e-3),
  8.           nn.ReLU(True)
  9.      )
  10.      return layer
  11. class inception(nn.Module):
  12.     def __init__(self, in_channel=3, out1_1=64, out2_1=48, out2_3=64, out3_1=64, out3_5=96, out4_1=32):
  13.           super(inception, self).__init__()
  14.           #第一条线路
  15.           self.branch1x1 = conv_relu(in_channel, out1_1, 1)
  16.           #第二条线路
  17.           self.branch3x3 = nn.Sequential(
  18.                conv_relu(in_channel, out2_1, 1),
  19.                conv_relu(out2_1, out2_3, 3, padding=1)
  20.           )
  21.           #第三条线路
  22.           self.branch5x5 = nn.Sequential(
  23.                conv_relu(in_channel, out3_1, 1),
  24.                conv_relu(out3_1, out3_5, 5, padding=2)
  25.           )
  26.           #第四条线路
  27.           self.branch_pool = nn.Sequential(
  28.                nn.MaxPool2d(3, stride=1, padding=1),
  29.                conv_relu(in_channel, out4_1, 1)
  30.           )
  31.     def forward(self, x):
  32.           f1 = self.branch1x1(x)
  33.           f2 = self.branch3x3(x)
  34.           f3 = self.branch5x5(x)
  35.           f4 = self.branch_pool(x)
  36.           output = torch.cat((f1, f2, f3, f4), dim=1)
  37.           return output
  38. class googlenet(nn.Module):
  39.     def __init__(self, verbose=False):
  40.           super(googlenetself).__init__()
  41.           self.block1 = nn.Sequential(
  42.                conv_relu(3, out_channels=64, kernel=7, stride=2, padding=3),
  43.                nn.MaxPool2d(3, 2)
  44.           )
  45.           self.block2 = nn.Sequential(
  46.                conv_relu(64, 64, kernel=1),
  47.                conv_relu(64, 192, kernel=3, padding=1),
  48.                nn.MaxPool2d(3, 2)
  49.           )
  50.           self.block3 = nn.Sequential(
  51.                inception(192, 64, 96, 128, 16, 32, 32),
  52.                inception(256, 128, 128, 192, 32, 96, 64),
  53.                nn.MaxPool2d(3, 2)
  54.           )
  55.           self.block4 = nn.Sequential(
  56.                inception(480, 192, 96, 208, 16, 48, 64),
  57.                inception(512, 160, 112, 224, 24, 64, 64),
  58.                inception(512, 128, 128, 256, 24, 64, 64),
  59.                inception(512, 112, 144, 288, 32, 64, 64),
  60.                inception(528, 256, 160, 320, 32, 128, 128),
  61.                nn.MaxPool2d(3, 2)
  62.           )
  63.           self.block5 = nn.Sequential(
  64.                inception(832, 256, 160, 320, 32, 128, 128),
  65.                inception(832, 384, 182, 384, 48, 128, 128),
  66.                nn.AvgPool2d(2)
  67.           )
  68.           self.classifier = nn.Linear(1024, 10)
  69.     def forward(self, x):
  70.           x = self.block1(x)
  71.           x = self.block2(x)
  72.           x = self.block3(x)
  73.           x = self.block4(x)
  74.           x = self.block5(x)
  75.           x = x.view(x.size(0), -1)
  76.           x = self.classifier(x)
  77.           return x
  78. #可以看到输入的尺寸不断减小,通道的维度不断增加
  79. googlenet=googlenet(True)
  80. print(googlenet)

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: