基于长短期记忆网络的短时雷达外推算法

栏目: Python · 发布时间: 6年前

内容简介:本文由机器之心经授权转载自本文主要介绍的是利用现有的pytorch框架,实现ConvLSTM和ConvGRU内核,并实现一个多层RNN的封装结构层,方便使用者快速的堆叠多层的RNNCell。得益于pytorch的便利,我们只需要按照公式写出forward的过程,后续的backward将由框架本身给我们完成。同时,作者还基于这些网络结构,搭建了一个简单的图像时序预测模型,方便读者理解每一结构之间的作用和联系。

本文由机器之心经授权转载自 墨迹天气TechInfo(ID:moji_techinfo) ,未经授权禁止二次转载。

基于长短期记忆网络的短时雷达外推算法

本文主要介绍的是利用现有的pytorch框架,实现ConvLSTM和ConvGRU内核,并实现一个多层RNN的封装结构层,方便使用者快速的堆叠多层的RNNCell。得益于pytorch的便利,我们只需要按照公式写出forward的过程,后续的backward将由框架本身给我们完成。同时,作者还基于这些网络结构,搭建了一个简单的图像时序预测模型,方便读者理解每一结构之间的作用和联系。

首先是ConvLSTM,其单元结构如下图所示:

基于长短期记忆网络的短时雷达外推算法

在公式中可以明显看出i,f,g,o,中的计算过程大多一致,因此我们利用一个卷积层,多个卷积核的方式来完成计算。对于每个门内部的计算我们也进行合并。如下图所示:

拼接输入数据X,h:

基于长短期记忆网络的短时雷达外推算法

combined =torch.cat((input, hidden), 1)

计算每一个门的输出:

基于长短期记忆网络的短时雷达外推算法

基于长短期记忆网络的短时雷达外推算法

A = self.conv(combined)
(ai, af,ao, ag)= torch.split(A, self.num_features, dim=1)  # it should return 4 tensors

这样,我们就完成了LSTM中所有门的计算,在利用pytorch的支持下,我们只使用三行代码就完成了基础的门运算操作。

然后我们参照原始公式,给每个门的输出数据加上激活函数和dropout层:

i =torch.sigmoid(ai)
i = self.dropout(i)
f = torch.sigmoid(af)
f = self.dropout(f)
o = torch.sigmoid(ao)
o = self.dropout(o)
g = torch.tanh(ag)
g = self.dropout(g)

得到f,g,i,o之后可以继续计算Ct,Ht:

next_c = f * c+ i * g
next_h = o * torch.tanh(next_c)
next_h = self.dropout(next_h)

得到Ct,Ht之后整个LSTM单元的计算过程变结束了,LSTM的主要创新之处在于存储单元Ct,Ct在整个神经元中充当了状态信息的累加器,神经元通过各种参数化学习而来的控制门对信息进行保存和削减。

为了配合后面的多层RNN封装结构,我们将神经元的输出封装为如下格式:

return next_h, (next_h,next_c)

至此,一个完整的 LSTM 单元 计算过程就实现完成了。

GRU单元的实现过程和LSTM类似,这里只给出计算公式和对应的代码实现:

基于长短期记忆网络的短时雷达外推算法

def forward(self, input, hidden):
    c1 = self.ConvGates(torch.cat((input, hidden), 1))
    (rt, ut)= c1.chunk(2, 1)
    reset_gate = self.dropout(f.sigmoid(rt))
    update_gate = self.dropout(f.sigmoid(ut))
    gated_hidden = torch.mul(reset_gate, hidden)
    p1 = self.Conv_ct(torch.cat((input, gated_hidden), 1))
    ct = f.tanh(p1)
    next_h = torch.mul(update_gate, hidden)+ (1 - update_gate) * ct
    return next_h

ConvGRUCell的具体定义过程可以查考文末的GitHub源码地址。

接下来我们继续承接上面的输入,实现一个 多层 RNN 的封装结构层

其单元示意图如下所示:

基于长短期记忆网络的短时雷达外推算法

MultiRNNCell层中封装了两个RNN单元Cell1和Cell2,数据x1首先传送给Cell1,Cell1有其初始化的状态信息h1(本文的LSTM结构状态信息有两个,分别为h&c,下文中不再赘述),经过Cell1神经元的处理,得到新的状态信息h1,并传给下一个单元Cell2,经过处理得到新的h2。两个cell的状态信息将会被保留,在序列中的下一个时次数据x2到来之后,继续参与计算。

forward的处理过程如下:

def forward(self, input, hidden_state):
    cur_inp = input
    new_states = []

new_states用来保存每个cell计算得来的新状态信息,hidden_state保存的是上一时次的状态信息,t为0时,hidden_state为初始化的状态信息。

for i,cell in enumerate(self._cells):
    cur_state = hidden_state[i]

    cur_inp, new_state= cell(cur_inp, cur_state)

    new_states.append(new_state)

MultiRNNCell对象的self._cells中保存了多层封装结构中的所有Cell,具体初始化方式和类成员可以在文末的github源码地址中找到。计算过程中我们取每个cell对应的状态信息h,和输入数据x,当前Cell计算后,下一个Cell的输入x使用上一个单元的输出h。

在源码中我提供了一个简单的图像时序预测模型,在这里简单讲解一下它的数据流:

def forward(self, data):
    new_state = self.stacked_lstm.init_hidden(data.size()[1])
     data的shape为(num_seqs,batch_size,channels_img,size_H,size_W)

new_state为初始化的状态信息

self.stacked_lstm是多层RNN封装单元,其内部提供了一个初始化参数的方法。

x_unwrap = []
for i inxrange(self.input_num_seqs + self.output_num_seqs):
    # print i
    if i< self.input_num_seqs:
        y_1, new_state= self.stacked_lstm(data[i], new_state)
    else:
        y_1, new_state= self.stacked_lstm(x_1, new_state)
    # print y_1.size()
    x_1 = self.deconv1(y_1)
    # print x_1.size()
    if i>= self.input_num_seqs:
        x_unwrap.append(x_1)

return x_unwrap

input_num_seqs是模型的序列输入长度,output_num_seqs是模型的序列输出长度,在输入的过程中每个多层RNN结构体的输入是我们的data,在输出过程中每个多层RNN结构体的输入是我们前一个时次的输出。

基于长短期记忆网络的短时雷达外推算法

在后续的loss计算过程中,我们计算的是实际的输出序列和我们预测出的x_unwarp序列之间的误差。其数据流如下图所示:

至此,整个基于pytorch实现的ConvLSTM和ConvGRU内核,多层封装结构层,时序预测网络已经完成了,受限于作者本身的水平,如有不妥之处可以在下方留言提出。

本文中涉及的网络结构代码地址为:https://github.com/chencodeX/RNN_Pytorch

在源码中还有一些基于上述结构实现的编码预测网络(encoder-forecaster),加权均方误差等,在本文中不再赘述,也欢迎大家fork。

本文作者简介:

陈子豪

墨迹风云科技股份有限公司算法工程师,专注于气象与深度学习领域融合应用。

基于长短期记忆网络的短时雷达外推算法

本文由机器之心经授权转载自 墨迹天气TechInfo(ID:moji_techinfo)

原文链接: https://mp.weixin.qq.com/s/TeXtXc9snP05LqrjLHzJvg


以上所述就是小编给大家介绍的《基于长短期记忆网络的短时雷达外推算法》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

大数据技术原理与应用

大数据技术原理与应用

林子雨 / 人民邮电出版社 / 2015-8-1 / 45.00

大数据作为继云计算、物联网之后IT行业又一颠覆性的技术,备受关注。大数据处不在,包括金融、汽车、零售、餐饮、电信、能源、政务、医疗、体育、娱乐等在内的社会各行各业,都融入了大数据的印迹,大数据对人类的社会生产和生活必将产生重大而深远的影响。 大数据时代的到来,迫切需要高校及时建立大数据技术课程体系,为社会培养和输送一大批具备大数据专业素养的高级人才,满足社会对大数据人才日益旺盛的需求。本书定......一起来看看 《大数据技术原理与应用》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

随机密码生成器
随机密码生成器

多种字符组合密码

MD5 加密
MD5 加密

MD5 加密工具