用Pytorch构建一个自动解码器

栏目: Python · 发布时间: 5年前

内容简介:本文为 AI 研习社编译的技术博客,原标题 :Building Autoencoder in Pytorch

用Pytorch构建一个自动解码器

本文为 AI 研习社编译的技术博客,原标题 :

Building Autoencoder in Pytorch

作者 |  Vipul Vaibhaw

翻译 |  邓普斯•杰弗、 酱番梨、 向日魁

校对 | 邓普斯•杰弗        整理 | 菠萝妹

原文链接:

https://medium.com/@vaibhaw.vipul/building-autoencoder-in-pytorch-34052d1d280c

这篇文章中,我们将利用 CIFAR-10 数据集通过 Pytorch 构建一个简单的卷积自编码器。

用Pytorch构建一个自动解码器

引用维基百科的定义,”自编码器是一种人工神经网络,在无监督学习中用于有效编码。自编码的目的是通过一组数据学习出一种特征(编码),通常用于降维。“

为了建立一个自编码器,我们需要三件事:一个编码函数,一个解码函数,和一个衡量压缩特征和解压缩特征间信息损失的距离函数(也称为损失函数)。

如果我们要在 Pytorch 中编写自动编码器,我们需要有一个自动编码器类,并且必须使用super()从父类继承__init__。

我们通过导入必要的 Pytorch 模块开始编写卷积自动编码器。

import torchimport torchvision as tvimport torchvision.transforms as transformsimport torch.nn as nnimport torch.nn.functional as Ffrom torch.autograd import Variablefrom torchvision.utils import save_image

现在我们设置下载CIFAR-10数据集并将其转换应用于它。

我们对数据集应用了两个转换 -

  1. ToTensor() - 它将 PIL图像或者 [0,255]范围内的 numpy.ndarray(H x W x C)转换成 Torch 。 [0.0,1.0]范围内的形状 FloatTensor。

  2. Normalize() - 使用均值和标准差对张量图像进行标准化。

基本上在应用变换之后,我们得到(-2,2)范围内的值 。

# Loading and Transforming datatransform = transforms.Compose([transforms.ToTensor(),  transforms.Normalize((0.4914, 0.4822, 0.4466), (0.247,            0.243, 0.261))])trainTransform  = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.4914, 0.4822, 0.4466), (0.247, 0.243, 0.261))])trainset = tv.datasets.CIFAR10(root='./data',  train=True,download=True, transform=transform)dataloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=False, num_workers=4)testset = tv.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

你可以在这里阅读更多关于上述变换的内容。 现在下一步是编写自动编码类。

# Writing our modelclass Autoencoder(nn.Module):    def __init__(self):        super(Autoencoder,self).__init__()                self.encoder = nn.Sequential(            nn.Conv2d(3, 6, kernel_size=5),            nn.ReLU(True),            nn.Conv2d(6,16,kernel_size=5),            nn.ReLU(True))        self.decoder = nn.Sequential(                         nn.ConvTranspose2d(16,6,kernel_size=5),            nn.ReLU(True),            nn.ConvTranspose2d(6,3,kernel_size=5),            nn.ReLU(True),            nn.Sigmoid())    def forward(self,x):        x = self.encoder(x)        x = self.decoder(x)        return x

卷积编码器神经网络具有一些 Conv2d,并且我们有使用ReLU激活功能正在被使用。 现在我们定义一些参数 -

#defining some paramsnum_epochs = 5 #you can go for more epochs, I am using a macbatch_size = 128

然后是时候设置训练模型了。我们调用模型并将其配置为在 cpu 上运行。如果你有一个 gpu,你可以使用 cuda。

我们使用 Mean Squared Error 作为损失函数。对于优化器,我们使用 adam。

model = Autoencoder().cpu()distance = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(),weight_decay=1e-5)

让咱们开始训练吧!

for epoch in range(num_epochs):    for data in dataloader:        img, _ = data        img = Variable(img).cpu()        # ===================forward=====================        output = model(img)        loss = distance(output, img)        # ===================backward====================        optimizer.zero_grad()        loss.backward()        optimizer.step()    # ===================log========================    print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, loss.data[0]))

这是我写的一个简单的博客,展示了如何在 Pytorch 中构建自动编码器。 但是,如果要在模型中包含 MaxPool2d(),请确保设置 return_indices = True,然后在解码器中使用 MaxUnpool2d()图层。

持续的学习和分享,可以在 github,Stack Overflow,LinkedIn,或者 Twitter 上 Follow 我。

想要继续查看该篇文章相关链接和参考文献?

长按链接点击打开或点击【 用Pytorch构建一个自动解码器 】:

https://ai.yanxishe.com/page/TextTranslation/1284

AI研习社每日更新精彩内容,观看更多精彩内容: 雷锋网雷锋网 (公众号:雷锋网) 雷锋网

命名实体识别(NER)综述

杰出数据科学家的关键技能是什么?

初学者怎样使用Keras进行迁移学习

如果你想学数据科学,这 7 类资源千万不能错过

等你来译:

深度学习目标检测算法综述

一文教你如何用PyTorch构建 Faster RCNN

高级DQNs:利用深度强化学习玩吃豆人游戏

用于深度强化学习的结构化控制网络 (ICML 论文讲解)


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

数据资本时代

数据资本时代

Viktor Mayer-Schnberger / 李晓霞、周涛 / 中信出版集团股份有限公司 / 2018-11-1 / CNY 58.00

【编辑推荐】 大数据除了能对我们的生活、工作、思维产生重大变革外,还能够做什么?畅销书《大数据时代》作者舍恩伯格在新书《数据资本时代》中,展示了大数据将如何从根本上改变经济——这并不是因为数据是一种新型石油,而是因为数据是一种新型润滑脂,它将给市场带来巨大能量,给公司带来巨大压力,使金融资本的作用大大削弱。赢家是市场,而并非资本。 这本书在当下国内出版,可以说恰逢其时。时下,中国经济正......一起来看看 《数据资本时代》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

在线进制转换器
在线进制转换器

各进制数互转换器