本文共 3322 字,大约阅读时间需要 11 分钟。
之前看过,现在突然想不起,真的是好记性不如烂笔头,希望大家在看的时候能够拿笔和纸跟着推导一遍,加深理解。
转自:
经典的RNN结构如下图所示:
假设我们的时间序列只有三段, S 0 S_0 S0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下: 假设在t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_3=\frac{1}{2}(Y_3-O_3)^2 L3=21(Y3−O3)2则对于一次训练任务的损失函数为 L = ∑ t = 1 T L t L=\sum_{t=1}^{T}L_t L=∑t=1TLt,即每一时刻损失值的累加。
使用随机梯度下降法训练RNN其实就是对 W x W_x Wx 、 W s W_s Ws 、 W 0 W_0 W0 以及 b 1 b_1 b1 b 2 b_2 b2 求偏导,并不断调整它们以使L尽可能达到最小的过程。
现在假设我们我们的时间序列只有三段,t1,t2,t3。
我们只对t3时刻的 [公式] 求偏导(其他时刻类似):
可以看出对于 W 0 W_0 W0 求偏导并没有长期依赖,但是对于 W x W_x Wx、 W s W_s Ws求偏导,会随着时间序列产生长期依赖。因为 S t S_t St随着时间序列向前传播,而 S t S_t St又是 W x W_x Wx、 W s W_s Ws的函数。根据上述求偏导的过程,我们可以得出任意时刻对 W x W_x Wx、 W s W_s Ws求偏导的公式:
任意时刻对 W s W_s Ws 求偏导的公式同上。
如果加上激活函数, S j = t a n h ( W x X j + W s S j − 1 + b 1 ) S_j=tanh(W_xX_j+W_sS_{j-1}+b_1) Sj=tanh(WxXj+WsSj−1+b1) ,
则 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh^{'}W_s ∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′Ws 激活函数tanh和它的导数图像如下。由上图可以看出 t a n h ′ ≤ 1 tanh^{'} \leq1 tanh′≤1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_xX_j+W_sS_{j-1}+b_1=0 WxXj+WsSj−1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s ∏j=k+1ttanh′Ws,就会趋近于0,和 0.01^{50} 趋近与0是一个道理。同理当 W s W_s Ws很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s ∏j=k+1ttanh′Ws 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。
至于怎么避免这种现象,让我在看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial{L_t}}{\partial{W_x}}=\sum_{k=0}^{t}\frac{\partial{L_t}}{\partial{O_t}}\frac{\partial{O_t}}{\partial{S_t}}(\prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}})\frac{\partial{S_k}}{\partial{W_x}} ∂Wx∂Lt=∑k=0t∂Ot∂Lt∂St∂Ot(∏j=k+1t∂Sj−1∂Sj)∂Wx∂Sk梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} ∏j=k+1t∂Sj−1∂Sj这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx1 ∏j=k+1t∂Sj−1∂Sj≈1另一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 0 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx0 ∏j=k+1t∂Sj−1∂Sj≈0 。其实这就是LSTM做的事情,至于细节问题下节将进行介绍。
先上一张LSTM的经典图:
而LSTM可以抽象成这样:
三个×分别代表的就是forget gate,input gate,output gate,而我认为LSTM最关键的就是forget gate这个部件。这三个gate是如何控制流入流出的呢,其实就是通过下面 f t , i t , o t f_t,i_t,o_t ft,it,ot 三个函数来控制,因为 σ ( x ) \sigma(x) σ(x)(代表sigmoid函数) 的值是介于0到1之间的,刚好用趋近于0时表示流入不能通过gate,趋近于1时表示流入可以通过gate。
当前的状态 S t = f t S t − 1 + i t X t S_t=f_tS_{t-1}+i_tX_t St=ftSt−1+itXt类似与传统RNN S t = W s S t − 1 + W x X t + b 1 S_t=W_sS_{t-1}+W_xX_t+b_1 St=WsSt−1+WxXt+b1。将LSTM的状态表达式展开后得: 如果加上激活函数, S t = t a n h [ σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t ] S_t=tanh[\sigma(W_fX_t+b_f)S_{t-1}+\sigma(W_iX_t+b_i)X_t] St=tanh[σ(WfXt+bf)St−1+σ(WiXt+bi)Xt]这篇文章中传统RNN求偏导的过程包含
对于LSTM同样也包含这样的一项,但是在LSTM中
假设 Z = t a n h ′ ( x ) σ ( y ) Z=tanh'(x)\sigma(y) Z=tanh′(x)σ(y) ,则 Z Z Z的函数图像如下图所示: 可以看到该函数值基本上不是0就是1。这篇文章中传统RNN的求偏导过程:
如果在LSTM中上式可能就会变成:
因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh'\sigma(W_fX_t+b_f)\approx0|1 ∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′σ(WfXt+bf)≈0∣1 ,这样就解决了传统RNN中梯度消失的问题。
转载地址:http://cwnti.baihongyu.com/