12 个常见 CNN 模型论文集锦与 PyTorch 实现

栏目: Python · 发布时间: 6年前

内容简介:最近发现了一份不错的源代码,作者使用 PyTorch 实现了如今主流的卷积神经网络 CNN 框架,包含了 12 中模型架构。所有代码使用的数据集是 CIFAR。项目地址:https://github.com/BIGBALLON/CIFAR-ZOO

最近发现了一份不错的源代码,作者使用 PyTorch 实现了如今主流的卷积神经网络 CNN 框架,包含了 12 中模型架构。所有代码使用的数据集是 CIFAR。

项目地址:

https://github.com/BIGBALLON/CIFAR-ZOO

CNN 经典论文

该项目实现的是主流的 CNN 模型,涉及的论文包括:

1. CNN 模型(12 篇)

(lenet) LeNet-5, convolutional neural networks

论文地址:http://yann.lecun.com/exdb/lenet/

(alexnet) ImageNet Classification with Deep Convolutional Neural Networks

论文地址:https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks

(vgg) Very Deep Convolutional Networks for Large-Scale Image Recognition

论文地址:https://arxiv.org/abs/1409.1556

(resnet) Deep Residual Learning for Image Recognition

论文地址:https://arxiv.org/abs/1512.03385

(preresnet) Identity Mappings in Deep Residual Networks

论文地址:https://arxiv.org/abs/1603.05027

(resnext) Aggregated Residual Transformations for Deep Neural Networks

论文地址:https://arxiv.org/abs/1611.05431

(densenet) Densely Connected Convolutional Networks

论文地址:https://arxiv.org/abs/1608.06993

(senet) Squeeze-and-Excitation Networks

论文地址:https://arxiv.org/abs/1709.01507

(bam) BAM: Bottleneck Attention Module

论文地址:https://arxiv.org/abs/1807.06514

(cbam) CBAM: Convolutional Block Attention Module

论文地址:https://arxiv.org/abs/1807.06521

(genet) Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks

论文地址:https://arxiv.org/abs/1810.12348

(sknet) SKNet: Selective Kernel Networks

论文地址:https://arxiv.org/abs/1903.06586

2. 正则化(3 篇)

(shake-shake) Shake-Shake regularization

论文地址:https://arxiv.org/abs/1705.07485

(cutout) Improved Regularization of Convolutional Neural Networks with Cutout

论文地址:https://arxiv.org/abs/1708.04552

(mixup) mixup: Beyond Empirical Risk Minimization

论文地址:https://arxiv.org/abs/1710.09412

3. 学习速率调度器(2 篇)

(cos_lr) SGDR: Stochastic Gradient Descent with Warm Restarts

论文地址:https://arxiv.org/abs/1608.03983

(htd_lr) Stochastic Gradient Descent with Hyperbolic-Tangent Decay on Classification

论文地址:https://arxiv.org/abs/1806.01593

需求和使用

1. 需求

运行所有代码的开发环境需求为:

  • Python >= 3.5
  • PyTorch >= 0.4

  • TensorFlow/Tensorboard

其它依赖项 (pyyaml, easydict, tensorboardX)

作者提供了一键安装、配置开发环境的方法:

pip install -r requirements.txt

2. 模型代码

作者将所有的模型都存放在 model 文件夹下,我们来看一下 PyTorch 实现的 ResNet 网络结构:

# -*-coding:utf-8-*-
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['resnet20', 'resnet32', 'resnet44',
'resnet56', 'resnet110', 'resnet1202']


def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv_1 = conv3x3(inplanes, planes, stride)
self.bn_1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv_2 = conv3x3(planes, planes)
self.bn_2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv_1(x)
out = self.bn_1(out)
out = self.relu(out)

out = self.conv_2(out)
out = self.bn_2(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv_1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn_1 = nn.BatchNorm2d(planes)
self.conv_2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn_2 = nn.BatchNorm2d(planes)
self.conv_3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn_3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv_1(x)
out = self.bn_1(out)
out = self.relu(out)

out = self.conv_2(out)
out = self.bn_2(out)
out = self.relu(out)

out = self.conv_3(out)
out = self.bn_3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self, depth, num_classes, block_name='BasicBlock'):
super(ResNet, self).__init__()
# Model type specifies number of layers for CIFAR-10 model
if block_name == 'BasicBlock':
assert (
depth - 2) % 6 == 0, 'depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
n = (depth - 2) // 6
block = BasicBlock
elif block_name == 'Bottleneck':
assert (
depth - 2) % 9 == 0, 'depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
n = (depth - 2) // 9
block = Bottleneck
else:
raise ValueError('block_name shoule be Basicblock or Bottleneck')

self.inplanes = 16
self.conv_1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
bias=False)
self.bn_1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.stage_1 = self._make_layer(block, 16, n)
self.stage_2 = self._make_layer(block, 32, n, stride=2)
self.stage_3 = self._make_layer(block, 64, n, stride=2)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
# nn.init.xavier_normal(m.weight.data)
nn.init.kaiming_normal_(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x = self.conv_1(x)
x = self.bn_1(x)
x = self.relu(x) # 32x32

x = self.stage_1(x) # 32x32
x = self.stage_2(x) # 16x16
x = self.stage_3(x) # 8x8

x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)

return x


def resnet20(num_classes):
return ResNet(depth=20, num_classes=num_classes)


def resnet32(num_classes):
return ResNet(depth=32, num_classes=num_classes)


def resnet44(num_classes):
return ResNet(depth=44, num_classes=num_classes)


def resnet56(num_classes):
return ResNet(depth=56, num_classes=num_classes)


def resnet110(num_classes):
return ResNet(depth=110, num_classes=num_classes)


def resnet1202(num_classes):
return ResNet(depth=1202, num_classes=num_classes)

其它模型也一并能找到。

3. 使用

简单运行下面的命令就可以运行程序了:

## 1 GPU for lenet
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet

## resume from ckpt
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet --resume

## 2 GPUs for resnet1202
CUDA_VISIBLE_DEVICES=0,1 python -u train.py --work-path ./experiments/cifar10/preresnet1202

## 4 GPUs for densenet190bc
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --work-path ./experiments/cifar10/densenet190bc

我们使用 yaml 文件 config.yaml 保存参数,查看 ./experimets 中的任何文件以了解更多详细信息。您可以通过 tensorboard 中 tensorboard –logdir path-to-event –port your-port 查看训练曲线。培训日志将通过日志转储,请检查您工作路径中的 log.txt。

模型在 CIFAR 数据集上的结果

1. 12 种 CNN 模型:

12 个常见 CNN 模型论文集锦与 PyTorch 实现

2. 正则化

默认的数据扩充方法是 RandomCrop+RandomHorizontalLip+Normalize,而 √ 表示采用哪种附加方法。

12 个常见 CNN 模型论文集锦与 PyTorch 实现

PS:Shake_Resnet26_2X64d 通过剪切和混合达到 97.71% 的测试精度!很酷,对吧?

3. 不同的学习速率调度器

12 个常见 CNN 模型论文集锦与 PyTorch 实现

最后,再附上项目地址:

https://github.com/BIGBALLON/CIFAR-ZOO

12 个常见 CNN 模型论文集锦与 PyTorch 实现


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Web Analytics 2.0

Web Analytics 2.0

Avinash Kaushik / Sybex / 2009-10-26 / USD 39.99

The bestselling book Web Analytics: An Hour A Day was the first book in the analytics space to move beyond clickstream analysis. Web Analytics 2.0 will significantly evolve the approaches from the fir......一起来看看 《Web Analytics 2.0》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试