PyTorch 初体验

栏目: Python · 发布时间: 7年前

内容简介:PyTorch 初体验

这两周简单看了下 pytorch ,虽然说还没有非常系统的、全方面的认识,但姑且总结一下好了。

基础模块构成

最核心的模型组件都在 torch.nn 这个模块里,这个模块里包含了

  • 不同类型的网络结构,如:Embedding, LSTM, Conv1d, MaxPool1d, Linear
  • 不同类型的激活函数,如:RELU, SELU, Sigmoid, Tanh
  • 不同类型的目标函数,如:CrossEntropyLoss, MSELoss, HingeEmbeddingLoss

不一一列举,总之,如果是想构建起一个网络,不考虑训练的话,那么只用 torch.nn 这个模块里的东西就足够了。

比较重要的是 torch.nn.Module 这个类,上述的网络结构、激活函数、目标函数都继承自这个类,如果是想自定义模型、激活函数、目标函数的话,继承这个类就好。所以这个类的行为和内在机制有必要好好了解一下。

其次就是 torch.autograd 这个模块,其中的 Variable 是 torch 里的输入、输出数据的标准类型,也就是说,我们定义好一个模型后,如果想输入东西,就得把数据都转成 Variable 类型的值。

torch.optim 里则定义了常用的一些优化方法,不多,罗列如下

  • Adadelta
  • Adagrad
  • Adam
  • SparseAdam
  • Adamax
  • ASGD
  • SGD
  • Rprop
  • RMSprop
  • Optimizer
  • LBFGS

差不多就是这个样子。

其他

暂时了解还不多,就不长篇大论了,这里随便写写。

在模型层面,pytorch 使用起来确实舒服很多。主要的点有这些

  • 不需要管 session、graph 这些东西,定义好的网络结构,直接就能接受输入并得到输出
  • 模块都继承自 torch.nn.Module 这个类,而这个类被设计成了 picklable 的,我们直接用 pickle.dump 和 pickle.load 就能对模型进行保存和加载,相比之下,tensorflow 默认将模型拆成若干个文件然后通过 saver 的方式来保存和加载一直让我非常抗拒 —— 倒不是说我认为模型存成多个文件就不好,但至少提供让我不存成多个文件的选项吧?在 tensorflow 里想要自己去把模型结构和模型参数拿出来按自己的想法存储这件事情,我是一直都没有成功过……
  • torch.nn.Module 类有一个 bool 类型的 training 成员,如果将其设置成 False,那么 Dropout、BatchNorm 之类的层就会失效,这个虽然是个很小的点但也是非常让人舒服的一点,tensorflow 里为了解决这个问题,通常就得自己来设置选项来保证在训练和预测的时候产生两张不同的图,贼恶心

    详情见 这个帖子这个帖子

当然,也有一些坑,或者说我觉得不太满意的地方吧

  • 目前 pypi 上的版本落后于官网版本,按照官网上的安装方法要下几百兆的东西……吐血……
  • 没有对整个项目结构和机制的总体介绍,当然,毕竟版本还在 0.3.0……
  • 模型内部的数据类型好像有不一致的地方,我在默认行为下,遇到过什么 DoubleTensor 的错误,可能和这个 issue 有关
  • 需要手动选择是否使用 CUDA,但是我明明看到 torch.cuda 下有个 is_available 的方法
  • RNN 的输入和输出,默认第二个维度是 batch,如果想让第一个维度表示 batch,就得手动把 batch_first 设置成 True,挺奇怪的

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

四维人类

四维人类

(英)劳伦斯·斯科特 / 祝锦杰 / 浙江教育出版社 / 2018-10 / 79.90元

数字技术如何重新定义 我们的思维方式与生存方式?一起来看看 《四维人类》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具