BERT是目前最強(qiáng)大的NLP預(yù)訓(xùn)練模型,也是工業(yè)界目前最耗時(shí)的應(yīng)用,計(jì)算量遠(yuǎn)高于ImageNet。谷歌的研究人員提出新的優(yōu)化器,使用1024塊TPU,將BERT的訓(xùn)練時(shí)間從3天成功縮短到76分鐘,提速 65.2 倍!
去年,谷歌發(fā)布了最強(qiáng)預(yù)訓(xùn)練模型 BERT,宣告了NLP領(lǐng)域的一項(xiàng)重大突破。
BERT 在 33 億文本的語(yǔ)料上訓(xùn)練語(yǔ)言模型,再分別在不同的下游任務(wù)上微調(diào),在11個(gè)不同的 NLP 任務(wù)均得到了目前為止最好的結(jié)果。
不過,在 33 億文本的語(yǔ)料上預(yù)訓(xùn)練一個(gè) BERT 模型的成本是非常大的,谷歌用了 16 個(gè)自己的 TPU 集群(一共 64 塊 TPU)來訓(xùn)練大號(hào)版本的 BERT,一共花了約4天的時(shí)間。
如此巨大的訓(xùn)練成本,讓普通研究者難以嘗試自己去訓(xùn)練一個(gè)BERT。
有沒有辦法加快BERT的訓(xùn)練呢?近日,來自Google、UC Berkeley、UCLA的幾位研究人員提出新的優(yōu)化器——LAMB優(yōu)化器,將訓(xùn)練的batch size推到硬件的極限,使用 TPU Pod ( 1024塊 TPUv3 芯片),將BERT的訓(xùn)練時(shí)間從3天縮短到了76分鐘!
論文地址:
https://arxiv.org/pdf/1904.00962.pdf
其中一作尤洋(Yang You)來自UC Berkeley,這項(xiàng)工作于他在Google Brain實(shí)習(xí)期間完成。
接下來,新智元帶來對(duì)這篇論文的譯介:
加快深度神經(jīng)網(wǎng)絡(luò)最有效的方法
大批量訓(xùn)練(large-batch training)是加快大型分布式系統(tǒng)中深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練的關(guān)鍵。然而,large-batch訓(xùn)練是很困難的,因?yàn)樗鼤?huì)產(chǎn)生一種泛化差距(generalization gap)。直接優(yōu)化通常會(huì)導(dǎo)致測(cè)試集的準(zhǔn)確性下降。
BERT是一種最先進(jìn)的深度學(xué)習(xí)模型,建立在用于語(yǔ)言理解的深度雙向transformers之上。對(duì)BERT來說,當(dāng)擴(kuò)大批大小(例如超過8192)時(shí),以前的large-batch訓(xùn)練技術(shù)效果并不好。BERT的預(yù)訓(xùn)練也需要很長(zhǎng)時(shí)間才能完成(使用16個(gè)TPUv3芯片大約需要3天)。
為了解決這個(gè)問題,我們提出了LAMB優(yōu)化器,它幫助我們將批大小擴(kuò)大到65536,而不會(huì)丟失準(zhǔn)確性。
LAMB是一個(gè)通用的優(yōu)化器,適用于小批量和大批量,并且除了學(xué)習(xí)率外不需要超參數(shù)調(diào)優(yōu)?;€BERT-Large模型需要100萬(wàn)次迭代才能完成預(yù)訓(xùn)練,而batch size為65536/32768的LAMB只需要8599次迭代。我們將batch size推到TPUv3 pod的內(nèi)存上限,可以在76分鐘內(nèi)完成BERT訓(xùn)練(表1)。
表1:我們使用SQuAD-v1的F1 score作為精度指標(biāo)。F1的基線成績(jī)是由BERT的公共github提供的預(yù)訓(xùn)練模型(BERT- large)實(shí)現(xiàn)的(截止到2019年2月1日)。我們?cè)趯?shí)驗(yàn)中使用tpuv3。我們使用了與基線相同的設(shè)置:總epochs的前9/10使用序列長(zhǎng)度128,最后1/10使用序列長(zhǎng)度512。所有的實(shí)驗(yàn)運(yùn)行相同數(shù)量的epochs。
深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練是十分耗時(shí)的。目前,減少訓(xùn)練時(shí)間最有效的方法是使用多個(gè)芯片(如CPU、GPU和TPU)來并行化SGD變體的優(yōu)化過程。由于前向傳播和反向傳播中不同層之間的數(shù)據(jù)依賴關(guān)系,使得跨層的并行化效率并不高。相反,研究人員在每次迭代中并行化小批量中的數(shù)據(jù)點(diǎn)。如果確定了訓(xùn)練的epochs的數(shù)量,那么線性地增大batch size意味著會(huì)線性地減少迭代次數(shù)(即更新權(quán)重的次數(shù))。為了最小化訓(xùn)練時(shí)間,最大化batch size將是理想的。
然而,大批量的訓(xùn)練是困難的。例如,使用大小為512的batch size訓(xùn)練在ImageNet上訓(xùn)練AlexNet,能實(shí)現(xiàn)80%以上的top-5測(cè)試精度。但將batch size擴(kuò)大到4096之后,直接訓(xùn)練可能只能獲得50% ~ 60%的top 5精度。
Keskar等人(10)認(rèn)為在大批量訓(xùn)練中存在一個(gè)泛化差距(generalization gap)。Hoffer等人(6)認(rèn)為,訓(xùn)練時(shí)間越長(zhǎng),泛化差距越小。然而,訓(xùn)練時(shí)間過長(zhǎng)意味著進(jìn)行大批量訓(xùn)練就沒有好處了。
因此,大批量訓(xùn)練的目標(biāo)是在一定數(shù)量的epochs內(nèi)達(dá)到可觀的精度。通過設(shè)計(jì)一系列的學(xué)習(xí)率計(jì)劃表,研究者已經(jīng)可以將ImageNet訓(xùn)練的batch size擴(kuò)大到32K,并且精度損失較小。據(jù)我們所知,Ying et al.實(shí)現(xiàn)了目前最快的ImageNet訓(xùn)練速度,并且達(dá)到了76+%的top-1精度。通過使用LARS優(yōu)化器,將batch size擴(kuò)展到32K,,Ying等人使用TPUv3 Pod,在2.2分鐘內(nèi)完成了ResNet-50的ImageNet訓(xùn)練。(最新,富士通研究院刷新了這一速度,將ImageNet訓(xùn)練時(shí)間降到74.7秒)
BERT是目前最先進(jìn)的深度學(xué)習(xí)語(yǔ)言模型。BERT建立在用于語(yǔ)言理解的深度雙向transformers之上。對(duì)BERT來說,當(dāng)將batch size擴(kuò)大到非常大時(shí)(例如超過8192),以前的large-batch訓(xùn)練技術(shù)效果并不好。BERT的預(yù)訓(xùn)練也需要很長(zhǎng)時(shí)間才能完成(使用16個(gè)TPUv3芯片大約需要3天)。
為了擴(kuò)大BERT的batch size,本文提出LAMB優(yōu)化器。LAMB支持自適應(yīng)element-wise updating和精確的逐層修正(layer-wise correction)。
LAMB是一個(gè)適用于小批量和大批量的通用優(yōu)化器。用戶只需要調(diào)整學(xué)習(xí)率,不需要調(diào)其他超參數(shù)。使用LAMB,我們可以將BERT預(yù)訓(xùn)練的批大小擴(kuò)大到64K,而不會(huì)丟失準(zhǔn)確性。
BERT預(yù)訓(xùn)練包括兩個(gè)階段:
(1)前9/10的epochs使用128的序列長(zhǎng)度;
(2)后1/10 epochs使用512的序列長(zhǎng)度。
baseline需要100萬(wàn)次迭代來完成BERT預(yù)訓(xùn)練,但我們只需要8599次迭代,這使我們能夠?qū)ERT訓(xùn)練時(shí)間從3天減少到76分鐘。
我們將批大小推到了TPU Pod的硬件極限。批大小大于32768(序列長(zhǎng)度為512)的話將耗盡內(nèi)存。批大小大于65536(序列長(zhǎng)度為128)則不會(huì)帶來任何加速。我們的優(yōu)化器可以將批大小擴(kuò)大到128k,甚至更大。由于硬件限制,序列長(zhǎng)度為512的設(shè)置下,我們?cè)谂笮∵_(dá)到32768時(shí)停下,在序列長(zhǎng)度為128的設(shè)置下,批大小達(dá)到65536時(shí)停止。
本文中所有的BERT模型都指BERT-Large模型。為了進(jìn)行公平的比較,本文所有的實(shí)驗(yàn)都運(yùn)行相同數(shù)量的epochs(即固定數(shù)量的浮點(diǎn)運(yùn)算)。我們的結(jié)果如表1所示。
LAMB優(yōu)化器
LAMB的全稱是Layer-wise Adaptive Moments optimizer for Batch training。
BERT訓(xùn)練的基線使用權(quán)重衰減的Adam作為優(yōu)化器,這是Adam優(yōu)化器的一個(gè)變體。另一個(gè)成功應(yīng)用于大批量卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練的自適應(yīng)優(yōu)化器是LARS。
這些優(yōu)化器啟發(fā)我們提出了新的優(yōu)化器,用于大批量BERT訓(xùn)練。我們提出的LAMB優(yōu)化器的概述如算法1所示。
實(shí)驗(yàn)和結(jié)果
常規(guī)訓(xùn)練
TPU是浮點(diǎn)運(yùn)算的強(qiáng)大計(jì)算硬件。我們?cè)谒械膶?shí)驗(yàn)中都使用了TPUv3。TPUv3 Pod有1024個(gè)芯片,可以為混合精度計(jì)算提供超過100 petaflops的性能。我們的結(jié)果如表1所示。基線模型在預(yù)訓(xùn)練時(shí)使用Wikipedia和BooksCorpus數(shù)據(jù)集。
我們使用了與原始BERT模型相同的數(shù)據(jù)集,即Wikipedia和BooksCorpus,分別有2.5B和8億單詞。原始BERT模型的作者首先以128的序列長(zhǎng)度進(jìn)行了900k次迭代訓(xùn)練,然后以512的序列長(zhǎng)度進(jìn)行了100k迭代訓(xùn)練。
16個(gè)TPUv3芯片的總訓(xùn)練時(shí)間約為3天。我們使用SQuAD-v1的F1分?jǐn)?shù)作為精度指標(biāo)。F1得分越高,準(zhǔn)確度越高。斯坦福問答數(shù)據(jù)集(SQuAD)是一個(gè)閱讀理解數(shù)據(jù)集,包含眾包工作者從維基百科的文章中提出的問題,每一個(gè)問題的答案都是對(duì)應(yīng)閱讀文章的一段文字,或者該問題無法回答。我們從BERT的公開GitHub庫(kù)上下載了預(yù)訓(xùn)練好的模型。
使用作者提供的腳本,baseline的F1得分為90.395。在我們的代碼中,我們使用了BERT的作者提供的數(shù)據(jù)集和基線模型,只修改了優(yōu)化器。通過使用LAMB優(yōu)化器,我們能夠在批大小為32768的15625次迭代中獲得91.460的F1分?jǐn)?shù)(序列長(zhǎng)度為128的14063次迭代和序列長(zhǎng)度為512的1562次迭代)。
我們把訓(xùn)練時(shí)間從3天減少到100分鐘左右。我們將批大小推到了TPU Pod的硬件極限。批大小大于32768時(shí)(序列長(zhǎng)度為512)將導(dǎo)致TPU Pod耗盡內(nèi)存。
我們實(shí)現(xiàn)了76.7%的弱擴(kuò)展效率(49.1倍的加速,64倍的計(jì)算資源)。由于我們?cè)赥PU Pod上使用同步數(shù)據(jù)并行來進(jìn)行分布式訓(xùn)練,因此在互連上傳輸梯度會(huì)帶來通信開銷。梯度的大小與訓(xùn)練后的模型相同。
Mixed-Batch訓(xùn)練
如前所述,BERT預(yù)訓(xùn)練包括兩個(gè)階段:
(1)前9/10的epoch使用128的序列長(zhǎng)度,
(2)最后1/10的epoch使用512的序列長(zhǎng)度。
對(duì)于第二階段,由于內(nèi)存限制,TPUv3 Pod上的最大批大小為32768,因此我們將第二階段在批大小達(dá)到32768時(shí)停止。
對(duì)于第一階段,由于內(nèi)存限制,TPUv3 Pod上的最大批大小是131072。但是,當(dāng)我們將批大小從65536增加到131072時(shí),并沒有看到加速,因此我們?cè)诘谝浑A段批大小達(dá)到65536時(shí)停止。
此前,Smith等人也研究了混合批訓(xùn)練。但是,他們?cè)谟?xùn)練中增大了批大小,而我們減小了批大小。
我們能夠從頭到尾充分利用硬件資源。Smith等人的研究只在最后階段充分利用了硬件資源。增加批大小可以warm-up和穩(wěn)定優(yōu)化過程,但是減小批大小會(huì)給優(yōu)化過程帶來混亂,導(dǎo)致訓(xùn)練不收斂。
在實(shí)驗(yàn)中,我們發(fā)現(xiàn)了一種有助于穩(wěn)定第二階段優(yōu)化的方法。由于我們切換到一個(gè)不同的優(yōu)化問題,有必要重新warm-up優(yōu)化過程。在第二階段,我們沒有降低學(xué)習(xí)率,而是將學(xué)習(xí)率從零開始增加(re-warm-up)。
通過這些改變,我們只需要8599次迭代,可以在76分鐘左右完成BERT訓(xùn)練,實(shí)現(xiàn)了101.8%的弱縮放效率(weak scaling efficiency),提速65.2倍,利用了64倍的計(jì)算資源。
結(jié)論
Large batch技術(shù)是加快神經(jīng)網(wǎng)絡(luò)深度訓(xùn)練的關(guān)鍵。在本文中,我們提出了支持adaptive element-wise updating和layer-wise correction的LAMB優(yōu)化器。LAMB是一個(gè)通用的優(yōu)化器,適用于小批量和大批量。通過使用LAMB,我們可以將BERT預(yù)訓(xùn)練的batch size擴(kuò)展到64K,而不會(huì)丟失準(zhǔn)確性。我們將BERT的訓(xùn)練時(shí)間從3天減少到76分鐘左右,并將批大小推到了TPU Pod的硬件極限。我們正在研究LAMB優(yōu)化器的理論分析。
-
谷歌
+關(guān)注
關(guān)注
27文章
6203瀏覽量
106100 -
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4785瀏覽量
101246 -
訓(xùn)練模型
+關(guān)注
關(guān)注
1文章
36瀏覽量
3889
原文標(biāo)題:BERT訓(xùn)練猛提速!谷歌新研究將BERT預(yù)訓(xùn)練時(shí)間從3天縮短到76分鐘
文章出處:【微信號(hào):AI_era,微信公眾號(hào):新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
想使ADCEXT1和ADCEXT2的采樣時(shí)間間隔縮短到最小,應(yīng)該怎么做?
PLL鎖定時(shí)間從4.5ms縮短到360μs的手動(dòng)方法
好奇~!谷歌的 Edge TPU 專用 ASIC 旨在將機(jī)器學(xué)習(xí)推理能力引入邊緣設(shè)備
請(qǐng)問如何將光纖斷點(diǎn)定位偏差縮短到最?。?/a>
使用JN5189和Micro MAC庫(kù)進(jìn)行開發(fā),如何將時(shí)間縮短到1.5毫秒以下嗎?
鋰電池充電時(shí)間急速縮短 只需10分鐘
Tesla 超級(jí)充電站將大幅升級(jí),充電時(shí)間縮短到5-10分鐘
FPGA在實(shí)時(shí)基因組測(cè)序計(jì)算大顯身手,把測(cè)序時(shí)間從30小時(shí)縮短到26分鐘!
FPGA能在實(shí)時(shí)基因組測(cè)序計(jì)算中大顯身手,大大縮短時(shí)間
1024塊TPU在燃燒!將BERT預(yù)訓(xùn)練模型的訓(xùn)練時(shí)長(zhǎng)從3天縮減到了76分鐘
AI協(xié)助將8年藥物研發(fā)時(shí)間縮短到46天!
如何將PLL鎖定時(shí)間從4.5毫秒縮短到360微秒
![如何<b class='flag-5'>將</b>PLL鎖定<b class='flag-5'>時(shí)間</b><b class='flag-5'>從</b>4.5毫秒<b class='flag-5'>縮短到</b>360微秒](https://file.elecfans.com/web1/M00/CA/8E/pIYBAF-JRx6AYWF5AAQ2WjGPxuQ771.png)
微軟一直在努力縮短更新Win10系統(tǒng)所需時(shí)間
嵌入式AI簡(jiǎn)報(bào) |特斯拉發(fā)布AI訓(xùn)練芯片Dojo D1
![嵌入式AI簡(jiǎn)報(bào) |特斯拉發(fā)布AI<b class='flag-5'>訓(xùn)練</b>芯片Dojo D1](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
評(píng)論