(译) Pytorch 教程:迁移学习

栏目: Python · 发布时间: 6年前

内容简介:在这篇教程中,你将会学到如何利用迁移学习来训练你的网络。你可以通过引用在实践中,很少有人会从头开始训练一个卷积神经网络(随机初始化),因为你很难拥有一个足够大的数据集。事实上,更常见的做法是先在一个非常大的数据集(比如 ImageNet,该数据集含有涵盖了 1000 个类别的 120 万张图片)上预训练一个卷积神经网络,然后利用该网络中参数作为初始参数,或者把该网络当作另一项任务的固定特征提取器。

这篇文章翻译自 Pytorch 官方教程 Transfer Learning Tutorial

原作者: Sasank Chilamkurthy

Note:点击下载完整示例代码

在这篇教程中,你将会学到如何利用迁移学习来训练你的网络。你可以通过 cs231n notes 了解更多关于迁移学习的信息。

引用 cs231n notes 中的一段话

在实践中,很少有人会从头开始训练一个卷积神经网络(随机初始化),因为你很难拥有一个足够大的数据集。事实上,更常见的做法是先在一个非常大的数据集(比如 ImageNet,该数据集含有涵盖了 1000 个类别的 120 万张图片)上预训练一个卷积神经网络,然后利用该网络中参数作为初始参数,或者把该网络当作另一项任务的固定特征提取器。

迁移学习主要在以下两个场景下使用:

  • 网络调优: 使用预训练网络(比如在 ImageNet 上训练的网络)中的参数作为初始参数,而不是随机初始化。其余部分的训练流程和往常一样。
  • 固定特征提取器: 除了最后的全连接层,我们会冻结网络中其余部分的参数,最后的全连接层中的参数会重新随机初始化,只有该层中的参数会在训练中更新。
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # 交互模式

读取数据

我们将使用 torchvisiontorch.utils.data 两个 packages 来读取数据。

我们今天的目标是建立一个可以分辨 蚂蚁蜜蜂 的分类器,但是我们只有蚂蚁和蜜蜂的图片各约 120 张用于训练,75 张用于验证集。通常来说,如果要从头训练一个模型,这个数据集是非常小的。因此我们要利用迁移学习。

这个数据集是 ImageNet 的一个很小的子集。

Note:从 这里 下载数据并将其解压到当前文件夹。

# 对训练集使用 data augmentation 和 normalization
# 对验证集只使用 normalization
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224  ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224  , 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224  ),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224  , 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化一些图像

为了理解 data augmentation,我们来看看一些图像。

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224  , 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# 获取训练数据中的一个 batch
inputs, classes = next(iter(dataloaders['train']))

out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

(译) Pytorch 教程:迁移学习

训练模型

现在,为了训练模型,我们应该写一些通用函数。在这里我们将阐述以下两点

  • Scheduling 学习率
  • 保存最好的模型

下面参数中的 schedulertorch.optim.lr_scheduler 包中的 LR scheduler 对象

def train_model(model, criterion, optimizer, scheduler, num_> Epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for > Epoch in range(num_> Epochs):
        print('> Epoch {}/{}'.format(> Epoch, num_> Epochs - 1))
        print('-' * 10)

        # 每次遍历都要经过训练集和验证集
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train()  # 设置模型为训练模式
            else:
                model.eval()   # 设置模型为验证模式

            running_loss = 0.0
            running_corrects = 0

            # 迭代
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 清零梯度
                optimizer.zero_grad()

                # 前向传播
                # 只在训练时计算梯度
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 只有在训练时才进行反向传播和参数更新
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 统计
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            > Epoch_loss = running_loss / dataset_sizes[phase]
            > Epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, > Epoch_loss, > Epoch_acc))

            # 找到最好的模型
            if phase == 'val' and > Epoch_acc > best_acc:
                best_acc = > Epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 读取最好模型
    model.load_state_dict(best_model_wts)
    return model

观察模型给出的预测

这是一个用于展示预测结果的通用函数

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

网络调优

读取一个预训练网络并重置最后的全连接层

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# 这里所有参数都会更新
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 学习率每 7 次迭代以 0.1 为因子衰减
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练与验证

在 CPU 上训练会花费大约 15-25 分钟,而在 GPU 上则要不了一分钟。

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_> Epochs=25)

输出:

Epoch 0/24 —————— train Loss: 0.5900 Acc: 0.7131 val Loss: 0.2508 Acc: 0.9020
Epoch 1/24 —————— train Loss: 0.6034 Acc: 0.7828 val Loss: 0.3181 Acc: 0.8627 
Epoch 2/24 —————— train Loss: 0.6150 Acc: 0.7582 val Loss: 0.4903 Acc: 0.8366 
Epoch 3/24 —————— train Loss: 0.6650 Acc: 0.7377 val Loss: 0.6294 Acc: 0.7582 
Epoch 4/24 —————— train Loss: 0.4935 Acc: 0.7828 val Loss: 0.2644 Acc: 0.8889 
Epoch 5/24 —————— train Loss: 0.3841 Acc: 0.8238 val Loss: 0.24 08 Acc: 0.9216 
Epoch 6/24 —————— train Loss: 0.5352 Acc: 0.8156 val Loss: 0.2250 Acc: 0.9150 
Epoch 7/24 —————— train Loss: 0.2252 Acc: 0.9385 val Loss: 0.1917 Acc: 0.9477 
Epoch 8/24 —————— train Loss: 0.3395 Acc: 0.8197 val Loss: 0.1738 Acc: 0.9477 
Epoch 9/24 —————— train Loss: 0.3363 Acc: 0.8607 val Loss: 0.2522 Acc: 0.9216 
Epoch 10/24 —————— train Loss: 0.2878 Acc: 0.8607 val Loss: 0.1787 Acc: 0.9412 
Epoch 11/24 —————— train Loss: 0.2831 Acc: 0.8770 val Loss: 0.1805 Acc: 0.9346 
Epoch 12/24 —————— train Loss: 0.2290 Acc: 0.9016 val Loss: 0.1898 Acc: 0.9412 
Epoch 13/24 —————— train Loss: 0.24 94 Acc: 0.9016 val Loss: 0.1729 Acc: 0.9412 
Epoch 14/24 —————— train Loss: 0.3435 Acc: 0.8689 val Loss: 0.1736 Acc: 0.9412 
Epoch 15/24 —————— train Loss: 0.2274 Acc: 0.9057 val Loss: 0.1692 Acc: 0.9542 
Epoch 16/24 —————— train Loss: 0.3154 Acc: 0.8689 val Loss: 0.1742 Acc: 0.9412 
Epoch 17/24 —————— train Loss: 0.2749 Acc: 0.8893 val Loss: 0.1826 Acc: 0.9412 
Epoch 18/24 —————— train Loss: 0.2673 Acc: 0.8770 val Loss: 0.1731 Acc: 0.9281 
Epoch 19/24 —————— train Loss: 0.2865 Acc: 0.8730 val Loss: 0.1867 Acc: 0.9346 
Epoch 20/24 —————— train Loss: 0.3061 Acc: 0.8648 val Loss: 0.1966 Acc: 0.9477 
Epoch 21/24 —————— train Loss: 0.2638 Acc: 0.9016 val Loss: 0.1973 Acc: 0.9477 
Epoch 22/24 —————— train Loss: 0.2602 Acc: 0.8893 val Loss: 0.1769 Acc: 0.9281 
Epoch 23/24 —————— train Loss: 0.2817 Acc: 0.9016 val Loss: 0.1756 Acc: 0.9412 
Epoch 24 /24 —————— train Loss: 0.2959 Acc: 0.8730 val Loss: 0.1790 Acc: 0.9281
Training complete in 1m 8s Best val Acc: 0.95424 8
visualize_model(model_ft)

(译) Pytorch 教程:迁移学习

固定特征提取器

现在,除了最后的全连接层,我们要冻结网络中其余部分的所有参数。我们使用 requires_grad = False 来冻结参数, bachward() 便不会计算这些参数的梯度。

你可以在 这里 读到更多信息。

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# 新构建模块中的参数的 requires_grad 默认为 True
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# 现在只有最后的全连接层的参数会更新
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 学习率每 7 次迭代以 0.1 为因子衰减
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

训练与验证

在 CPU 上,这会花费约之前一半的时间,因为大多数参数的梯度不用计算了,不过这些参数仍然会参与前向传播。

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_> Epochs=25)

输出:

Epoch 0/24 —————— train Loss: 0.6463 Acc: 0.6803 val Loss: 0.1949 Acc: 0.9477 
Epoch 1/24 —————— train Loss: 0.4923 Acc: 0.8033 val Loss: 0.1696 Acc: 0.9477 
Epoch 2/24 —————— train Loss: 0.4234 Acc: 0.8115 val Loss: 0.4379 Acc: 0.7712 
Epoch 3/24 —————— train Loss: 0.5606 Acc: 0.7582 val Loss: 0.6383 Acc: 0.7451 
Epoch 4/24 —————— train Loss: 0.7560 Acc: 0.7295 val Loss: 0.1888 Acc: 0.9412 
Epoch 5/24 —————— train Loss: 0.4316 Acc: 0.8197 val Loss: 0.1999 Acc: 0.9477 
Epoch 6/24 —————— train Loss: 0.7722 Acc: 0.7131 val Loss: 0.1975 Acc: 0.9477 
Epoch 7/24 —————— train Loss: 0.3685 Acc: 0.8607 val Loss: 0.2000 Acc: 0.9477 
Epoch 8/24 —————— train Loss: 0.2968 Acc: 0.8811 val Loss: 0.1916 Acc: 0.9477 
Epoch 9/24 —————— train Loss: 0.3396 Acc: 0.8525 val Loss: 0.2165 Acc: 0.9542 
Epoch 10/24 —————— train Loss: 0.3885 Acc: 0.8320 val Loss: 0.2109 Acc: 0.9542 
Epoch 11/24 —————— train Loss: 0.4107 Acc: 0.8156 val Loss: 0.1881 Acc: 0.9477 
Epoch 12/24 —————— train Loss: 0.3249 Acc: 0.8730 val Loss: 0.1747 Acc: 0.9542 
Epoch 13/24 —————— train Loss: 0.3439 Acc: 0.8525 val Loss: 0.1950 Acc: 0.9477 
Epoch 14/24 —————— train Loss: 0.3641 Acc: 0.8443 val Loss: 0.1992 Acc: 0.9412 
Epoch 15/24 —————— train Loss: 0.3272 Acc: 0.8443 val Loss: 0.2320 Acc: 0.9412 
Epoch 16/24 —————— train Loss: 0.3102 Acc: 0.8730 val Loss: 0.1867 Acc: 0.9477 
Epoch 17/24 —————— train Loss: 0.4226 Acc: 0.8238 val Loss: 0.1872 Acc: 0.9542 
Epoch 18/24 —————— train Loss: 0.3452 Acc: 0.8443 val Loss: 0.1812 Acc: 0.9542 
Epoch 19/24 —————— train Loss: 0.3697 Acc: 0.8525 val Loss: 0.1890 Acc: 0.9477 
Epoch 20/24 —————— train Loss: 0.3078 Acc: 0.8607 val Loss: 0.1976 Acc: 0.9608 
Epoch 21/24 —————— train Loss: 0.3161 Acc: 0.8770 val Loss: 0.1982 Acc: 0.9412 
Epoch 22/24 —————— train Loss: 0.3749 Acc: 0.8320 val Loss: 0.2035 Acc: 0.9477 
Epoch 23/24 —————— train Loss: 0.3298 Acc: 0.8525 val Loss: 0.1855 Acc: 0.9477 
Epoch 24/24 —————— train Loss: 0.3597 Acc: 0.8402 val Loss: 0.1878 Acc: 0.9542 
Training complete in 0m 34s Best val Acc: 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()

(译) Pytorch 教程:迁移学习

下载 Python 源代码:transfer_learning_tutorial.py

下载 Jupyter Notebook: transfer_learning_tutorial.ipynb


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Advanced Web Metrics with Google Analytics, 2nd Edition

Advanced Web Metrics with Google Analytics, 2nd Edition

Brian Clifton / Sybex / 2010-3-15 / USD 39.99

Valuable tips and tricks for using the latest version of Google Analytics Packed with insider tips and tricks, this how-to guide is fully revised to cover the latest version of Google Analytics and sh......一起来看看 《Advanced Web Metrics with Google Analytics, 2nd Edition》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

SHA 加密
SHA 加密

SHA 加密工具