基于长短期记忆网络的短时雷达外推算法

栏目: Python · 发布时间: 6年前

内容简介:本文由机器之心经授权转载自本文主要介绍的是利用现有的pytorch框架,实现ConvLSTM和ConvGRU内核,并实现一个多层RNN的封装结构层,方便使用者快速的堆叠多层的RNNCell。得益于pytorch的便利,我们只需要按照公式写出forward的过程,后续的backward将由框架本身给我们完成。同时,作者还基于这些网络结构,搭建了一个简单的图像时序预测模型,方便读者理解每一结构之间的作用和联系。

本文由机器之心经授权转载自 墨迹天气TechInfo(ID:moji_techinfo) ,未经授权禁止二次转载。

基于长短期记忆网络的短时雷达外推算法

本文主要介绍的是利用现有的pytorch框架,实现ConvLSTM和ConvGRU内核,并实现一个多层RNN的封装结构层,方便使用者快速的堆叠多层的RNNCell。得益于pytorch的便利,我们只需要按照公式写出forward的过程,后续的backward将由框架本身给我们完成。同时,作者还基于这些网络结构,搭建了一个简单的图像时序预测模型,方便读者理解每一结构之间的作用和联系。

首先是ConvLSTM,其单元结构如下图所示:

基于长短期记忆网络的短时雷达外推算法

在公式中可以明显看出i,f,g,o,中的计算过程大多一致,因此我们利用一个卷积层,多个卷积核的方式来完成计算。对于每个门内部的计算我们也进行合并。如下图所示:

拼接输入数据X,h:

基于长短期记忆网络的短时雷达外推算法

combined =torch.cat((input, hidden), 1)

计算每一个门的输出:

基于长短期记忆网络的短时雷达外推算法

基于长短期记忆网络的短时雷达外推算法

A = self.conv(combined)
(ai, af,ao, ag)= torch.split(A, self.num_features, dim=1)  # it should return 4 tensors

这样,我们就完成了LSTM中所有门的计算,在利用pytorch的支持下,我们只使用三行代码就完成了基础的门运算操作。

然后我们参照原始公式,给每个门的输出数据加上激活函数和dropout层:

i =torch.sigmoid(ai)
i = self.dropout(i)
f = torch.sigmoid(af)
f = self.dropout(f)
o = torch.sigmoid(ao)
o = self.dropout(o)
g = torch.tanh(ag)
g = self.dropout(g)

得到f,g,i,o之后可以继续计算Ct,Ht:

next_c = f * c+ i * g
next_h = o * torch.tanh(next_c)
next_h = self.dropout(next_h)

得到Ct,Ht之后整个LSTM单元的计算过程变结束了,LSTM的主要创新之处在于存储单元Ct,Ct在整个神经元中充当了状态信息的累加器,神经元通过各种参数化学习而来的控制门对信息进行保存和削减。

为了配合后面的多层RNN封装结构,我们将神经元的输出封装为如下格式:

return next_h, (next_h,next_c)

至此,一个完整的 LSTM 单元 计算过程就实现完成了。

GRU单元的实现过程和LSTM类似,这里只给出计算公式和对应的代码实现:

基于长短期记忆网络的短时雷达外推算法

def forward(self, input, hidden):
    c1 = self.ConvGates(torch.cat((input, hidden), 1))
    (rt, ut)= c1.chunk(2, 1)
    reset_gate = self.dropout(f.sigmoid(rt))
    update_gate = self.dropout(f.sigmoid(ut))
    gated_hidden = torch.mul(reset_gate, hidden)
    p1 = self.Conv_ct(torch.cat((input, gated_hidden), 1))
    ct = f.tanh(p1)
    next_h = torch.mul(update_gate, hidden)+ (1 - update_gate) * ct
    return next_h

ConvGRUCell的具体定义过程可以查考文末的GitHub源码地址。

接下来我们继续承接上面的输入,实现一个 多层 RNN 的封装结构层

其单元示意图如下所示:

基于长短期记忆网络的短时雷达外推算法

MultiRNNCell层中封装了两个RNN单元Cell1和Cell2,数据x1首先传送给Cell1,Cell1有其初始化的状态信息h1(本文的LSTM结构状态信息有两个,分别为h&c,下文中不再赘述),经过Cell1神经元的处理,得到新的状态信息h1,并传给下一个单元Cell2,经过处理得到新的h2。两个cell的状态信息将会被保留,在序列中的下一个时次数据x2到来之后,继续参与计算。

forward的处理过程如下:

def forward(self, input, hidden_state):
    cur_inp = input
    new_states = []

new_states用来保存每个cell计算得来的新状态信息,hidden_state保存的是上一时次的状态信息,t为0时,hidden_state为初始化的状态信息。

for i,cell in enumerate(self._cells):
    cur_state = hidden_state[i]

    cur_inp, new_state= cell(cur_inp, cur_state)

    new_states.append(new_state)

MultiRNNCell对象的self._cells中保存了多层封装结构中的所有Cell,具体初始化方式和类成员可以在文末的github源码地址中找到。计算过程中我们取每个cell对应的状态信息h,和输入数据x,当前Cell计算后,下一个Cell的输入x使用上一个单元的输出h。

在源码中我提供了一个简单的图像时序预测模型,在这里简单讲解一下它的数据流:

def forward(self, data):
    new_state = self.stacked_lstm.init_hidden(data.size()[1])
     data的shape为(num_seqs,batch_size,channels_img,size_H,size_W)

new_state为初始化的状态信息

self.stacked_lstm是多层RNN封装单元,其内部提供了一个初始化参数的方法。

x_unwrap = []
for i inxrange(self.input_num_seqs + self.output_num_seqs):
    # print i
    if i< self.input_num_seqs:
        y_1, new_state= self.stacked_lstm(data[i], new_state)
    else:
        y_1, new_state= self.stacked_lstm(x_1, new_state)
    # print y_1.size()
    x_1 = self.deconv1(y_1)
    # print x_1.size()
    if i>= self.input_num_seqs:
        x_unwrap.append(x_1)

return x_unwrap

input_num_seqs是模型的序列输入长度,output_num_seqs是模型的序列输出长度,在输入的过程中每个多层RNN结构体的输入是我们的data,在输出过程中每个多层RNN结构体的输入是我们前一个时次的输出。

基于长短期记忆网络的短时雷达外推算法

在后续的loss计算过程中,我们计算的是实际的输出序列和我们预测出的x_unwarp序列之间的误差。其数据流如下图所示:

至此,整个基于pytorch实现的ConvLSTM和ConvGRU内核,多层封装结构层,时序预测网络已经完成了,受限于作者本身的水平,如有不妥之处可以在下方留言提出。

本文中涉及的网络结构代码地址为:https://github.com/chencodeX/RNN_Pytorch

在源码中还有一些基于上述结构实现的编码预测网络(encoder-forecaster),加权均方误差等,在本文中不再赘述,也欢迎大家fork。

本文作者简介:

陈子豪

墨迹风云科技股份有限公司算法工程师,专注于气象与深度学习领域融合应用。

基于长短期记忆网络的短时雷达外推算法

本文由机器之心经授权转载自 墨迹天气TechInfo(ID:moji_techinfo)

原文链接: https://mp.weixin.qq.com/s/TeXtXc9snP05LqrjLHzJvg


以上所述就是小编给大家介绍的《基于长短期记忆网络的短时雷达外推算法》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Web信息架构(第3版)

Web信息架构(第3版)

Peter Morville、Louis Rosenfeld / 陈建勋 / 电子工业出版社 / 2008年8月 / 85.00

本书涵盖了信息架构基本原理和实践应用的方方面面。全书共7个部分,包括信息架构概述、信息架构的基本原理、信息架构的开发流程和方法论、信息架构实践、信息架构与组织、两个案例研究,以及参考资料清单。 本书兼具较高的理论价值和实用价值,曾被Web设计领域多本书籍重点推荐,是信息架构领域公认的经典书,不论新手还是专家都能各取所需。本书可供Web设计与开发者、Web架构师、网站管理者及信息管理相关人员参......一起来看看 《Web信息架构(第3版)》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具