内容简介:本文介绍一下如何使用BiLSTM(基于PyTorch)解决一个实际问题,实现如果不了解LSTM的同学请先看我的这两篇文章LSTM、PyTorch中的LSTM。下面直接开始代码讲解导库
本文介绍一下如何使用BiLSTM(基于PyTorch)解决一个实际问题,实现 给定一个长句子预测下一个单词
如果不了解LSTM的同学请先看我的这两篇文章LSTM、PyTorch中的LSTM。下面直接开始代码讲解
导库
''' code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor ''' import torch import numpy as np import torch.nn as nn import torch.optim as optim import torch.utils.data as Data dtype = torch.FloatTensor
准备数据
sentence = ( 'GitHub Actions makes it easy to automate all your software workflows ' 'from continuous integration and delivery to issue triage and more' ) word2idx = {w: i for i, w in enumerate(list(set(sentence.split())))} idx2word = {i: w for i, w in enumerate(list(set(sentence.split())))} n_class = len(word2idx) # classification problem max_len = len(sentence.split()) n_hidden = 5
我水平不佳,一开始看到这个 sentence
不懂这种写法是什么意思,如果你调用 type(sentence)
以及打印 sentence
就会知道,这其实就是个字符串,就是将上下两行字符串连接在一起的一个大字符串
数据预处理,构建dataset,定义dataloader
def make_data(sentence): input_batch = [] target_batch = [] words = sentence.split() for i in range(max_len - 1): input = [word2idx[n] for n in words[:(i + 1)]] input = input + [0] * (max_len - len(input)) target = word2idx[words[i + 1]] input_batch.append(np.eye(n_class)[input]) target_batch.append(target) return torch.Tensor(input_batch), torch.LongTensor(target_batch) # input_batch: [max_len - 1, max_len, n_class] input_batch, target_batch = make_data(sentence) dataset = Data.TensorDataset(input_batch, target_batch) loader = Data.DataLoader(dataset, 16, True)
这里面的循环还是有点复杂的,尤其是 input
和 input_batch
里面存的东西,很难理解。所以下面我会详细解释
首先开始循环, input
的第一个赋值语句会将第一个词 Github
对应的索引存起来。 input
的第二个赋值语句会将剩下的 max_len - len(input)
都用0去填充
第二次循环, input
的第一个赋值语句会将前两个词 Github
和 Actions
对应的索引存起来。 input
的第二个赋值语句会将剩下的 max_len - len(input)
都用0去填充
每次循环, input
和 target
中所存的 索引转换成word
如下图所示,因为我懒得去查看每个词对应的索引是什么,所以干脆直接写出存在其中的词
从上图可以看出, input
的长度永远保持 max_len(=21)
,并且循环了 max_len-1
次,所以最终 input_batch
的维度是 [max_len - 1, max_len, n_class]
定义网络架构
class BiLSTM(nn.Module): def __init__(self): super(BiLSTM, self).__init__() self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True) # fc self.fc = nn.Linear(n_hidden * 2, n_class) def forward(self, X): # X: [batch_size, max_len, n_class] batch_size = X.shape[0] input = X.transpose(0, 1) # input : [max_len, batch_size, n_class] hidden_state = torch.randn(1*2, batch_size, n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden] cell_state = torch.randn(1*2, batch_size, n_hidden) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden] outputs, (_, _) = self.lstm(input, (hidden_state, cell_state)) outputs = outputs[-1] # [batch_size, n_hidden * 2] model = self.fc(outputs) # model : [batch_size, n_class] return model model = BiLSTM() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001)
Bi-LSTM的网络结构图如下所示,其中Backward Layer意思不是"反向传播",而是将"句子反向输入"。具体流程就是,现有有由四个词构成的一句话"i like your friends"。常规单向LSTM的做法就是直接输入"i like your",然后预测出"friends",而双向LSTM会同时输入"i like your"和"your like i",然后将Forward Layer和Backward Layer的output进行concat(这样做可以理解为同时"汲取"正向和反向的信息),最后预测出"friends"
而正因为多了一个反向的输入,所以整个网络结构中很多隐藏层的输入和输出的某些维度会变为原来的两倍,具体如下图所示。对于双向LSTM来说, num_directions = 2
训练&测试
# Training for epoch in range(10000): for x, y in loader: pred = model(x) loss = criterion(pred, y) if (epoch + 1) % 1000 == 0: print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) optimizer.zero_grad() loss.backward() optimizer.step() # Pred predict = model(input_batch).data.max(1, keepdim=True)[1] print(sentence) print([idx2word[n.item()] for n in predict.squeeze()])
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:- 袜子商店应用:一个云原生参照应用
- Android 应用中跳转到应用市场评分
- 授之以渔-运维平台应用模块一(应用树篇)
- OAM(开放应用模型)——定义云原生应用标准的野望
- ChromeOS 终端应用程序暗示其即将支持 Linux 应用
- Android应用之间数据的交互(一)获取系统应用的数据
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
松本行弘的程式世界
松本行弘 / 鄧瑋敦 / 博碩 / 2010年07月27日
讓Ruby之父教您大師級的程式思考術! 本書以松本行弘先生對程式本質的深層認知、各種技術之優缺點的掌握,闡述Ruby這套程式語言的設計理念,並由此延伸讓您一窺程式設計的奧妙之處。本書內含許多以Ruby、Lisp、Smalltalk、Erlang、JavaScript等動態語言所寫成的範例,從動態語言、函數式程式設計等領域開展您的學習視野。 本書精華: ‧物件導向與抽象化 ‧......一起来看看 《松本行弘的程式世界》 这本书的介绍吧!