代码如下:文章源自联网快讯-https://x1995.cn/3411.html
文章源自联网快讯-https://x1995.cn/3411.html
- import torch
- import torch.nn as nn
- #首先定义一个卷积块,其顺序是bn->relu->conv
- def conv_block(in_channel, out_channel):
- layer = nn.Sequential(
- nn.BatchNorm2d(in_channel),
- nn.ReLU(True),
- nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)
- )
- return layer
- class dense_block(nn.Module):
- def __init__(self, in_channel, growth_rate, num_layers):
- super(dense_block, self).__init__()
- block = []
- channel = in_channel
- for i in range(num_layers):
- block.append(conv_block(channel, growth_rate))
- channel += growth_rate
- self.net = nn.Sequential(*block)
- def forward(self, x):
- for layer in self.net:
- out = layer(x)
- x = torch.cat((out, x), dim=1)
- return x
- #除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet
- #会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,
- #为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用
- # 1 x 1 的卷积
- def transition(in_channel, out_channel):
- trans_layer = nn.Sequential(
- nn.BatchNorm2d(in_channel),
- nn.ReLU(True),
- nn.Conv2d(in_channel, out_channel, 1),
- nn.AvgPool2d(2, 2)
- )
- return trans_layer
- class DenseNet(nn.Module):
- def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
- super(DenseNet, self).__init__()
- self.block1 = nn.Sequential(
- nn.Conv2d(in_channel, 64, 7, 2, 3),
- nn.BatchNorm2d(64),
- nn.ReLU(True),
- nn.MaxPool2d(3, 2, padding=1)
- )
- channels = 64
- block = []
- for i, layers in enumerate(block_layers):
- block.append(dense_block(channels, growth_rate, layers))
- channels += layers * growth_rate
- if i!= len(block_layers) - 1:
- block.append(transition(channels, channels // 2)) #通过transition 层将大小减半,通道数减半
- channels = channels // 2
- self.block2 = nn.Sequential(*block)
- self.block2.add_module('bn', nn.BatchNorm2d(channels))
- self.block2.add_module('relu', nn.ReLU(True))
- self.block2.add_module('avg_pool', nn.AvgPool2d(3))
- self.classifier = nn.Linear(channels, num_classes)
- def forward(self, x):
- x = self.block1(x)
- x = self.block2(x)
- x = x.view(x.shape[0], -1)
- x = self.classifier(x)
- return x
- densenet=DenseNet(3,10)
- print(densenet)
继续阅读
评论