關於梯度消失與梯度爆炸

Qi Fong
8 min readMay 14, 2021

相信有做過深度學習的人都看過兩個現象:

一、loss下不去,但是卻也不夠小,導致在training的時候就train不起來,這有可能是碰到了網路中的權重幾乎沒有辦法更新,卻也fit不好。

二、不但沒有收斂,loss還盪來盪去的。

這到底發生什麼事呢?

首先講deep learning有什麼作用及會碰到什麼問題。

一個deep的神經網路,藉由非線性的activation function,例如sigmoid, tanh, relu,甚至LeakyRelu,能夠起到特徵映射的作用,能夠逐層的把特徵擷取或關聯出來,這比一般線性模型有更好的預測或特徵擷取能力。

訓練階段會初始化你的神經網路的連結,也叫做weights,至於初始化的方法有很多種,不管你是要用uniform distribution, normal distribution, 或是xavier 方法,這些初始化的方法往往還要考慮你用了什麼activation function。

梯度消失與梯度爆炸

但不管怎麼樣,更新weights的方法通常是用梯度下降(gradient descent)和back-propagation這項技術,也就是計算完預測值和真實的label的loss,藉由計算導函數,由最後輸出端的神經網路逐層向輸入端更新權重的過程。

這中間會碰到鏈鎖律的操作,導致越靠近輸入端的更新有越多的導數相乘,這就碰到一個問題了,若是這些導數絕對值都是小於1的話,那結果就會指數速度的靠近0,使得輸入端的weights很難更新,但靠近輸出端的weights會更新的比較快,這種情況叫做梯度消失。

梯度爆炸則是,當許多導數絕對值都大於1時,那會指數的膨脹,導致更新的步伐跨太大,導致loss一直在震盪。

這幾乎是傳統深度神經網路一定會碰到的問題,至於該怎麼解決呢?現在常見的有以下幾種方法

改善梯度消失的方法

一、改activation function

由於是因為導數的關係產生了梯度消失與爆炸,這在sigmoid和tanh這種在飽和區域的微分值會趨近於0的函數很常見,所以後來有新的啟動函數ReLu來克服這個問題,ReLu的圖如下。

從圖形不難觀察出,只要輸入值大於0,就直接輸出原本的輸入,反之直接輸出0,這個機制有很多好處,例如

  1. 運算快速,只要判斷正負號就好
  2. 正的區域微分值就是1

這兩點都是很強大的好處,特別是第2點,因為這個特性,他可以避免掉導數相乘後,指數的爆炸或消失。

當然Relu也有其他的問題,例如假設他死了(dead),就很難再復活,因此後來有LeakyRelu解決這個問題,但實作上還是Relu用的比較多。

二、改用ResNet網路架構

ResNet全名叫做Residual Nets,圖像化機制如下,可參考此paper

寫成數學式的話則是h(x)=F(x)+x

如此對其求微分值,必定保證會有右端的1存在,這能有效的改善梯度消失的問題,那這樣會梯度爆炸嗎?我們知道大於1的連續相乘就有可能會發生,但是這個問題相對好解決,你可以直接上一個上限值給他就可以,但是下限值卻很難設定。

三、改用memory cell

由於memory cell本身有forget gate的機制,直觀上來說,他有更精細的操作,讓他可以自己學習是否要捨棄或納入這個梯度,這使得梯度消失的情況受到改善,但是改善並不是說能完全解決這個問題,是加入了這個機制,有機會減緩這個現象。

梯度爆炸改善方法

一、直接設梯度上限值

這是最直觀和暴力的方法,既然你知道有這個現象,那你就設定小一點。

二、正則化

這邊講的正則化主要是L1和L2正則,數學式如下:

L1正則:

L2正則:

直觀上的解釋是,當你正則化後,會使得權重值縮小了,將整個梯度給拉小了,如此就不會讓梯度膨脹得太劇烈。

對於梯度消失與爆炸還有一個能同時改善的方法,那就是加入

batch normalization機制。

簡單講一下batch normalization的想法。

每次training的時候,他會考慮整個mini-batch在各個維度的平均數和標準差,然後對每一個維度做正規化,數學式如下

i表示維度的編號

這樣做有許多好處,簡單講三點:

  1. 他讓進到下一層的輸入分佈更加均勻和穩定,同時也能解決internal covariate shift[註1]的問題。
  2. 正規化後,會讓平均數回到0,將在某些啟動函數會飽和的數值拉回來。
  3. 使得權重初始化變得比較簡單,例如針對Relu的啟動函數,就要很小心的設計權重初始化,不然很有可能很快讓他的狀態變成dead。

以上就是關於梯度消失、梯度爆炸的介紹,其實個別都還有很多東西可以講,比如batch normalization就可以單獨講一篇,啟動函數的優缺點也可以再開一篇,但是這邊盡量縮減到針對梯度消失和爆炸的問題做討論。

註:

  1. Internal covariate shift: 當上一層的輸出分佈發生改變時,導致下層的網路要重新適應,導致很難收斂的現象。可參考原文

Reference:

  1. https://zhuanlan.zhihu.com/p/70834555
  2. https://medium.com/%E8%BB%9F%E9%AB%94%E4%B9%8B%E5%BF%83/deep-learning-residual-leaning-%E8%AA%8D%E8%AD%98resnet%E8%88%87%E4%BB%96%E7%9A%84%E5%86%A0%E5%90%8D%E5%BE%8C%E7%B9%BC%E8%80%85resnext-resnest-6bedf9389ce
  3. https://zhuanlan.zhihu.com/p/42833949

我的其他文章:

--

--