用DenseNet训练数据集CIFAR-10

Miracle 2018年10月22日18:44:52深度学习29,41276355字阅读21分11秒阅读模式

pytorch代码如下:文章源自联网快讯-https://x1995.cn/3424.html

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Sep  5 09:10:52 2018
  4. @author: www
  5. """
  6. import sys
  7. sys.path.append('...')
  8. import numpy as np
  9. import torch
  10. from torch import nn
  11. from torch.autograd import Variable
  12. from torchvision.datasets import CIFAR10
  13. # 首先定义一个卷积块,其顺序是bn->relu->conv
  14. def conv_block(in_channel, out_channel):
  15.     layer = nn.Sequential(
  16.         nn.BatchNorm2d(in_channel),
  17.         nn.ReLU(True),
  18.         nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)
  19.     )
  20.     return layer
  21. class dense_block(nn.Module):
  22.     def __init__(self, in_channel, growth_rate, num_layers):
  23.         super(dense_block, self).__init__()
  24.         block = []
  25.         channel = in_channel
  26.         for i in range(num_layers):
  27.             block.append(conv_block(channel, growth_rate))
  28.             channel += growth_rate
  29.         self.net = nn.Sequential(*block)
  30.     def forward(self, x):
  31.         for layer in self.net:
  32.             out = layer(x)
  33.             x = torch.cat((out, x), dim=1)
  34.         return x
  35. # 验证输出是否正确
  36. test_net = dense_block(3, 12, 3)
  37. test_x = Variable(torch.zeros(1, 3, 96, 96))
  38. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  39. test_y = test_net(test_x)
  40. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  41. # 除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet
  42. # 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,
  43. # 为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用
  44. # 1 x 1 的卷积
  45. def transition(in_channel, out_channel):
  46.     trans_layer = nn.Sequential(
  47.         nn.BatchNorm2d(in_channel),
  48.         nn.ReLU(True),
  49.         nn.Conv2d(in_channel, out_channel, 1),
  50.         nn.AvgPool2d(2, 2)
  51.     )
  52.     return trans_layer
  53. # 验证一下过渡层是否正确
  54. test_net = transition(3, 12)
  55. test_x = Variable(torch.zeros(1, 3, 96, 96))
  56. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  57. test_y = test_net(test_x)
  58. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  59. class densenet(nn.Module):
  60.     def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
  61.         super(densenet, self).__init__()
  62.         self.block1 = nn.Sequential(
  63.             nn.Conv2d(in_channel, 64, 7, 2, 3),
  64.             nn.BatchNorm2d(64),
  65.             nn.ReLU(True),
  66.             nn.MaxPool2d(3, 2, padding=1)
  67.         )
  68.         channels = 64
  69.         block = []
  70.         for i, layers in enumerate(block_layers):
  71.             block.append(dense_block(channels, growth_rate, layers))
  72.             channels += layers * growth_rate
  73.             if i != len(block_layers) - 1:
  74.                 block.append(transition(channels, channels // 2))  # 通过transition 层将大小减半,通道数减半
  75.                 channels = channels // 2
  76.         self.block2 = nn.Sequential(*block)
  77.         self.block2.add_module('bn', nn.BatchNorm2d(channels))
  78.         self.block2.add_module('relu', nn.ReLU(True))
  79.         self.block2.add_module('avg_pool', nn.AvgPool2d(3))
  80.         self.classifier = nn.Linear(channels, num_classes)
  81.     def forward(self, x):
  82.         x = self.block1(x)
  83.         x = self.block2(x)
  84.         x = x.view(x.shape[0], -1)
  85.         x = self.classifier(x)
  86.         return x
  87. test_net = densenet(3, 10)
  88. test_x = Variable(torch.zeros(1, 3, 96, 96))
  89. test_y = test_net(test_x)
  90. print('output: {}'.format(test_y.shape))
  91. # 数据预处理函数
  92. def data_tf(x):
  93.     x = x.resize((96, 96), 2)  # 将图片放大到 96 x 96
  94.     x = np.array(x, dtype='float32') / 255
  95.     x = (x - 0.5) / 0.5  # 标准化,这个技巧之后会讲到
  96.     x = x.transpose((2, 0, 1))  # 将 channel 放到第一维,只是 pytorch 要求的输入方式
  97.     x = torch.from_numpy(x)
  98.     return x
  99. train_set = CIFAR10('./data', train=True, transform=data_tf)
  100. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  101. test_set = CIFAR10('./data', train=False, transform=data_tf)
  102. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  103. net = densenet(3, 10)
  104. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  105. criterion = nn.CrossEntropyLoss()
  106. from datetime import datetime
  107. def get_acc(output, label):
  108.     total = output.shape[0]
  109.     _, pred_label = output.max(1)
  110.     num_correct = (pred_label == label).sum().data[0]
  111.     return num_correct / total
  112. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  113.     if torch.cuda.is_available():
  114.         net = net.cuda()
  115.     prev_time = datetime.now()
  116.     for epoch in range(num_epochs):
  117.         train_loss = 0
  118.         train_acc = 0
  119.         net = net.train()
  120.         for im, label in train_data:
  121.             if torch.cuda.is_available():
  122.                 im = Variable(im.cuda())
  123.                 label = Variable(label.cuda())
  124.             else:
  125.                 im = Variable(im)
  126.                 label = Variable(label)
  127.             # forward
  128.             output = net(im)
  129.             loss = criterion(output, label)
  130.             # forward
  131.             optimizer.zero_grad()
  132.             loss.backward()
  133.             optimizer.step()
  134.             train_loss += loss.data[0]
  135.             train_acc += get_acc(output, label)
  136.         cur_time = datetime.now()
  137.         h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  138.         m, s = divmod(remainder, 60)
  139.         time_str = "Time %02d:%02d:%02d" % (h, m, s)
  140.         if valid_data is not None:
  141.             valid_loss = 0
  142.             valid_acc = 0
  143.             net = net.eval()
  144.             for im, label in valid_data:
  145.                 if torch.cuda.is_available():
  146.                     im = Variable(im.cuda(), volatile=True)
  147.                     label = Variable(label.cuda(), volatile=True)
  148.                 else:
  149.                     im = Variable(im, volatile=True)
  150.                     label = Variable(label, volatile=True)
  151.                 output = net(im)
  152.                 loss = criterion(output, label)
  153.                 valid_loss += loss.item()
  154.                 valid_acc += get_acc(output, label)
  155.             epoch_str = (
  156.                     "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
  157.                     % (epoch, train_loss / len(train_data),
  158.                        train_acc / len(train_data), valid_loss / len(valid_data),
  159.                        valid_acc / len(valid_data)))
  160.         else:
  161.             epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
  162.                          (epoch, train_loss / len(train_data),
  163.                           train_acc / len(train_data)))
  164.         prev_time = cur_time
  165.         print(epoch_str + time_str)
  166. train(net, train_data, test_data, 20, optimizer, criterion)
文章源自联网快讯-https://x1995.cn/3424.html
继续阅读
Miracle
  • 本文由 发表于 2018年10月22日18:44:52