用DenseNet训练数据集CIFAR-10

2018年10月22日18:44:52 2 1,688

pytorch代码如下:

  1. # -*- coding: utf-8 -*-  
  2. """ 
  3. Created on Wed Sep  5 09:10:52 2018 
  4. @author: www 
  5. """  
  6.   
  7. import sys  
  8.   
  9. sys.path.append('...')  
  10.   
  11. import numpy as np  
  12. import torch  
  13. from torch import nn  
  14. from torch.autograd import Variable  
  15. from torchvision.datasets import CIFAR10  
  16.   
  17.   
  18. # 首先定义一个卷积块,其顺序是bn->relu->conv  
  19. def conv_block(in_channel, out_channel):  
  20.     layer = nn.Sequential(  
  21.         nn.BatchNorm2d(in_channel),  
  22.         nn.ReLU(True),  
  23.         nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)  
  24.     )  
  25.     return layer  
  26.   
  27.   
  28. class dense_block(nn.Module):  
  29.     def __init__(self, in_channel, growth_rate, num_layers):  
  30.         super(dense_block, self).__init__()  
  31.         block = []  
  32.         channel = in_channel  
  33.         for i in range(num_layers):  
  34.             block.append(conv_block(channel, growth_rate))  
  35.             channel += growth_rate  
  36.         self.net = nn.Sequential(*block)  
  37.   
  38.     def forward(self, x):  
  39.         for layer in self.net:  
  40.             out = layer(x)  
  41.             x = torch.cat((out, x), dim=1)  
  42.         return x  
  43.   
  44.   
  45. # 验证输出是否正确  
  46. test_net = dense_block(3, 12, 3)  
  47. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  48. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))  
  49. test_y = test_net(test_x)  
  50. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))  
  51.   
  52.   
  53. # 除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet  
  54. # 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,  
  55. # 为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用  
  56. # 1 x 1 的卷积  
  57. def transition(in_channel, out_channel):  
  58.     trans_layer = nn.Sequential(  
  59.         nn.BatchNorm2d(in_channel),  
  60.         nn.ReLU(True),  
  61.         nn.Conv2d(in_channel, out_channel, 1),  
  62.         nn.AvgPool2d(2, 2)  
  63.     )  
  64.     return trans_layer  
  65.   
  66.   
  67. # 验证一下过渡层是否正确  
  68. test_net = transition(3, 12)  
  69. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  70. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))  
  71. test_y = test_net(test_x)  
  72. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))  
  73.   
  74.   
  75. class densenet(nn.Module):  
  76.     def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):  
  77.         super(densenet, self).__init__()  
  78.         self.block1 = nn.Sequential(  
  79.             nn.Conv2d(in_channel, 64, 7, 2, 3),  
  80.             nn.BatchNorm2d(64),  
  81.             nn.ReLU(True),  
  82.             nn.MaxPool2d(3, 2, padding=1)  
  83.         )  
  84.   
  85.         channels = 64  
  86.         block = []  
  87.         for i, layers in enumerate(block_layers):  
  88.             block.append(dense_block(channels, growth_rate, layers))  
  89.             channels += layers * growth_rate  
  90.             if i != len(block_layers) - 1:  
  91.                 block.append(transition(channels, channels // 2))  # 通过transition 层将大小减半,通道数减半  
  92.                 channels = channels // 2  
  93.   
  94.         self.block2 = nn.Sequential(*block)  
  95.         self.block2.add_module('bn', nn.BatchNorm2d(channels))  
  96.         self.block2.add_module('relu', nn.ReLU(True))  
  97.         self.block2.add_module('avg_pool', nn.AvgPool2d(3))  
  98.   
  99.         self.classifier = nn.Linear(channels, num_classes)  
  100.   
  101.     def forward(self, x):  
  102.         x = self.block1(x)  
  103.         x = self.block2(x)  
  104.   
  105.         x = x.view(x.shape[0], -1)  
  106.         x = self.classifier(x)  
  107.         return x  
  108.   
  109.   
  110. test_net = densenet(3, 10)  
  111. test_x = Variable(torch.zeros(1, 3, 96, 96))  
  112. test_y = test_net(test_x)  
  113. print('output: {}'.format(test_y.shape))  
  114.   
  115.   
  116. # 数据预处理函数  
  117. def data_tf(x):  
  118.     x = x.resize((96, 96), 2)  # 将图片放大到 96 x 96  
  119.     x = np.array(x, dtype='float32') / 255  
  120.     x = (x - 0.5) / 0.5  # 标准化,这个技巧之后会讲到  
  121.     x = x.transpose((2, 0, 1))  # 将 channel 放到第一维,只是 pytorch 要求的输入方式  
  122.     x = torch.from_numpy(x)  
  123.     return x  
  124.   
  125.   
  126. train_set = CIFAR10('./data', train=True, transform=data_tf)  
  127. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)  
  128. test_set = CIFAR10('./data', train=False, transform=data_tf)  
  129. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)  
  130.   
  131. net = densenet(3, 10)  
  132. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)  
  133. criterion = nn.CrossEntropyLoss()  
  134.   
  135. from datetime import datetime  
  136.   
  137.   
  138. def get_acc(output, label):  
  139.     total = output.shape[0]  
  140.     _, pred_label = output.max(1)  
  141.     num_correct = (pred_label == label).sum().data[0]  
  142.     return num_correct / total  
  143.   
  144.   
  145. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):  
  146.     if torch.cuda.is_available():  
  147.         net = net.cuda()  
  148.     prev_time = datetime.now()  
  149.     for epoch in range(num_epochs):  
  150.         train_loss = 0  
  151.         train_acc = 0  
  152.         net = net.train()  
  153.         for im, label in train_data:  
  154.             if torch.cuda.is_available():  
  155.                 im = Variable(im.cuda())  
  156.                 label = Variable(label.cuda())  
  157.             else:  
  158.                 im = Variable(im)  
  159.                 label = Variable(label)  
  160.             # forward  
  161.             output = net(im)  
  162.             loss = criterion(output, label)  
  163.             # forward  
  164.             optimizer.zero_grad()  
  165.             loss.backward()  
  166.             optimizer.step()  
  167.   
  168.             train_loss += loss.data[0]  
  169.             train_acc += get_acc(output, label)  
  170.         cur_time = datetime.now()  
  171.         h, remainder = divmod((cur_time - prev_time).seconds, 3600)  
  172.         m, s = divmod(remainder, 60)  
  173.         time_str = "Time %02d:%02d:%02d" % (h, m, s)  
  174.         if valid_data is not None:  
  175.             valid_loss = 0  
  176.             valid_acc = 0  
  177.             net = net.eval()  
  178.             for im, label in valid_data:  
  179.                 if torch.cuda.is_available():  
  180.                     im = Variable(im.cuda(), volatile=True)  
  181.                     label = Variable(label.cuda(), volatile=True)  
  182.                 else:  
  183.                     im = Variable(im, volatile=True)  
  184.                     label = Variable(label, volatile=True)  
  185.                 output = net(im)  
  186.                 loss = criterion(output, label)  
  187.                 valid_loss += loss.item()  
  188.                 valid_acc += get_acc(output, label)  
  189.             epoch_str = (  
  190.                     "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "  
  191.                     % (epoch, train_loss / len(train_data),  
  192.                        train_acc / len(train_data), valid_loss / len(valid_data),  
  193.                        valid_acc / len(valid_data)))  
  194.         else:  
  195.             epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %  
  196.                          (epoch, train_loss / len(train_data),  
  197.                           train_acc / len(train_data)))  
  198.   
  199.         prev_time = cur_time  
  200.         print(epoch_str + time_str)  
  201.   
  202.   
  203. train(net, train_data, test_data, 20, optimizer, criterion)  

评论已关闭!

目前评论:2   其中:访客  1   博主  1

    • 游客 游客 0

      这么厉害啊