用DenseNet训练数据集CIFAR-10

Miracle
1040
文章
51
评论
2018年10月22日18:44:52 2 6758字阅读22分31秒

pytorch代码如下:

  1. # -*- coding: utf-8 -*-  
  2. """ 
  3. Created on Wed Sep  5 09:10:52 2018 
  4. @author: www 
  5. """  
  6.   
  7. import sys  
  8.   
  9. sys.path.append('...')  
  10.   
  11. import numpy as np  
  12. import torch  
  13. from torch import nn  
  14. from torch.autograd import Variable  
  15. from torchvision.datasets import CIFAR10  
  16.   
  17.   
  18. # 首先定义一个卷积块,其顺序是bn->relu->conv  
  19. def conv_block(in_channel, out_channel):  
  20.     layer = nn.Sequential(  
  21.         nn.BatchNorm2d(in_channel),  
  22.         nn.ReLU(True),  
  23.         nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)  
  24.     )  
  25.     return layer  
  26.   
  27.   
  28. class dense_block(nn.Module):  
  29.     def __init__(self, in_channel, growth_rate, num_layers):  
  30.         super(dense_block, self).__init__()  
  31.         block = []  
  32.         channel = in_channel  
  33.         for i in range(num_layers):  
  34.             block.append(conv_block(channel, growth_rate))  
  35.             channel += growth_rate  
  36.         self.net = nn.Sequential(*block)  
  37.   
  38.     def forward(self, x):  
  39.         for layer in self.net:  
  40.             out = layer(x)  
  41.             x = torch.cat((out, x), dim=1)  
  42.         return x  
  43.   
  44.   
  45. # 验证输出是否正确  
  46. test_net = dense_block(3, 12, 3)  
  47. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  48. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))  
  49. test_y = test_net(test_x)  
  50. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))  
  51.   
  52.   
  53. # 除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet  
  54. # 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,  
  55. # 为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用  
  56. # 1 x 1 的卷积  
  57. def transition(in_channel, out_channel):  
  58.     trans_layer = nn.Sequential(  
  59.         nn.BatchNorm2d(in_channel),  
  60.         nn.ReLU(True),  
  61.         nn.Conv2d(in_channel, out_channel, 1),  
  62.         nn.AvgPool2d(2, 2)  
  63.     )  
  64.     return trans_layer  
  65.   
  66.   
  67. # 验证一下过渡层是否正确  
  68. test_net = transition(3, 12)  
  69. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  70. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))  
  71. test_y = test_net(test_x)  
  72. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))  
  73.   
  74.   
  75. class densenet(nn.Module):  
  76.     def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):  
  77.         super(densenet, self).__init__()  
  78.         self.block1 = nn.Sequential(  
  79.             nn.Conv2d(in_channel, 64, 7, 2, 3),  
  80.             nn.BatchNorm2d(64),  
  81.             nn.ReLU(True),  
  82.             nn.MaxPool2d(3, 2, padding=1)  
  83.         )  
  84.   
  85.         channels = 64  
  86.         block = []  
  87.         for i, layers in enumerate(block_layers):  
  88.             block.append(dense_block(channels, growth_rate, layers))  
  89.             channels += layers * growth_rate  
  90.             if i != len(block_layers) - 1:  
  91.                 block.append(transition(channels, channels // 2))  # 通过transition 层将大小减半,通道数减半  
  92.                 channels = channels // 2  
  93.   
  94.         self.block2 = nn.Sequential(*block)  
  95.         self.block2.add_module('bn', nn.BatchNorm2d(channels))  
  96.         self.block2.add_module('relu', nn.ReLU(True))  
  97.         self.block2.add_module('avg_pool', nn.AvgPool2d(3))  
  98.   
  99.         self.classifier = nn.Linear(channels, num_classes)  
  100.   
  101.     def forward(self, x):  
  102.         x = self.block1(x)  
  103.         x = self.block2(x)  
  104.   
  105.         x = x.view(x.shape[0], -1)  
  106.         x = self.classifier(x)  
  107.         return x  
  108.   
  109.   
  110. test_net = densenet(3, 10)  
  111. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  112. test_y = test_net(test_x)  
  113. print('output: {}'.format(test_y.shape))  
  114.   
  115.   
  116. # 数据预处理函数  
  117. def data_tf(x):  
  118.     x = x.resize((96, 96), 2)  # 将图片放大到 96 x 96  
  119.     x = np.array(x, dtype='float32') / 255  
  120.     x = (x - 0.5) / 0.5  # 标准化,这个技巧之后会讲到  
  121.     x = x.transpose((2, 0, 1))  # 将 channel 放到第一维,只是 pytorch 要求的输入方式  
  122.     x = torch.from_numpy(x)  
  123.     return x  
  124.   
  125.   
  126. train_set = CIFAR10('./data', train=True, transform=data_tf)  
  127. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)  
  128. test_set = CIFAR10('./data', train=False, transform=data_tf)  
  129. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)  
  130.   
  131. net = densenet(3, 10)  
  132. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)  
  133. criterion = nn.CrossEntropyLoss()  
  134.   
  135. from datetime import datetime  
  136.   
  137.   
  138. def get_acc(output, label):  
  139.     total = output.shape[0]  
  140.     _, pred_label = output.max(1)  
  141.     num_correct = (pred_label == label).sum().data[0]  
  142.     return num_correct / total  
  143.   
  144.   
  145. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):  
  146.     if torch.cuda.is_available():  
  147.         net = net.cuda()  
  148.     prev_time = datetime.now()  
  149.     for epoch in range(num_epochs):  
  150.         train_loss = 0  
  151.         train_acc = 0  
  152.         net = net.train()  
  153.         for im, label in train_data:  
  154.             if torch.cuda.is_available():  
  155.                 im = Variable(im.cuda())  
  156.                 label = Variable(label.cuda())  
  157.             else:  
  158.                 im = Variable(im)  
  159.                 label = Variable(label)  
  160.             # forward  
  161.             output = net(im)  
  162.             loss = criterion(output, label)  
  163.             # forward  
  164.             optimizer.zero_grad()  
  165.             loss.backward()  
  166.             optimizer.step()  
  167.   
  168.             train_loss += loss.data[0]  
  169.             train_acc += get_acc(output, label)  
  170.         cur_time = datetime.now()  
  171.         h, remainder = divmod((cur_time - prev_time).seconds, 3600)  
  172.         m, s = divmod(remainder, 60)  
  173.         time_str = "Time %02d:%02d:%02d" % (h, m, s)  
  174.         if valid_data is not None:  
  175.             valid_loss = 0  
  176.             valid_acc = 0  
  177.             net = net.eval()  
  178.             for im, label in valid_data:  
  179.                 if torch.cuda.is_available():  
  180.                     im = Variable(im.cuda(), volatile=True)  
  181.                     label = Variable(label.cuda(), volatile=True)  
  182.                 else:  
  183.                     im = Variable(im, volatile=True)  
  184.                     label = Variable(label, volatile=True)  
  185.                 output = net(im)  
  186.                 loss = criterion(output, label)  
  187.                 valid_loss += loss.item()  
  188.                 valid_acc += get_acc(output, label)  
  189.             epoch_str = (  
  190.                     "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "  
  191.                     % (epoch, train_loss / len(train_data),  
  192.                        train_acc / len(train_data), valid_loss / len(valid_data),  
  193.                        valid_acc / len(valid_data)))  
  194.         else:  
  195.             epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %  
  196.                          (epoch, train_loss / len(train_data),  
  197.                           train_acc / len(train_data)))  
  198.   
  199.         prev_time = cur_time  
  200.         print(epoch_str + time_str)  
  201.   
  202.   
  203. train(net, train_data, test_data, 20, optimizer, criterion)  
继续阅读
历史上的今天
十月
22
  • 版权声明: 发表于 2018年10月22日18:44:52
  • 转载注明:https://x1995.cn/3424.html
Densely Connected Convolutional Networks翻译 深度学习

Densely Connected Convolutional Networks翻译

Abstract 最近的成果显示,如果神经网络各层到输入和输出层采用更短的连接,那么网络可以设计的更深、更准确且训练起来更有效率。本文根据这个现象,提出了Dense Convolutional Net...
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

评论:2   其中:访客  1   博主  1
    • 游客 游客 @回复 0

      这么厉害啊