FaceBook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

栏目: 编程工具 · 发布时间: 5年前

内容简介:近日,PyTorch 社区可复现性是许多研究领域的基本要求,这其中当然包括基于机器学习技术的研究领域。然而, 许多机器学习相关论文要么无法复现,要么难以重现。随着论文数量的持续增长,包括目前在 arXiv 上预印刷的数万份论文以及提交给会议的论文,研究工作的可复现性变得越来越重要。虽然其中许多论文都附有代码以及训练好的模型,但这种帮助显然非常有限,复现过程中仍有大量需要读者自己摸索的步骤。下面让我们来看一下如何通过 PyTorch Hub 这一利器完成快速的模型发布与工作复现。

近日,PyTorch 社区 发布 了一个深度学习 工具 包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作。PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以及新的研究。同时它还内置了对 Google Colab 的支持,并与 Papers With Code 集成。目前 PyTorchHub 包括了一系列与图像分类、分割、生成以及转换相关的模型。

可复现性是许多研究领域的基本要求,这其中当然包括基于机器学习技术的研究领域。然而, 许多机器学习相关论文要么无法复现,要么难以重现。随着论文数量的持续增长,包括目前在 arXiv 上预印刷的数万份论文以及提交给会议的论文,研究工作的可复现性变得越来越重要。虽然其中许多论文都附有代码以及训练好的模型,但这种帮助显然非常有限,复现过程中仍有大量需要读者自己摸索的步骤。下面让我们来看一下如何通过 PyTorch Hub 这一利器完成快速的模型发布与工作复现。

FaceBook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

如何快速发布模型

这部分主要介绍了对于模型发布者来说如何快速高效的将自己的模型加入 PyTorch Hub 库。PyTorch Hub 支持通过添加简单的 hubconf.py 文件将预先训练的模型(模型定义和预先训练重)发布到 GitHub 存储库。这提供了模型列表以及其依赖库列表。一些示例可以在 torchvisionhuggingface-bertgan-model-zoo 存储库中找到。

Pytoch 社区给出了 torchvision 的 hubconf.py 文件的示例:

复制代码

# Optional list ofdependenciesrequired by thepackage
dependencies= ['torch']

fromtorchvision.models.alexnetimportalexnet
fromtorchvision.models.densenetimportdensenet121, densenet169, densenet201, densenet161
fromtorchvision.models.inceptionimportinception_v3
fromtorchvision.models.resnetimportresnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d
fromtorchvision.models.squeezenetimportsqueezenet1_0, squeezenet1_1
fromtorchvision.models.vggimportvgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
fromtorchvision.models.segmentationimportfcn_resnet101, deeplabv3_resnet101
fromtorchvision.models.googlenetimportgooglenet
fromtorchvision.models.shufflenetv2importshufflenet_v2_x0_5, shufflenet_v2_x1_0
fromtorchvision.models.mobilenetimportmobilenet_v2

在 torchvision 中,模型有以下特性:

  • 每个模型文件可以被独立执行或实现某个功能
  • 不需要除了 PyTorch 之外的任何软件包(在 hubconf.py 中编码为 dependencies[‘torch’])
  • 他们不需要单独的入口点,因为模型在创建时可以无缝地开箱即用。

PyTroch 社区认为最小化包依赖性可减少用户加载模型时遇到的困难。这里他们给出了一个更为复杂的例子——HuggingFace’s BERT 模型,它的 hubconf.py 如下:

复制代码

dependencies= ['torch','tqdm','boto3','requests','regex']

fromhubconfs.bert_hubconfimport(
bertTokenizer,
bertModel,
bertForNextSentencePrediction,
bertForPreTraining,
bertForMaskedLM,
bertForSequenceClassification,
bertForMultipleChoice,
bertForQuestionAnswering,
bertForTokenClassification
)

此外,对于每个模型,PyTorch 官方提到都需要为其创建一个入口点。下面是一个用于指定 bertForMaskedLM 模型的入口点的代码片段,这部分代码完成的功能是返回加载了预训练参数的模型。

复制代码

defbertForMaskedLM(*args, **kwargs):
"""
BertForMaskedLM includes the BertModel Transformer followed by the
pre-trained masked language modeling head.
Example:
...
"""
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
returnmodel

这些入口点可以看成是复杂的模型结构的一种封装形式。它们可以在提供简洁高效的帮助文档的同时完成下载预训练权重的功能(例如,通过 pretrained = True),也可以集成其他特定功能,例如可视化。

通过 hubconf.py ,模型发布者可以在 Github 上基于 template 提交他们的合并请求。PyTorch 社区希望通过 PyTorch Hub 创建一系列高质量、易复现且效果好的模型以提高研究工作的复现性。因此,PyTorch 会通过与模型发布者合作的方式以完善请求,并有可能会在某些情况下拒绝发布一些低质量的模型。一旦 PyTorch 社区接受了模型发布者的请求,这些新的模型将会很快出现在 PyTorch Hub 的网页上以供用户浏览。

用户工作流

对于想使用 PyTorch Hub 对别人的工作进行复现的用户,PyTorch Hub 提供了以下几个步骤:1)浏览可用的模型;2)加载模型;3)探索已加载的模型。下面让我们来浏览几个例子。

浏览可用的入口点

用户可以使用 torch.hub.list() API 列出仓库中的所有可用入口点。

复制代码

>>> torch.hub.list('pytorch/vision')
>>>
['alexnet',
'deeplabv3_resnet101',
'densenet121',
...
'vgg16',
'vgg16_bn',
'vgg19',
'vgg19_bn']

注意,PyTorch Hub 还允许辅助入口点(除了预训练模型),例如,用于 BERT 模型预处理的 bertTokenizer,它可以使用户工作流程更加顺畅。

加载模型

对于 PyTroch Hub 中可用的模型,用户可以使用 torch.hub.load() API 加载模型入口点。此外,torch.hub.help() API 可以提供有关如何实例化模型的有用信息。

复制代码

print(torch.hub.help('pytorch/vision','deeplabv3_resnet101'))
model = torch.hub.load('pytorch/vision','deeplabv3_resnet101', pretrained=True)

由于仓库的持有者会不断添加错误修复以及性能改进,PyTorch Hub 允许用户通过调用以下内容简单地获取最新更新:

复制代码

model = torch.hub.load(...,force_reload=True)

这一举措可以有效地减轻仓库持有者重复发布模型的负担,从而使他们能够更专注于自己的研究工作。同时,也确保了用户可以获得最新版本的模型。

此外,对于用户来说,稳定性也是一个重要问题。因此,某些模型所有者会从特征的分支或标签为他们提供服务,以确保代码的稳定性。例如,pytorch_GAN_zoo 会从 hub 分支为他们提供服务:

复制代码

model= torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub','DCGAN', pretrained=True, useGPU=False)

这里,传递给 hub.load() 的 * args,** kwargs 用于实例化模型。在上面的示例中,pretrained = True 和 useGPU = False 被传递给模型的入口点。

探索已加载的模型

从 PyTorch Hub 加载模型后,用户可以使用以下工作流查看已加载模型的可用方法,并更好地了解运行它所需的参数。

其中,dir(model) 可以查看模型中可用的方法。下面是 bertForMaskedLM 的一些方法:

复制代码

>>> dir(model)
>>>
['forward'
...
'to'
'state_dict',
]

help(model.forward)则会提供使已加载的模型运行时所需参数的视图:

复制代码

>>> help(model.forward)
>>>
Helponmethodforwardinmodulepytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...

更多细节可以查看 BERTDeepLabV3 页面:

其他探索方式与相关资源

PyTorch Hub 中提供的模型也支持 Colab,并且会直接链接在 Papers With Code 上,用户只需单击链接即可开始使用:

FaceBook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易

PyTorch 提供了一些相关资源帮助用户快速上手 PyTorch Hub:

FAQ

问:如果我们想贡献一个 Hub 中已经有了的模型,但也许我的模型具有更高的准确性,我还应该贡献吗?

答:是的,请提交您的模型,Hub 的下一步是开发投票系统以展示最佳模型。

问:谁负责保管 PyTorch Hub 的模型权重?

答:作为贡献者,您负责保管模型权重。您可以在您喜欢的云存储中托管您的模型,或者如果它符合限制,则可以在 GitHub 上托管您的模型。 如果您无法保管权重,请通过 Hub 仓库中提交问题的方式与我们联系。

问:如果我的模型使用了私有化数据进行训练怎么办?我还应该贡献这个模型吗?

答:请不要提交您的模型!PyTorch Hub 以开源研究为中心,并扩展到使用公开数据集来训练这些模型。如果提交了私有模型的合并请求,我们将恳请您重新提交使用公开数据进行训练后的模型。

问:我下载的模型保存在哪里?

答:我们遵循 XDG 基本目录规范,并遵循缓存文件和目录的通用标准。这些位置按以下顺序使用:

  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)
  • 如果环境变量了 TORCH_HOME,则为 $TORCH_HOME/hub。
  • 如果设置了环境变量 XDG_CACHE_HOME,则为 $ XDG_CACHE_HOME / torch / hub。
  • ~/.cache/torch/hub

以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

精益数据分析

精益数据分析

[加] 阿利斯泰尔·克罗尔、[加] 本杰明·尤科维奇 / 韩知白、王鹤达 / 人民邮电出版社 / 2014-12 / 79.00元

本书展示了如何验证自己的设想、找到真正的客户、打造能赚钱的产品,以及提升企业知名度。30多个案例分析,全球100多位知名企业家的真知灼见,为你呈现来之不易、经过实践检验的创业心得和宝贵经验,值得每位创业家和企业家一读。 深入理解精益创业、数据分析基础,和数据驱动的思维模式 如何将六个典型的商业模式应用到各种规模的新企业 找到你的第一关键指标 确定底线,找到出发点 在大......一起来看看 《精益数据分析》 这本书的介绍吧!

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具