用DenseNet训练数据集CIFAR-10

Miracle
1135
文章
59
评论
2018年10月22日18:44:52 2 6355字阅读21分11秒

pytorch代码如下:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Sep  5 09:10:52 2018
  4. @author: www
  5. """
  6. import sys
  7. sys.path.append('...')
  8. import numpy as np
  9. import torch
  10. from torch import nn
  11. from torch.autograd import Variable
  12. from torchvision.datasets import CIFAR10
  13. # 首先定义一个卷积块,其顺序是bn->relu->conv
  14. def conv_block(in_channel, out_channel):
  15.     layer = nn.Sequential(
  16.         nn.BatchNorm2d(in_channel),
  17.         nn.ReLU(True),
  18.         nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)
  19.     )
  20.     return layer
  21. class dense_block(nn.Module):
  22.     def __init__(self, in_channel, growth_rate, num_layers):
  23.         super(dense_block, self).__init__()
  24.         block = []
  25.         channel = in_channel
  26.         for i in range(num_layers):
  27.             block.append(conv_block(channel, growth_rate))
  28.             channel += growth_rate
  29.         self.net = nn.Sequential(*block)
  30.     def forward(self, x):
  31.         for layer in self.net:
  32.             out = layer(x)
  33.             x = torch.cat((out, x), dim=1)
  34.         return x
  35. # 验证输出是否正确
  36. test_net = dense_block(3, 12, 3)
  37. test_x = Variable(torch.zeros(1, 3, 96, 96))
  38. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  39. test_y = test_net(test_x)
  40. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  41. # 除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet
  42. # 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,
  43. # 为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用
  44. # 1 x 1 的卷积
  45. def transition(in_channel, out_channel):
  46.     trans_layer = nn.Sequential(
  47.         nn.BatchNorm2d(in_channel),
  48.         nn.ReLU(True),
  49.         nn.Conv2d(in_channel, out_channel, 1),
  50.         nn.AvgPool2d(2, 2)
  51.     )
  52.     return trans_layer
  53. # 验证一下过渡层是否正确
  54. test_net = transition(3, 12)
  55. test_x = Variable(torch.zeros(1, 3, 96, 96))
  56. print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))
  57. test_y = test_net(test_x)
  58. print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))
  59. class densenet(nn.Module):
  60.     def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
  61.         super(densenet, self).__init__()
  62.         self.block1 = nn.Sequential(
  63.             nn.Conv2d(in_channel, 64, 7, 2, 3),
  64.             nn.BatchNorm2d(64),
  65.             nn.ReLU(True),
  66.             nn.MaxPool2d(3, 2, padding=1)
  67.         )
  68.         channels = 64
  69.         block = []
  70.         for i, layers in enumerate(block_layers):
  71.             block.append(dense_block(channels, growth_rate, layers))
  72.             channels += layers * growth_rate
  73.             if i != len(block_layers) - 1:
  74.                 block.append(transition(channels, channels // 2))  # 通过transition 层将大小减半,通道数减半
  75.                 channels = channels // 2
  76.         self.block2 = nn.Sequential(*block)
  77.         self.block2.add_module('bn', nn.BatchNorm2d(channels))
  78.         self.block2.add_module('relu', nn.ReLU(True))
  79.         self.block2.add_module('avg_pool', nn.AvgPool2d(3))
  80.         self.classifier = nn.Linear(channels, num_classes)
  81.     def forward(self, x):
  82.         x = self.block1(x)
  83.         x = self.block2(x)
  84.         x = x.view(x.shape[0], -1)
  85.         x = self.classifier(x)
  86.         return x
  87. test_net = densenet(3, 10)
  88. test_x = Variable(torch.zeros(1, 3, 96, 96))
  89. test_y = test_net(test_x)
  90. print('output: {}'.format(test_y.shape))
  91. # 数据预处理函数
  92. def data_tf(x):
  93.     x = x.resize((96, 96), 2)  # 将图片放大到 96 x 96
  94.     x = np.array(x, dtype='float32') / 255
  95.     x = (x - 0.5) / 0.5  # 标准化,这个技巧之后会讲到
  96.     x = x.transpose((2, 0, 1))  # 将 channel 放到第一维,只是 pytorch 要求的输入方式
  97.     x = torch.from_numpy(x)
  98.     return x
  99. train_set = CIFAR10('./data', train=True, transform=data_tf)
  100. train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
  101. test_set = CIFAR10('./data', train=False, transform=data_tf)
  102. test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  103. net = densenet(3, 10)
  104. optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
  105. criterion = nn.CrossEntropyLoss()
  106. from datetime import datetime
  107. def get_acc(output, label):
  108.     total = output.shape[0]
  109.     _, pred_label = output.max(1)
  110.     num_correct = (pred_label == label).sum().data[0]
  111.     return num_correct / total
  112. def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  113.     if torch.cuda.is_available():
  114.         net = net.cuda()
  115.     prev_time = datetime.now()
  116.     for epoch in range(num_epochs):
  117.         train_loss = 0
  118.         train_acc = 0
  119.         net = net.train()
  120.         for im, label in train_data:
  121.             if torch.cuda.is_available():
  122.                 im = Variable(im.cuda())
  123.                 label = Variable(label.cuda())
  124.             else:
  125.                 im = Variable(im)
  126.                 label = Variable(label)
  127.             # forward
  128.             output = net(im)
  129.             loss = criterion(output, label)
  130.             # forward
  131.             optimizer.zero_grad()
  132.             loss.backward()
  133.             optimizer.step()
  134.             train_loss += loss.data[0]
  135.             train_acc += get_acc(output, label)
  136.         cur_time = datetime.now()
  137.         h, remainder = divmod((cur_time - prev_time).seconds, 3600)
  138.         m, s = divmod(remainder, 60)
  139.         time_str = "Time %02d:%02d:%02d" % (h, m, s)
  140.         if valid_data is not None:
  141.             valid_loss = 0
  142.             valid_acc = 0
  143.             net = net.eval()
  144.             for im, label in valid_data:
  145.                 if torch.cuda.is_available():
  146.                     im = Variable(im.cuda(), volatile=True)
  147.                     label = Variable(label.cuda(), volatile=True)
  148.                 else:
  149.                     im = Variable(im, volatile=True)
  150.                     label = Variable(label, volatile=True)
  151.                 output = net(im)
  152.                 loss = criterion(output, label)
  153.                 valid_loss += loss.item()
  154.                 valid_acc += get_acc(output, label)
  155.             epoch_str = (
  156.                     "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
  157.                     % (epoch, train_loss / len(train_data),
  158.                        train_acc / len(train_data), valid_loss / len(valid_data),
  159.                        valid_acc / len(valid_data)))
  160.         else:
  161.             epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
  162.                          (epoch, train_loss / len(train_data),
  163.                           train_acc / len(train_data)))
  164.         prev_time = cur_time
  165.         print(epoch_str + time_str)
  166. train(net, train_data, test_data, 20, optimizer, criterion)
继续阅读
历史上的今天
十月
22
  • 版权声明: 发表于 2018年10月22日18:44:52
  • 转载注明:https://x1995.cn/3424.html
Densely Connected Convolutional Networks翻译 深度学习

Densely Connected Convolutional Networks翻译

Abstract 最近的成果显示,如果神经网络各层到输入和输出层采用更短的连接,那么网络可以设计的更深、更准确且训练起来更有效率。本文根据这个现象,提出了Dense Convolutional Net...
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

评论:2   其中:访客  1   博主  1
    • 游客 游客 @回复 0

      这么厉害啊