ResNet的搭建

2018年10月15日15:10:13ResNet的搭建已关闭评论 514

pytorch代码如下:

 

  1. import torch  
  2. import torch.nn as nn  
  3. import torch.nn.functional as F  
  4.    
  5. class ResidualBlock(nn.Module):  
  6.       
  7.     ''' 
  8.     实现子module: Residual Block 
  9.     '''  
  10.       
  11.     def __init__(self,inchannel,outchannel,stride=1,shortcut=None):  
  12.           
  13.         super(ResidualBlock,self).__init__()  
  14.           
  15.         self.left=nn.Sequential(  
  16.             nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False),  
  17.             nn.BatchNorm2d(outchannel),  
  18.             nn.ReLU(inplace=True),  
  19.             nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),  
  20.             nn.BatchNorm2d(outchannel)  
  21.         )  
  22.         self.right=shortcut  
  23.       
  24.     def forward(self,x):  
  25.           
  26.         out=self.left(x)  
  27.         residual=x if self.right is None else self.right(x)  
  28.         out+=residual  
  29.         return F.relu(out)  
  30.       
  31. class ResNet(nn.Module):  
  32.       
  33.     ''' 
  34.     实现主module:ResNet34 
  35.     ResNet34 包含多个layer,每个layer又包含多个residual block 
  36.     用子module来实现residual block,用_make_layer函数来实现layer 
  37.     '''  
  38.       
  39.     def __init__(self,num_classes=1000):  
  40.           
  41.         super(ResNet,self).__init__()  
  42.           
  43.         # 前几层图像转换  
  44.         self.pre=nn.Sequential(  
  45.             nn.Conv2d(3,64,7,2,3,bias=False),  
  46.             nn.BatchNorm2d(64),  
  47.             nn.ReLU(inplace=True),  
  48.             nn.MaxPool2d(3,2,1)  
  49.         )  
  50.           
  51.         # 重复的layer,分别有3,4,6,3个residual block  
  52.         self.layer1=self._make_layer(64,64,3)  
  53.         self.layer2=self._make_layer(64,128,4,stride=2)  
  54.         self.layer3=self._make_layer(128,256,6,stride=2)  
  55.         self.layer4=self._make_layer(256,512,3,stride=2)  
  56.           
  57.         #分类用的全连接  
  58.         self.fc=nn.Linear(512,num_classes)  
  59.       
  60.     def _make_layer(self,inchannel,outchannel,bloch_num,stride=1):  
  61.           
  62.         ''' 
  63.         构建layer,包含多个residual block 
  64.         '''  
  65.         shortcut=nn.Sequential(  
  66.             nn.Conv2d(inchannel,outchannel,1,stride,bias=False),  
  67.             nn.BatchNorm2d(outchannel)  
  68.         )  
  69.         layers=[]  
  70.         layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut))  
  71.         for i in range(1,bloch_num):  
  72.             layers.append(ResidualBlock(outchannel,outchannel))  
  73.         return nn.Sequential(*layers)  
  74.       
  75.     def forward(self,x):  
  76.           
  77.         x=self.pre(x)  
  78.           
  79.         x=self.layer1(x)  
  80.         x=self.layer2(x)  
  81.         x=self.layer3(x)  
  82.         x=self.layer4(x)  
  83.           
  84.         x=F.avg_pool2d(x,7)  
  85.         x=x.view(x.size(0),-1)  
  86.         return self.fc(x)  
  87. resnet=ResNet()  
  88. print(resnet)  

历史上的今天:

当时只道是寻常