博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
rnn梯度弥散 LSTM无梯度弥散
阅读量:4146 次
发布时间:2019-05-25

本文共 3322 字,大约阅读时间需要 11 分钟。

之前看过,现在突然想不起,真的是好记性不如烂笔头,希望大家在看的时候能够拿笔和纸跟着推导一遍,加深理解。

转自:

1.RNN梯度弥散和爆炸的原因

经典的RNN结构如下图所示:

在这里插入图片描述
假设我们的时间序列只有三段, S 0 S_0 S0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
在这里插入图片描述
假设在t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_3=\frac{1}{2}(Y_3-O_3)^2 L3=21(Y3O3)2

则对于一次训练任务的损失函数为 L = ∑ t = 1 T L t L=\sum_{t=1}^{T}L_t L=t=1TLt,即每一时刻损失值的累加。

使用随机梯度下降法训练RNN其实就是对 W x W_x Wx W s W_s Ws W 0 W_0 W0 以及 b 1 b_1 b1 b 2 b_2 b2 求偏导,并不断调整它们以使L尽可能达到最小的过程。

现在假设我们我们的时间序列只有三段,t1,t2,t3。

我们只对t3时刻的 [公式] 求偏导(其他时刻类似):

在这里插入图片描述
可以看出对于 W 0 W_0 W0 求偏导并没有长期依赖,但是对于 W x W_x Wx W s W_s Ws求偏导,会随着时间序列产生长期依赖。因为 S t S_t St随着时间序列向前传播,而 S t S_t St又是 W x W_x Wx W s W_s Ws的函数。

根据上述求偏导的过程,我们可以得出任意时刻对 W x W_x Wx W s W_s Ws求偏导的公式:

在这里插入图片描述

任意时刻对 W s W_s Ws 求偏导的公式同上。

如果加上激活函数, S j = t a n h ( W x X j + W s S j − 1 + b 1 ) S_j=tanh(W_xX_j+W_sS_{j-1}+b_1) Sj=tanh(WxXj+WsSj1+b1)

∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh^{'}W_s j=k+1tSj1Sj=j=k+1ttanhWs
激活函数tanh和它的导数图像如下。

在这里插入图片描述

由上图可以看出 t a n h ′ ≤ 1 tanh^{'} \leq1 tanh1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_xX_j+W_sS_{j-1}+b_1=0 WxXj+WsSj1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s j=k+1ttanhWs,就会趋近于0,和 0.01^{50} 趋近与0是一个道理。同理当 W s W_s Ws很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s j=k+1ttanhWs 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。

至于怎么避免这种现象,让我在看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial{L_t}}{\partial{W_x}}=\sum_{k=0}^{t}\frac{\partial{L_t}}{\partial{O_t}}\frac{\partial{O_t}}{\partial{S_t}}(\prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}})\frac{\partial{S_k}}{\partial{W_x}} WxLt=k=0tOtLtStOt(j=k+1tSj1Sj)WxSk梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} j=k+1tSj1Sj这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx1 j=k+1tSj1Sj1另一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 0 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx0 j=k+1tSj1Sj0 。其实这就是LSTM做的事情,至于细节问题下节将进行介绍。

2.LSTM如何解决梯度消失问题

先上一张LSTM的经典图:

在这里插入图片描述

而LSTM可以抽象成这样:

在这里插入图片描述

三个×分别代表的就是forget gate,input gate,output gate,而我认为LSTM最关键的就是forget gate这个部件。这三个gate是如何控制流入流出的呢,其实就是通过下面 f t , i t , o t f_t,i_t,o_t ft,it,ot 三个函数来控制,因为 σ ( x ) \sigma(x) σ(x)(代表sigmoid函数) 的值是介于0到1之间的,刚好用趋近于0时表示流入不能通过gate,趋近于1时表示流入可以通过gate。

在这里插入图片描述
当前的状态 S t = f t S t − 1 + i t X t S_t=f_tS_{t-1}+i_tX_t St=ftSt1+itXt类似与传统RNN S t = W s S t − 1 + W x X t + b 1 S_t=W_sS_{t-1}+W_xX_t+b_1 St=WsSt1+WxXt+b1。将LSTM的状态表达式展开后得:
在这里插入图片描述
如果加上激活函数, S t = t a n h [ σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t ] S_t=tanh[\sigma(W_fX_t+b_f)S_{t-1}+\sigma(W_iX_t+b_i)X_t] St=tanh[σ(WfXt+bf)St1+σ(WiXt+bi)Xt]

这篇文章中传统RNN求偏导的过程包含

在这里插入图片描述

对于LSTM同样也包含这样的一项,但是在LSTM中

在这里插入图片描述

假设 Z = t a n h ′ ( x ) σ ( y ) Z=tanh'(x)\sigma(y) Z=tanh(x)σ(y) ,则 Z Z Z的函数图像如下图所示:
在这里插入图片描述
可以看到该函数值基本上不是0就是1。

这篇文章中传统RNN的求偏导过程:

在这里插入图片描述

如果在LSTM中上式可能就会变成:

在这里插入图片描述

因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh'\sigma(W_fX_t+b_f)\approx0|1 j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf)01 ,这样就解决了传统RNN中梯度消失的问题。

转载地址:http://cwnti.baihongyu.com/

你可能感兴趣的文章
前阿里手淘前端负责人@winter:前端人如何保持竞争力?
查看>>
【JavaScript 教程】面向对象编程——实例对象与 new 命令
查看>>
我在网易做了6年前端,想给求职者4条建议
查看>>
SQL1015N The database is in an inconsistent state. SQLSTATE=55025
查看>>
RQP-DEF-0177
查看>>
MySQL字段类型的选择与MySQL的查询效率
查看>>
Java的Properties配置文件用法【续】
查看>>
JAVA操作properties文件的代码实例
查看>>
IPS开发手记【一】
查看>>
Java通用字符处理类
查看>>
文件上传时生成“日期+随机数”式文件名前缀的Java代码
查看>>
Java代码检查工具Checkstyle常见输出结果
查看>>
北京十大情人分手圣地
查看>>
Android自动关机代码
查看>>
Android中启动其他Activity并返回结果
查看>>
2009年33所高校被暂停或被限制招生
查看>>
GlassFish 部署及应用入门
查看>>
X-code7 beta error: warning: Is a directory
查看>>
Error: An App ID with identifier "*****" is not avaliable. Please enter a different string.
查看>>
3.5 YOLO9000: Better,Faster,Stronger(YOLO9000:更好,更快,更强)
查看>>