再探反向传播算法(推导)

之前也写过关于反向传播算法中几个公式的推导,最近总被人问到其中推导的细节,发现之前写的内容某在些地方很牵强,很突兀,没有一步一步紧跟逻辑(我也不准备修正,因为它也代表了一种思考方式)。这两天又重新回顾了一下反向传播算法,所有就再次来说说反向传播算法。这篇博文的目的在于要交代清楚为什么要引入反向传播算法,以及为什么它叫反向传播。

1.从前(正)向传播谈起

在谈反向传播算法之前,我们先来简单回顾一下正向传播(详细版戳此处)。假设有如下网络结构:
这里写图片描述

其中:

L = 神经网络总共包含的层数 S l = l 层的神经元数目 K = 输出层的神经元数,亦即分类的数目 w i j l = l j l + 1 i

即对如上网络结构来说, L = 3 , s 1 = 3 , s 2 = 2 , s 3 = K = 2 a i l 表示第 l 层第 i 个神经元的激活值, b l 表示第 l 层的偏置。

则有如下正向传播过程:

z 1 2 = a 1 1 w 11 1 + a 2 1 w 12 1 + a 3 1 w 13 1 + b 1 z 2 2 = a 1 1 w 21 1 + a 2 1 w 22 1 + a 3 1 w 23 1 + b 1 [ z 1 2 z 2 2 ] = [ w 11 1 w 12 1 w 13 1 w 21 1 w 22 1 w 23 1 ] 2 × 3 × [ a 1 1 a 2 1 a 3 1 ] 3 × 1 + [ b 1 b 1 ] z 2 = a 1 w 1 + b 1 a 2 = f ( z 2 ) z 3 = a 2 w 2 + b 2 a 3 = f ( z 3 )

所以可以得出正向传播过程几个公式:

(1) z i l + 1 = a 1 l w i 1 l + a 2 l w i 2 l + + a S l l w i S l l + b l (2) z l + 1 = a l w l + b l (3) a l = f ( z l )

其中, f ( ) 表示激活函数,如sigmoid函数。

现在我们已经知道了正向传播的过程,也就是说当我们训练得到参数 w 之后,就可以用正向传播通过网络来预测了。但是大家有没有想过,参数 w 是怎么训练得到的?那第一反应肯定是运用梯度下降算法。既然是用梯度下降算法来求解参数,那第一步当然就是求解梯度了。

2.求解梯度

为了方便阅读,在这个位置再插入一张上面同样的网络结结构图:

这里写图片描述

此时,我们假设网络的目标函数为误差平方函数,且暂时不管正则化,同时只考虑一个样本即:

J = 1 2 ( h w , b ( x ) y ) 2

且此处 h w , b ( x ) = a 3
由此,我们可以发现:如果 J w 11 1 求导,则 J 是关于 a 3 的函数, a 3 是关于 z 3 的函数, z 3 是关于 a 2 的函数, a 2 是关于 z 2 的函数, w 11 1 是关于 z 2 的函数。

为了更加清晰下面的求导过程,我们先来举两个例子,看看链式求导的过程(如果熟悉链式求导规则,请直接忽略)。


例1:
假设有如下函数:

f = s i n ( t ) , t = x 2 , x = 5 w f w = f t t x x w = c o s ( t ) 2 x 5 = c o s ( x 2 ) 2 x 5 = c o s ( 25 w 2 ) 10 w 5 = 50 w c o s ( 25 w 2 )

作为验证,我们直接将 t , x 带入 f 然后求导:

f = s i n ( x 2 ) = s i n ( 25 w 2 ) f w = c o s ( 25 w 2 ) 50 w = 50 w c o s ( 25 w 2 )

例2:
我们再来看一个抽象的,没有表达式得链式求导,假设有如下函数表达式:

f = g ( t ) , t = ϕ ( x + y ) , x = h ( w ) , y = μ ( w )

则我们可以画出如下关系图:
这里写图片描述
即, t f 的函数, y x 都是 t 的函数, w 分别又都是 y x 的函数,也就是说我们有两条路径可以到达 w ,所以
f w = f t t y y w + f t t x x w = f t ( t y y w + t x x w )


所以有:

J w 11 1 = J a 1 3 a 1 3 z 1 3 z 1 3 a 1 2 a 2 z 1 2 z 1 2 w 11 1 + J a 2 3 a 2 3 z 2 3 z 2 3 a 1 2 a 2 z 1 2 z 1 2 w 11 1 J w 12 1 = J a 1 3 a 1 3 z 1 3 z 1 3 a 1 2 a 2 z 1 2 z 1 2 w 12 1 + J a 2 3 a 2 3 z 2 3 z 2 3 a 1 2 a 2 z 1 2 z 1 2 w 12 1 J w 22 2 = J a 2 3 a 2 3 z 2 3 z 2 3 w 22 2

我们可以发现,当 J 对第2层的参数求导还相对不麻烦,但当 J 对第1层的参数求导的时候就做了很多重复的计算;并且这还是网络相对简单的时候,要是网络相对复杂一点,这个过程简直就是难以下手。这也是为什么神经网络在一段时间发展缓慢的原因,就是因为没有一种高效的计算梯度的方式。

3.一种高效的梯度求解办法

J w 11 1 = ( J a 1 3 a 1 3 z 1 3 z 1 3 a 1 2 a 2 z 1 2 ) z 1 2 w 11 1 + ( J a 2 3 a 2 3 z 2 3 z 2 3 a 1 2 a 2 z 1 2 ) z 1 2 w 11 1

从上面的求导公式可以看出,不管你是从哪一条路径过来,在对 w 11 1 求导之前都会先到达 z 1 2 ,即先对 z 1 2 求导之后,才会有 z 1 2 w 11 1 。也就是说,我不管你是经过什么样的路径,在对连接第 l 层第j个神经元与第 l + 1 i 个神经元的参数 w i j l 求导之前,肯定会先对 z i l + 1 求导。因此,对任意参数的求导过程,可以改写为:

(4) J w i j l = J z i l + 1 z i l + 1 w i j l = J z i l + 1 a j l

例如:

J w 11 1 = J z 1 1 + 1 z 1 1 + 1 w 11 1 = J z 1 2 z 1 2 w 11 1

所以,现在的问题变成了如何求解红色部分了,即:

J z i l + 1 = ? ? ?

从网络结构图可以, J 对任意 z i l 求导,求导路径必定会经过第 l + 1 层的所有神经元,于是有:

J z i l = J z 1 l + 1 z 1 l + 1 z i l + J z 2 l + 1 z 2 l + 1 z i l + + J z S l + 1 l + 1 z S l + 1 l + 1 z i l = k = 1 S l + 1 J z k l + 1 z k l + 1 z i l = k = 1 S l + 1 J z k l + 1 z i l ( a 1 l w k 1 l + a 2 l w k 2 l + + a S l l w k S l l + b l ) 1 = k = 1 S l + 1 J z k l + 1 z i l j = 1 S l a j l w k j l = k = 1 S l + 1 J z k l + 1 z i l j = 1 S l f ( z j l ) w k j l (5) = k = 1 S l + 1 J z k l + 1 f ( z i l ) w k i l

于是我们得到:

(6) J z i l = k = 1 S l + 1 J z k l + 1 f ( z i l ) w k i l

因此

J z i l + 1 = k = 1 S l + 2 J z k l + 2 f ( z i l + 1 ) w k i l + 1

为了便于书写和观察规律,我们引入一个中间变量 δ i l = J z i l ,则(5)得:

(7) δ i l = J z i l = k = 1 S l + 1 δ k l + 1 f ( z i l ) w k i l ( l <= L 1 )

注:之所以要 l <= L 1 ,是因为由(5)得推导过程可知, l 最大只能取到 L 1 ,第L层后面没有网络层了。

所以:

δ i L = J z i L = z i L [ 1 2 k = 1 S L ( h k ( x ) y k ) 2 ] = z i L [ 1 2 k = 1 S L ( f ( z k L ) y k ) 2 ] = [ f ( z i L ) y i ] f ( z i L ) (8) = [ a i L y i ] f ( z i L )

同时将(7)带入(4)可知:

(9) J w i j l = δ i l + 1 a j l

通过上面的所有推导,我们可以得到如下3个公式:

J w i j l = δ i l + 1 a j l δ i l = J z i l = k = 1 S l + 1 δ k l + 1 f ( z i l ) w k i l ( 0 < l L 1 ) δ i L = [ a i L y i ] f ( z i L )

且经过适量化后为:

(10) J w l = δ l + 1 ( a l ) T (11) δ l = ( w l ) T δ l + 1 f ( z l ) (12) δ L = [ a L y ] f ( z L )

符号 表示矩阵乘法;符号 表示两个矩阵相同位置的元素对应相乘

由(10)(11)(12)分析可知,欲求 J w l 的导数,必先知道 δ l + 1 ;而欲知 δ l + 1 ,必先求 δ l + 2 ,以此类推……
由此可知对于整个求导过程,一定是先求 δ L ,再求 δ L 1 ,一直到 δ 2

为了方便阅读,在这个位置再插入一张上面同样的网络结结构图:

这里写图片描述

对于这样一个网络结构,整个求导过程(不含 b l )如下:

S t e p 1 : δ 3 = [ a 3 y ] f ( z 3 ) S t e p 2 : J w 2 = δ 3 ( a 2 ) T S t e p 3 : δ 2 = ( w 2 ) T δ 3 f ( z 2 ) S t e p 4 : J w 1 = δ 2 ( a 1 ) T

于是我们终于发现了这么一个不争的事实:
1.最先求解出导数的参数一定位于第 L 1 层上(如此处的 w 2 );
2.要想求解第 l 层参数的导数,一定会用到第 l + 1 层上的中间变量 δ l + 1 (如此处求解 w 1 的导数,用到了 δ 2 );
3.整个过程是从后往前的;

所以,该过程被形象的称为反向(后向)传播算法。
另: δ l 被称为第 l 层的“残差”

一个重要的结论:
反向传播算法是用来求解梯度的!

反向传播算法是用来求解梯度的!

反向传播算法是用来求解梯度的!

重要的话说三遍,因为不少人总是把梯度下降和反向传播两个搞得稀里糊涂的。

4.总结

通过举例对平方误差目标函数反向传播算算法公式的推导,我们可以总结出更为一般的情况,即:

(13) J w l = δ l + 1 ( a l ) T (14) δ l = ( w l ) T δ l + 1 f ( z l ) (15) δ i L = J z i L = J a i L a i L z i L = J a i L f ( z i L ) z i L = J a i L f ( z i L ) (16) J b l = δ l + 1

我们可以看到,仅仅只有公式(15)才依赖于不同的目标函数;比如在交叉熵中 δ i L = a L y 推导戳此处.

关于反向传播算法的推导基本上可以告一段落了,下一篇我们将通过一个例子用python来实现,这样就会更清楚了 。

相关文章
相关标签/搜索