牛刀小试之用 pytorch 实现 LSTM

栏目: Python · 发布时间: 6年前

内容简介:正常运行,说明我们的参数都组织得当,正确使用了pytorch中lstm模型。

牛刀小试之用 pytorch 实现 LSTM

LSTM参数

首先需要定义好循环网络,需要nn.LSTM(),首先介绍一下这个函数里面的参数

牛刀小试之用 pytorch 实现 LSTM

LSTM数据格式:

  • num_layers : 我们构建的循环网络有层lstm

  • num_directions : 当bidirectional=True时,num_directions=2;当bidirectional=False时,num_directions=1

输入LSTM中的数据格式

输入LSTM中的X数据格式尺寸为(seq_len, batch, input_size),此外h0和c0尺寸如下

  • h0(num_layers * num_directions,  batch_size,  hidden_size)

  • c0(num_layers * num_directions,  batch_size,  hidden_size)

LSTM输出数据格式

LSTM输出的数据格式尺寸为(seq_len, batch, hidden_size * num_directions);输出的hn和cn尺寸如下

  • hn(num_layers * num_directions,  batch_size,  hidden_size)

  • cn(num_layers * num_directions,  batch_size,  hidden_size)

import torch.nn as nn
import torch as t

lstm = nn.LSTM(input_size=4,  #输入数据的特征数是4
               hidden_size=10, #输出的特征数(hidden_size)是10
               batch_first= True)  #使用batch_first数据维度表达方式,即(batch_size, 序列长度, 特征数目)
lstm
LSTM(4, 10, batch_first=True)

根据LSTM网络的使用方式,每一层LSTM都有三个外界输入的数据,分别:

  • X: LSTM网络外输入的数据

  • h_0:上一层LSTM输出的结果

  • c_0:上一层LSTM调整后的记忆

照着前面总结的 LSTM输入的数据格式 (这里大家要搞明白batch_size、seq_len、dims各自代表什么),我们定义LSTM输入的数据。

我们定义的LSTM网络也需要输入这三个数据:

  • x: 这里我们将x的尺寸 (batch_size, seq_len, dims)依次是(3, 5, 4)

  • h0:h0的(num_layers * num_directions,  batch_size,  hidden_size)尺寸依次是(1, 3 10)

  • c0: c0的(num_layers * num_directions,  batch_size,  hidden_size)尺寸依次是(1, 3,  10)

#x: 这里我们将x的尺寸 (batch_size, seq_len, dims)依次是(3, 5, 4)
x = t.randn(3, 5, 4) 
x
tensor([[[ 0.1478, -0.7733, -0.3462,  0.0320],
         [-0.0540,  0.4757, -1.2787,  0.6141],
         [ 1.9581,  0.0015,  1.4387, -0.5895],
         [-1.0691, -1.7070,  1.0219, -0.7990],
         [-1.7735,  0.6824,  0.6067, -0.6630]],

        [[-0.3223, -0.6943,  0.1120, -1.7799],
         [-1.0542,  0.2151, -2.2530,  0.2640],
         [-0.0599, -0.1996,  0.9793, -1.4952],
         [-0.2328,  0.2297, -1.4825,  0.0720],
         [ 0.7112, -0.1165,  2.5641, -1.4247]],

        [[-0.4157, -1.1617, -0.7442, -0.8369],
         [ 0.5266,  2.3119,  0.6428,  0.3797],
         [-0.2951, -1.5711,  1.2832, -0.2773],
         [ 0.4760,  0.2403,  0.2923,  2.2315],
         [-0.3348, -0.0976,  0.0388,  0.5948]]])

h0的(num_layers * num_directions,  batch_size,  hidden_size)尺寸依次是(1, 3,  10)

h0 = t.randn(1, 3, 10)
h0
tensor([[[ 1.2469,  1.2457, -1.0390,  0.3173,  1.0083,  0.7610, -0.0088,
           0.0614,  0.5630,  0.7260],
         [-0.6529, -1.4584, -0.7871, -0.4002, -0.4619, -0.2633,  0.2818,
          -0.3486, -1.0637, -1.0772],
         [ 0.6969,  1.2095, -0.9888, -1.1326, -1.1339, -1.0660,  0.9650,
           0.4040, -0.7997, -1.3996]]])

c0的(num_layers * num_directions, batch_size,  hidden_size)尺寸依次是(1, 3,  10)

c0 = t.randn(1, 3, 10)
c0
tensor([[[-1.5377, -0.7845,  0.0971, -0.1659,  1.8828,  1.8013, -0.7545,
           0.7165,  2.1182, -0.7022],
         [ 0.2850,  0.2503, -0.8153,  0.5210,  0.0405,  0.5819, -0.1994,
           0.2940,  0.4487, -0.4580],
         [-0.6478, -1.1122, -0.0021,  0.3013,  1.1450,  0.5811, -0.8989,
          -0.2919, -0.9292,  0.0599]]])

调用之前实例化的lstm, 输入数据x和上一期lstm的 h_0c_0

output = lstm(x, (h0, c0))
output
(tensor([[[-0.2626, -0.1759,  0.0594,  0.0890,  0.4585,  0.5010, -0.2925,
            0.0523,  0.2142, -0.2536],
          [ 0.0161,  0.0077, -0.0074, -0.0148,  0.3526,  0.1798, -0.1033,
           -0.1095,  0.2395, -0.1300],
          [ 0.0354, -0.0625,  0.1339,  0.1123,  0.1212,  0.0853,  0.0660,
           -0.0315,  0.0441, -0.1083],
          [ 0.0188, -0.3257,  0.1776,  0.1890,  0.0583,  0.0848,  0.0848,
           -0.1144,  0.0529, -0.0439],
          [ 0.0346, -0.2408,  0.1594,  0.2038,  0.1914, -0.0385,  0.1791,
           -0.2192,  0.0237, -0.0506]],

         [[ 0.1276, -0.2258, -0.1029,  0.3354, -0.0689,  0.1603,  0.0564,
           -0.0073,  0.1161, -0.0802],
          [ 0.1915, -0.0040, -0.0824,  0.0939, -0.1243,  0.1259, -0.0495,
           -0.2668,  0.1277, -0.0344],
          [ 0.1991, -0.1713,  0.0796,  0.2259, -0.0527,  0.1457,  0.0777,
           -0.2230, -0.0176, -0.0446],
          [ 0.2169,  0.0204,  0.0673,  0.0936, -0.0887,  0.1041,  0.0164,
           -0.3327,  0.0704, -0.0270],
          [ 0.0938, -0.2337,  0.1380,  0.2187,  0.0426,  0.0422,  0.1518,
           -0.1585, -0.0422, -0.0449]],

         [[ 0.0959, -0.3436,  0.1434,  0.1418,  0.0680,  0.2509,  0.0244,
           -0.3842, -0.1125,  0.1595],
          [ 0.0318, -0.0647,  0.0393,  0.0978,  0.2184, -0.0525,  0.2293,
           -0.1959, -0.1001,  0.0704],
          [-0.0548, -0.3070,  0.1436,  0.1868,  0.0755, -0.0966,  0.1553,
           -0.1882,  0.0149,  0.0397],
          [-0.1964,  0.0121,  0.0689,  0.0245,  0.1022, -0.1088,  0.1525,
           -0.0814,  0.1698,  0.0285],
          [-0.0799,  0.0252,  0.0494,  0.0438,  0.1130, -0.1117,  0.1234,
           -0.1222,  0.1560,  0.0400]]], grad_fn=<TransposeBackward0>),
 (tensor([[[ 0.0346, -0.2408,  0.1594,  0.2038,  0.1914, -0.0385,  0.1791,
            -0.2192,  0.0237, -0.0506],
           [ 0.0938, -0.2337,  0.1380,  0.2187,  0.0426,  0.0422,  0.1518,
            -0.1585, -0.0422, -0.0449],
           [-0.0799,  0.0252,  0.0494,  0.0438,  0.1130, -0.1117,  0.1234,
            -0.1222,  0.1560,  0.0400]]], grad_fn=<ViewBackward>),
  tensor([[[ 0.0747, -0.3711,  0.2711,  0.7422,  0.4142, -0.0688,  0.3214,
            -0.5058,  0.0441, -0.1047],
           [ 0.1504, -0.3095,  0.3440,  0.5343,  0.0911,  0.0653,  0.2167,
            -0.3579, -0.1910, -0.1100],
           [-0.1395,  0.0483,  0.0805,  0.1036,  0.2337, -0.2296,  0.2053,
            -0.2010,  0.3182,  0.0997]]], grad_fn=<ViewBackward>)))

从上面的结果看应该分成两大部分,其中第二部分又分为两小部分。

上面的结果对应的正是当前LSTM的输出结果,以及当前 h_outc_out

out, (h_out, c_out) = lstm(x, (h0, c0))
print(out.shape)
print(h_out.shape)
print(c_out.shape)
torch.Size([3, 5, 10])
torch.Size([1, 3, 10])
torch.Size([1, 3, 10])

正常运行,说明我们的参数都组织得当,正确使用了pytorch中lstm模型。

往期文章

《用 Python 做文本分析》视频教程  

10分钟理解深度学习中的~卷积~

深度学习之 图解LSTM

100G Python学习资料(免费下载)

100G 文本分析语料资源(免费下载)     

typing库:让你的代码阅读者再也不用猜猜猜

Seaborn官方教程中文教程(一)

数据清洗 常用正则表达式大全

大邓强力推荐-jupyter notebook使用小技巧

PySimpleGUI: 开发自己第一个软件

深度特征合成:自动生成机器学习中的特征

Python 3.7中dataclass的终极指南(一)

Python 3.7中dataclass的终极指南(二)

15个最好的数据科学领域Python库

使用Pandas更好的做数据科学

[计算消费者的偏好]推荐系统与协同过滤、奇异值分解

机器学习: 识别图片中的数字

应用PCA降维加速模型训练

如何从文本中提取特征信息?

使用sklearn做自然语言处理-1

使用sklearn做自然语言处理-2

机器学习|八大步骤解决90%的NLP问题      

Python圈中的符号计算库-Sympy

Python中处理日期时间库的使用方法  

视频讲解】Scrapy递归抓取简书用户信息

美团商家信息采集神器

用chardect库解决网页乱码问题


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

郎咸平说:新经济颠覆了什么

郎咸平说:新经济颠覆了什么

郎咸平 / 东方出版社 / 2016-8 / 39.00元

正所谓“上帝欲其灭亡,必先令其疯狂”,在当下中国,“互联网+资本催化”的新经济引擎高速运转,大有碾压一切、颠覆一切之势。在新经济狂热之下,每个人都在全力以赴寻找“下一个风口”,幻想成为下一只飞起来的猪。 对此,一向以“危机论”著称的郎咸平教授再次发出盛世危言:新经济光环背后,危机已悄然而至!中国式O2O还能烧多久?P2P监管黑洞有多大?互联网造车为什么不靠谱?共享经济为什么徒有虚名?BAT为......一起来看看 《郎咸平说:新经济颠覆了什么》 这本书的介绍吧!

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具