内容简介:LSTM or Long Short Term Memory is a very important building block of complex and state of the art neural network architectures. The main idea behind this article is explaining the math behind it. To get an initial understanding of what LSTM is, I would sug
LSTM or Long Short Term Memory is a very important building block of complex and state of the art neural network architectures. The main idea behind this article is explaining the math behind it. To get an initial understanding of what LSTM is, I would suggest the following blog.
Contents :
A — Concept
- Introduction
- Explanation
- Derivation Prerequisites
B — Derivation
- Output of the LSTM
- Hidden state
- Output gate
- Cell state
- Input gate
- Forget gate
- Input to the LSTM
- Weights and biases
C — Back propagation through time
D — Conclusion
Concept
Introduction
The above is a diagram for a single LSTM cell. I know it looks scary :cold_sweat: , but we will go through it one by one and by the end of the article, hopefully it will be pretty clear.
Explanation
Basically a single LSTM cell has 4 different components. Forget gate, input gate, output gate and the cell state. We will first discuss the use of these parts in brief (for detailed explanation please refer to the above blog) and then dive into the math part of it.
Forget gate
As the name suggests, this part is responsible for deciding what information is to be thrown away or kept from the last step. This is done by the first sigmoid layer.
Based on h_t-1 (previous hidden state) and x_t (current input at time-step t), this decides a value between 0 and 1 for each value in cell state C_t-1.
For all 1’s, all the information is kept as it is, for all 0’s all the information is discarded and with other values it decides how much information from previous state is to be carried to the next state.
Input gate
Christopher Olah has a beautiful explanation of what happens in the input gate. To cite his blog:
Now these two values i.e i_t and c~t combine to decide what new input is to be fed to the cell state.
Cell state
Cell state serves as the memory of an LSTM. This is where they perform way better than vanilla RNN’s when dealing with longer sequences of input. At each time-step the previous cell state (C_t-1) combines with the forget gate to decide what information is to be carried forward which in turn combines with the input gate (i_t and c~t) to form the new cell state or the new memory of the cell.
Output gate
At last the LSTM cell has to give some output. The cell state obtained from above is passed through a hyperbolic function called tanh so that the cell state values are filtered between -1 and 1. For details into different activation function,this is a nice blog.
Now i hope the basic cell structure of a LSTM cell is clear and we can proceed to the derivation of equations which we will use in our implementation.
Derivation Prerequisites
- Requirements : The core concept of deriving equations is based on backpropogation, cost function and loss. If you are not familiar with these , these are few links that will help in getting a good understanding. This article also assumes a basic understanding of high school calculus (calculating derivatives and there rules).
2. Variables : For each gates we have a set of weights and biases which will be denoted as:
- W_f,b_f->Forget gate weight and bias
- W_i,b_i->Input gate weight and bias
- W_c,b_c->Candidate cell state weight and bias
- W_o,b_o->Output gate weight and bias
W_v ,b_v -> Weight and bias associated with the Softmax layer.
f_t, i_t,c_tilede_t, o_t -> Output of the activation functions
a_f, a_i, a_c, a_o -> Input to the activation functions
J is the cost function, with respect to which we will be calculating the derivatives. Note the ( character after the underscore(_) is a subscript)
3. Forward prop equations:
4. Process for calculation : Let’s take forget gate example to illustrate the calculation of the derivatives. We need to follow the path of red arrows in the below figure.
So we chalk out a path from f_t to our cost function J i.e
f_t →C_t →h_t →J.
The backpropagation happens exactly in the same step but in reverse i.e
f_t ←C_t ←h_t ←J.
J is differentiated with respect to h_t, h_t with respect to _C_t and C_t with respect to f_t.
So if we observe here , J and h_t is the last step of the cell, and if we calculate dJ/dh_t , then it can be used for calculations like dJ/dC_t since :
dJ/dC_t = dJ/dh_t * dh_t/dC_t ( Chain rule )
Similarly, the derivatives will be calculated for all the variables mentioned in point no 1.
Now that we have the variables ready and we are clear with the forward prop equations, its time to dive into deriving the derivatives through back-propagation. We will start with the output equations as we saw that the same derivatives is used in other equations. This is where the chain rule comes in. So let’s start now.
Derivation
Output of the lstm
The output has two values which we need to calculate.
- Softmax : For derivative of Cross Entropy Loss with Softmax we will be using the final equation directly.
The detailed derivation can be found below:
Hidden State
We have the hidden state as h_t. h_t is differentiated w.r.t J. According to chain rule, the derivation can be seen in the below figure. We use the value of V_t as mentioned in Fig 9 equation 7 i.e :
V_t = W_v.h_t + b_v
Output gate
Variables associated : a_o and o_t.
o_t: In the below image, the path between o_t and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dV_t * dV_t/dh_t * dh_t/dO_t
dJ/dV_t * dV_t/dh_t can be written as dJ/dh_t (we have this value from hidden state).
The value of h_t = o_t * tanh(c_t) -> Fig 9 equation 6. So we only need to differentiate h_t w.r.t o_t. The differentiation will be as :-
a_o: Similarly, the path between a_o and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dV_t * dV_t/dh_t * dh_t/dO_t * dO_t/da_o
dJ/dV_t * dV_t/dh_t * dh_t/dO_t can be written as dJ/dO_t (we have this value from above o_t).
o_t = sigmoid (a_o) -> Fig 8 equation 4 . So we only need to differentiate o_t w.r.t a_o. T he differentiation will be as :-
Cell State
C_t is the cell state of the cell. Along with it, we also handle the candidate cell state a_c and c~_t here.
C_t :The derivation for C_t is pretty trivial, as the path from C_t to J is simple enough. C_t → h_t → V_t → J. As we already have dJ/dh_t, we directly differentiate h_t w.r.t C_t.
h_t = o_t * tanh(c_t) -> Fig 9 equation 6. So we only need to differentiate h_t w.r.t C_t.
Note: The cell state clubbed will be explained at the end of the article.
c~_t: In the below image, the path between c~_t and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/dc~_t
dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from above).
The value of C_t is as shown in Fig 9 equation 5 (tilde (~) sign is missing in the last c_t in line no 3 in below figure -> writing mistake). So we only need to differentiate C_t w.r.t c~_t.
a_c :In the below image, the path between a_c and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/dc~_t * dc~_t/da_c
dJ/dh_t * dh_t/dC_t * dC_t/dc~_t can be written as dJ/dc~_t (we have this value from above).
The value of c~_t is as shown in Fig 8 equation 3. So we only need to differentiate c~_t w.r.t a_c .
Input gate
Variables associated : i_t and a_i
i_t: In the below image, the path between i_t and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/di_t
dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from cell state). So we only need to differentiate C_t w.r.t i_t.
The value of C_t is as shown in Fig 9 equation 5. So the differentiation will be as :-
a_i :In the below image, the path between a_i and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/di_t * di_t/da_i
dJ/dh_t * dh_t/dC_t * dC_t/di_t can be written as dJ/di_t (we have this value from above). So we only need to differentiate i_t w.r.t a_i.
Forget Gate
Variables associated : f_t and a_f
f_t: In the below image, the path between f_t and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/df_t
dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from cell state). So we only need to differentiate C_t w.r.t f_t.
The value of C_t is as shown in Fig 9 equation 5. So the differentiation will be as :-
a_f: In the below image, the path between f_t and J is shown. According the arrows the full equation for the differentiation will be as follows:
dJ/dh_t * dh_t/dC_t * dC_t/df_t * df_t/da_t
dJ/dh_t * dh_t/dC_t * dC_t/df_t can be written as dJ/df_t (we have this value from above). So we only need to differentiate f_t w.r.t a_f.
Input to the Lstm
There are 2 variables associated with input for each cell i.e previous cell state C_t-1 and previous hidden state concatenated with current input i.e
[h_t-1 ,x_t] -> Z_t
C_t-1 :This is the memory of the Lstm cell. Figure 5 shows the cell state. The derivation of C_t-1 is pretty simple as only C_t-1 and C_t are involved.
Z_t: As shown in the below figure, Z_t goes into 4 different path, a_f,a_i,a_o,a_c.
Z_t → a_f → f_t → C_t → h_t → J . -> Forget gate
Z_t → a_i→ i_t → C_t → h_t → J . -> Input gate
Z_t → a_c → c~_t → C_t → h_t → J . -> Candidate cell state
Z_t → a_o → o_t → C_t → h_t → J . -> Output gate
Weights and biases
The derivation for W and b is straight forward. The below derivation is for the output gate of the Lstm. For the rest of the gates, similar process is done for weights and biases.
以上所述就是小编给大家介绍的《LSTM Gradients》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
剑指Offer:名企面试官精讲典型编程题(第2版)
何海涛 / 电子工业出版社 / 2017-5 / 65.00
《剑指Offer:名企面试官精讲典型编程题(第2版)》剖析了80个典型的编程面试题,系统整理基础知识、代码质量、解题思路、优化效率和综合能力这5个面试要点。《剑指Offer:名企面试官精讲典型编程题(第2版)》共分7章,主要包括面试的流程,讨论面试每一环节需要注意的问题;面试需要的基础知识,从编程语言、数据结构及算法三方面总结程序员面试知识点;高质量的代码,讨论影响代码质量的3个要素(规范性、完整......一起来看看 《剑指Offer:名企面试官精讲典型编程题(第2版)》 这本书的介绍吧!