原始結(jié)構(gòu)的RNN還不夠處理較為復(fù)雜的序列建模問(wèn)題,它存在較為嚴(yán)重的梯度消失問(wèn)題,最直觀的現(xiàn)象就是隨著網(wǎng)絡(luò)層數(shù)增加,網(wǎng)絡(luò)會(huì)逐漸變得無(wú)法訓(xùn)練。長(zhǎng)短期記憶網(wǎng)絡(luò)(Long Short Time Memory,LSTM)正是為了解決梯度消失問(wèn)題而設(shè)計(jì)的一種特殊的RNN結(jié)構(gòu)。
深度神經(jīng)網(wǎng)絡(luò)的困擾:梯度爆炸與梯度消失
在此前的普通深度神經(jīng)網(wǎng)絡(luò)和深度卷積網(wǎng)絡(luò)的講解時(shí),圖1就是一個(gè)簡(jiǎn)單的兩層普通網(wǎng)絡(luò),但當(dāng)網(wǎng)絡(luò)結(jié)構(gòu)變深時(shí),神經(jīng)網(wǎng)絡(luò)在訓(xùn)練時(shí)碰到梯度爆炸或者梯度消失的情況。那么什么是梯度爆炸和梯度消失呢?它們又是怎樣產(chǎn)生的?
鑒于神經(jīng)網(wǎng)絡(luò)的訓(xùn)練機(jī)制,不管是哪種類(lèi)型的神經(jīng)網(wǎng)絡(luò),其訓(xùn)練都是通過(guò)反向傳播計(jì)算梯度來(lái)實(shí)現(xiàn)權(quán)重更新的。通過(guò)設(shè)定損失函數(shù),建立損失函數(shù)關(guān)于各層網(wǎng)絡(luò)輸入輸出的梯度計(jì)算,當(dāng)網(wǎng)絡(luò)訓(xùn)練開(kāi)動(dòng)起來(lái)的時(shí)候,系統(tǒng)便按照反向傳播機(jī)制來(lái)不斷更新網(wǎng)絡(luò)各層參數(shù)直到停止訓(xùn)練。但當(dāng)網(wǎng)絡(luò)層數(shù)加深時(shí),這個(gè)訓(xùn)練系統(tǒng)并不是很穩(wěn),經(jīng)常會(huì)出現(xiàn)一些問(wèn)題。其中梯度爆炸和梯度消失便是較為嚴(yán)重的兩個(gè)問(wèn)題。
所謂梯度爆炸就是在神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中,梯度變得越來(lái)越大以使得神經(jīng)網(wǎng)絡(luò)權(quán)重得到瘋狂更新的情形,這種情況很容易發(fā)現(xiàn),因?yàn)樘荻冗^(guò)大,計(jì)算更新得到的參數(shù)也會(huì)大到崩潰,這時(shí)候我們可能看到更新的參數(shù)值中有很多的NaN,這說(shuō)明梯度爆炸已經(jīng)使得參數(shù)更新出現(xiàn)數(shù)值溢出。這便是梯度爆炸的基本情況。
然后是梯度消失。與梯度爆炸相反的是,梯度消失就是在神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中梯度變得越來(lái)越小以至于梯度得不到更新的一種情形。當(dāng)網(wǎng)絡(luò)加深時(shí),網(wǎng)絡(luò)深處的誤差很難因?yàn)樘荻鹊臏p小很難影響到前層網(wǎng)絡(luò)的權(quán)重更新,一旦權(quán)重得不到有效的更新計(jì)算,神經(jīng)網(wǎng)絡(luò)的訓(xùn)練機(jī)制也就失效了。
為什么神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中梯度怎么就會(huì)變得越來(lái)越大或者越來(lái)越?。靠梢杂帽緯?shū)第一講的神經(jīng)網(wǎng)絡(luò)反向傳播推導(dǎo)公式為例來(lái)解釋。
式(11.1)~-式(11.8)是一個(gè)兩層網(wǎng)絡(luò)的反向傳播參數(shù)更新公式推導(dǎo)過(guò)程。離輸出層相對(duì)較遠(yuǎn)的是輸入到隱藏層的權(quán)重參數(shù),可以看到損失函數(shù)對(duì)于隱藏層輸出輸入到隱藏層權(quán)重和偏置的梯度計(jì)算公式,一般而言都會(huì)轉(zhuǎn)換從下一層的權(quán)重乘以激活函數(shù)求導(dǎo)后的式子。如果激活函數(shù)求導(dǎo)后的結(jié)果和下一層權(quán)重的乘積大于1或者說(shuō)遠(yuǎn)遠(yuǎn)大于1的話(huà),在網(wǎng)絡(luò)層數(shù)加深時(shí),層層遞增的網(wǎng)絡(luò)在做梯度更新時(shí)往往就會(huì)出現(xiàn)梯度爆炸的情況。如果激活函數(shù)求導(dǎo)和下一層權(quán)重的乘積小于1的話(huà),在網(wǎng)絡(luò)加深時(shí),淺層的網(wǎng)絡(luò)梯度計(jì)算結(jié)果會(huì)越來(lái)越小往往就會(huì)出現(xiàn)梯度消失的情況。所以可是說(shuō)是反向傳播的機(jī)制本身造就梯度爆炸和梯度消失這兩種不穩(wěn)定因素。例如,一個(gè)100層的深度神經(jīng)網(wǎng)絡(luò),假設(shè)每一層的梯度計(jì)算值都為1.1,經(jīng)過(guò)由輸出到輸入的反向傳播梯度計(jì)算可能最后的梯度值就變成= 13780.61234,這是一個(gè)極大的梯度值了,足以造成計(jì)算溢出問(wèn)題。若是每一層的梯度計(jì)算值為 0.9,反向傳播輸入層的梯度計(jì)算值則可能為= 0.000026561398,足夠小到造成梯度消失。本例只是一個(gè)簡(jiǎn)化的假設(shè)情況,實(shí)際反向傳播計(jì)算要更為復(fù)雜。
所以總體來(lái)說(shuō),神經(jīng)網(wǎng)絡(luò)的訓(xùn)練中梯度過(guò)大或者過(guò)小引起的參數(shù)過(guò)大過(guò)小都會(huì)導(dǎo)致神經(jīng)網(wǎng)絡(luò)失效,那我們的目的就是要讓梯度計(jì)算回歸到正常的區(qū)間范圍,不要過(guò)大也不要過(guò)小,這也是解決這兩個(gè)問(wèn)題的一個(gè)思路。
那么如何解決梯度爆炸和梯度消失問(wèn)題?梯度爆炸較為容易處理,在實(shí)際訓(xùn)練的時(shí)候?qū)μ荻冗M(jìn)行修剪即可,但是梯度消失的處理就比較麻煩了,由上述的分析我們知道梯度消失一個(gè)關(guān)鍵在于激活函數(shù)。Sigmoid激活函數(shù)本身就更容易產(chǎn)生這種問(wèn)題,所以一般而言,我們換上更加魯棒的ReLu激活函數(shù)以及給神經(jīng)網(wǎng)絡(luò)加上歸一化激活函數(shù)層(BN層),一般問(wèn)題都能得到很好的解決,但也不是任何情形下都管用,例如,RNN網(wǎng)絡(luò),具體在下文中我們?cè)僮黾刑接憽?/p>
以上便是梯度爆炸和梯度消失這兩種問(wèn)題的基本解釋?zhuān)旅嫖覀兓貧w正題,來(lái)談?wù)劚疚牡闹鹘恰狶STM。
LSTM:讓RNN具備更好的記憶機(jī)制
前面說(shuō)了很多鋪墊,全部都是為了來(lái)講LSTM。梯度爆炸和梯度消失,普通神經(jīng)網(wǎng)絡(luò)和卷積神經(jīng)網(wǎng)絡(luò)有,那么循環(huán)神經(jīng)網(wǎng)絡(luò)RNN也有嗎?必須有。而且梯度消失和梯度爆炸的問(wèn)題之于RNN來(lái)說(shuō)傷害更大。當(dāng)RNN網(wǎng)絡(luò)加深時(shí),因?yàn)樘荻认У膯?wèn)題使得前層的網(wǎng)絡(luò)權(quán)重得不到更新,RNN就會(huì)在一定程度上丟失記憶性。為此,在傳統(tǒng)的RNN網(wǎng)絡(luò)結(jié)構(gòu)基礎(chǔ)上,研究人員給出一些著名的改進(jìn)方案,因?yàn)檫@些改進(jìn)方案都脫離不了經(jīng)典的RNN架構(gòu),所以一般來(lái)說(shuō)我們也稱(chēng)這些改進(jìn)方案為RNN變種網(wǎng)絡(luò)。比較著名的就是GRU(循環(huán)門(mén)控單元)和LSTM(長(zhǎng)短期記憶網(wǎng)絡(luò))。GRU和LSTM二者結(jié)構(gòu)基本一致,但有部分不同的地方,本講以更有代表性的LSTM來(lái)進(jìn)行詳解。
在正式深入LSTM的技術(shù)細(xì)節(jié)之前,先要明確幾點(diǎn)。第一,LSTM的本質(zhì)是一種RNN網(wǎng)絡(luò)。第二,LSTM在傳統(tǒng)的RNN結(jié)構(gòu)上做了相對(duì)復(fù)雜的改進(jìn),這些改進(jìn)使得LSTM相對(duì)于經(jīng)典RNN能夠很好的解決梯度爆炸和梯度消失問(wèn)題,讓循環(huán)神經(jīng)網(wǎng)絡(luò)具備更強(qiáng)更好的記憶性能,這也是LSTM的價(jià)值所在。那咱們就來(lái)重點(diǎn)看一下LSTM的技術(shù)細(xì)節(jié)。
咱們先擺一張經(jīng)典RNN結(jié)構(gòu)與LSTM結(jié)構(gòu)對(duì)比圖,這樣能夠有一個(gè)宏觀的把握,然后再針對(duì)LSTM結(jié)構(gòu)圖中各個(gè)部分進(jìn)行拆解分析。圖2所示是標(biāo)準(zhǔn)RNN結(jié)構(gòu),圖3所示是LSTM結(jié)構(gòu)。
圖2 RNN結(jié)構(gòu)
圖3 LSTM結(jié)構(gòu)
從圖3中可以看到,相較于RNN單元,LSTM單元要復(fù)雜許多。每個(gè)LSTM單元中包含了4個(gè)交互的網(wǎng)絡(luò)層,現(xiàn)在將LSTM單元放大,并標(biāo)注上各個(gè)結(jié)構(gòu)名稱(chēng),如圖4所示。
圖4 LSTM單元
根據(jù)圖4,一個(gè)完整的LSTM單元可以用式(11.9)~(11.14)來(lái)表示,其中符號(hào)表示兩個(gè)向量合并。
現(xiàn)在我們將LSTM單元結(jié)構(gòu)圖進(jìn)行分解,根據(jù)結(jié)構(gòu)圖和公式來(lái)逐模塊解釋LSTM。
1. 記憶細(xì)胞
如圖5紅色部分所示,可以看到在LSTM單元的最上面部分有一條貫穿的箭頭直線(xiàn),這條直線(xiàn)由輸入到輸出,相較于RNN,LSTM提供了c作為記憶細(xì)胞輸入。記憶細(xì)胞提供了記憶的功能,在網(wǎng)絡(luò)結(jié)構(gòu)加深時(shí)仍能傳遞前后層的網(wǎng)絡(luò)信息。這樣貫穿的直線(xiàn)使得記憶信息在網(wǎng)絡(luò)各層之間保持下去很容易。
圖5 LSTM記憶細(xì)胞
2. 遺忘門(mén)(Forget Gate)
遺忘門(mén)的計(jì)算公式如下:
遺忘門(mén)的作用是要決定從記憶細(xì)胞c中是否丟棄某些信息,這個(gè)過(guò)程可以通過(guò)一個(gè) Sigmoid函數(shù)來(lái)進(jìn)行處理。遺忘門(mén)在整個(gè)結(jié)構(gòu)中的位置如圖11.6所示。可以看到,遺忘門(mén)接受來(lái)自輸入和上一層隱狀態(tài)的值進(jìn)行合并后加權(quán)計(jì)算處理。
3. 記憶細(xì)胞候選值和更新門(mén)
更新門(mén)(Update Gate)表示需要將什么樣的信息能存入記憶細(xì)胞中。除了計(jì)算更新門(mén)之外,還需要使用tanh計(jì)算記憶細(xì)胞的候選值。LSTM中更新門(mén)需要更加細(xì)心一點(diǎn)。記憶細(xì)胞候選值和更新門(mén)的計(jì)算公式如下:
更新門(mén)和tanh在整個(gè)結(jié)構(gòu)中的位置如圖7所示。
圖7 記憶細(xì)胞候選值和更新門(mén)
4. 記憶細(xì)胞更新
結(jié)合遺忘門(mén)、更新門(mén)、上一個(gè)單元記憶細(xì)胞值和記憶細(xì)胞候選值來(lái)共同決定和更新當(dāng)前細(xì)胞狀態(tài):
記憶細(xì)胞更新在LSTM整個(gè)結(jié)構(gòu)中位置如圖8所示:
圖8 記憶細(xì)胞更新
5. 輸出門(mén)
LSTM 提供了單獨(dú)的輸出門(mén)(Output Gate)。計(jì)算公式如下:
輸出門(mén)的位置如圖9所示。
圖9 輸出門(mén)
以上便是完整的LSTM結(jié)構(gòu)。雖然復(fù)雜,但經(jīng)過(guò)逐步解析之后也就基本清晰了。LSTM 在自然語(yǔ)言處理、問(wèn)答系統(tǒng)、股票預(yù)測(cè)等等領(lǐng)域都有著廣泛而深入的應(yīng)用。
責(zé)任編輯:haq
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4782瀏覽量
101226 -
LSTM
+關(guān)注
關(guān)注
0文章
59瀏覽量
3794
原文標(biāo)題:深入理解LSTM
文章出處:【微信號(hào):sessdw,微信公眾號(hào):三星半導(dǎo)體互動(dòng)平臺(tái)】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論