在H100發(fā)布之際,英偉達(dá)還帶來(lái)一個(gè)“重磅產(chǎn)品”——Transformer Engine。在Transformer大火之際推出這么一個(gè)產(chǎn)品,無(wú)疑是煉丹師福音。
當(dāng)時(shí)我還在猜測(cè)它會(huì)以怎么樣的一種形式呈現(xiàn)給用戶,直到最近公開了倉(cāng)庫(kù) NVIDIA/TransformerEngine
這其實(shí)就是PyTorch的一個(gè)拓展,為了利用FP8的特性,針對(duì)Transformer里面的Kernel進(jìn)行了重寫,包含了一系列LayerNorm, GeLU, ScaledSoftmax等。
使用方式也是比較簡(jiǎn)單,使用該拓展額外包的一層Module來(lái)搭建網(wǎng)絡(luò),即可,最后再包一層混合精度訓(xùn)練作用域:
importtorch importtransformer_engine.pytorchaste fromtransformer_engine.commonimportrecipe #Setdimensions. in_features=768 out_features=3072 hidden_size=2048 #Initializemodelandinputs. model=te.Linear(in_features,out_features,use_bias=True) inp=torch.randn(hidden_size,in_features,device="cuda") #創(chuàng)建FP8訓(xùn)練的配置 fp8_recipe=recipe.DelayedScaling(margin=0,interval=1,fp8_format=recipe.Format.E4M3) #FP8的autocast withte.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe): out=model(inp) loss=out.sum() loss.backward()
本篇博客就簡(jiǎn)單介紹下Transformer Engine及其對(duì)應(yīng)實(shí)現(xiàn)原理
官方文檔:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Transfromer Engine 是干啥的?
在各種以Transformer為基礎(chǔ)的語(yǔ)言模型如GPT3大火后,語(yǔ)言模型的參數(shù)量還在以指數(shù)形式增長(zhǎng):
那么優(yōu)化Transformer性能就顯得格外重要了,其中混合精度訓(xùn)練是一個(gè)很實(shí)用的技巧
在FP16下,其數(shù)據(jù)范圍還是足夠大的,因此在AMP下,我們只在最后的Loss做了一個(gè)scaling,這個(gè)步驟足以保證在整個(gè)模型運(yùn)算過(guò)程中不會(huì)產(chǎn)生溢出
而FP8相比FP16減少了更多有效位,因此不能簡(jiǎn)單地復(fù)用FP16下的策略,需要給每個(gè)FP8 Tensor單獨(dú)設(shè)置一個(gè)合適的scale factor。Transformer Engine 需要?jiǎng)討B(tài)地對(duì)輸入范圍進(jìn)行調(diào)整,如圖所示:
上圖來(lái)自H100白皮書內(nèi)(當(dāng)時(shí)我還天真的以為有一個(gè)專門的硬件做這個(gè)處理。。。)
下面我們簡(jiǎn)單看下其代碼和實(shí)現(xiàn)原理
Kernel實(shí)現(xiàn)
具體到每一個(gè)算子實(shí)現(xiàn)動(dòng)態(tài)范圍調(diào)整的原理其實(shí)很簡(jiǎn)單,通過(guò)記錄歷史的abs max值,來(lái)去調(diào)整最終縮放的范圍。
其主要的Kernel實(shí)現(xiàn)都放在了 common 目錄下,我們以gelu這個(gè)kernel為例,最終它會(huì)調(diào)用到 vectorized_pointwise.h這個(gè)文件,我們主要看 unary_kernel
unary_kernel
這個(gè)核函數(shù)模板跟常規(guī)的elementwise向量化模板是類似的。
首先會(huì)讓每個(gè)線程獲取到scale值
ComputeTypes=0; ifconstexpr(is_fp8::value){ //獲取scale值 if(scale!=nullptr)s=*scale; //將scale取倒數(shù)寫回scale_inv if(blockIdx.x==0&&threadIdx.x==0&&scale_inv!=nullptr){ reciprocal (scale_inv,s); } }
其中在循環(huán)里,線程會(huì)不斷更新他運(yùn)算結(jié)果的最大值,并且最終運(yùn)算結(jié)果要乘上scale值:
//實(shí)際運(yùn)算發(fā)生 ComputeTypetemp=OP(val,p); ifconstexpr(is_fp8::value){ __builtin_assume(max>=0); max=fmaxf(fabsf(temp),max); //縮放 temp=temp*s; }
當(dāng)Kernel主體運(yùn)算完畢后,再也warp為單位做一個(gè)reduce_max,獲取到線程束內(nèi)的最大值,再通過(guò)atomicMax原子操作,不斷更新全局最大值:
ifconstexpr(is_fp8::value){ /*warptileamaxreduce*/ max=reduce_max (max,warp_id); if(threadIdx.x==0&&amax!=nullptr){ static_assert(std::is_same ::value); //更新全局最大值 atomicMaxFloat(amax,max); } }
其他layernorm等Kernel也是諸如類似的邏輯,這里就不再展開了
(1) DelayedScaling
從前面的示例代碼我們可以看到一個(gè)比較重要的API是DelayedScaling,我們可以根據(jù)官方文檔查看各個(gè)參數(shù)含義:
margin 計(jì)算scale的偏移量
interval 控制計(jì)算scale factor的頻率
fp8_format 使用FP8的格式,F(xiàn)P8有E4M3和E5M2,但是現(xiàn)在不支持純E5M2的格式訓(xùn)練
amax_history_len 記錄abs maxval的歷史窗口大小
amax_compute_algo 在窗口里選擇absmax的算法,'max'則是選擇歷史窗口里最大值,'most_recent'則是選擇近期的值,當(dāng)然你也可以傳一個(gè)自定義的函數(shù)
相關(guān)代碼為:
@torch.jit.script def_default_get_amax( amax_history:torch.Tensor, amax_compute_algo:str, )->Tuple[torch.Tensor,torch.Tensor]: """Defaultfunctiontoobtainamaxfromhistory.""" ifamax_compute_algo=="max": amax=torch.max(amax_history,dim=0).values else:#amax_compute_algo=="most_recent" amax=amax_history[0] amax_history=update_amax_history(amax_history) returnamax_history,amax
scaling_factor_compute_algo 計(jì)算scale factor的算法
@torch.jit.script def_default_sf_compute( amax:torch.Tensor, scale:torch.Tensor, fp8_max:float, margin:int, )->torch.Tensor: """Defaultfunctiontoconvertamaxtoscalingfactor.""" exp=torch.floor(torch.log2(fp8_max/amax))-margin sf=torch.round(torch.pow(2,torch.abs(exp))) sf=torch.where(amax>0.0,sf,scale) sf=torch.where(torch.isfinite(amax),sf,scale) sf=torch.where(exp0,?1?/?sf,?sf) ????return?sf
override_linear_precision 由3個(gè)bool值,分別控制fprop前向,dgrad,wgrad三個(gè)矩陣乘是否用更高的精度來(lái)計(jì)算,默認(rèn)都為False
(2) TransformerEngineBaseModule
相關(guān)的Kernel除了要完成自己的計(jì)算任務(wù),也得實(shí)時(shí)維護(hù)amax這些值,因此也需要對(duì)應(yīng)修改nn.Module,這里TransformerEngine繼承了nn.Module,并且增加了一些buffer維護(hù)的機(jī)制,這些buffer用于存儲(chǔ)動(dòng)態(tài)scale的信息:
classTransformerEngineBaseModule(torch.nn.Module,ABC): def__init__(self)->None: ... self.fp8=False self.fp8_meta={} self.fp8_meta["fp8_group"]=None self.fp8_meta["recipe"]=get_default_fp8_recipe() deffp8_init(self,num_gemms:int=1)->None: """Initializefp8relatedmetadataandtensorsduringfprop.""" #Iffp8isn'tenabled,turnoffandreturn. ifnotis_fp8_enabled(): self.fp8=False return #FP8isalreadyenabledandrecipeisthesame,don'tdoanything. ifself.fp8andget_fp8_recipe()==self.fp8_meta["recipe"]: return #SetFP8,recipe,andotherFP8metadata self.fp8=True self.fp8_meta["recipe"]=get_fp8_recipe() self.fp8_meta["num_gemms"]=num_gemms self.fp8_meta["fp8_group"]=get_fp8_group() #SetFP8_MAXpertensoraccordingtorecipe self.fp8_meta["fp8_max_fwd"]=self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_bwd"]=self.fp8_meta["recipe"].fp8_format.value.max_bwd #Allocatescalesandamaxes self.init_fp8_meta_tensors()
而相關(guān)Module如LayerNormMLP繼承該Module,并且傳入fp8_meta信息更新:
classLayerNormMLP(TransformerEngineBaseModule): defforward(...): out=_LayerNormMLP.apply( ..., self.fp8, self.fp8_meta, )
總結(jié)
大致瀏覽完其實(shí)思路不復(fù)雜,但感覺還是FP8技術(shù)的不穩(wěn)定,整個(gè)項(xiàng)目還是加入了很多限制。得益于PyTorch靈活的外部擴(kuò)展形式,只要不去觸碰框架底層運(yùn)行機(jī)制,僅僅在算子層面上的修改還是相當(dāng)簡(jiǎn)單。雖然不具備通用性,但是運(yùn)算主體就這幾個(gè)算子,為了性能也是可以接受的
審核編輯:湯梓紅
-
NVIDIA
+關(guān)注
關(guān)注
14文章
5080瀏覽量
103826 -
英偉達(dá)
+關(guān)注
關(guān)注
22文章
3854瀏覽量
92076 -
Transformer
+關(guān)注
關(guān)注
0文章
146瀏覽量
6056 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13378 -
H100
+關(guān)注
關(guān)注
0文章
32瀏覽量
311
原文標(biāo)題:詳解 NVIDIA H100 TransformerEngine
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
英偉達(dá)a100和h100哪個(gè)強(qiáng)?英偉達(dá)A100和H100的區(qū)別
NVIDIA發(fā)布新一代產(chǎn)品—NVIDIA H100
![<b class='flag-5'>NVIDIA</b>發(fā)布新一代產(chǎn)品—<b class='flag-5'>NVIDIA</b> <b class='flag-5'>H100</b>](https://file.elecfans.com/web2/M00/37/56/pYYBAGI65L-AXzR-AAKw8jWrMOw369.png)
GTC2022大會(huì)黃仁勛:NVIDIA H100的5項(xiàng)突破性創(chuàng)新
![GTC2022大會(huì)黃仁勛:<b class='flag-5'>NVIDIA</b> <b class='flag-5'>H100</b>的5項(xiàng)突破性創(chuàng)新](https://file.elecfans.com/web2/M00/37/51/poYBAGI65tCAP4S-AAJFLDRHwnI099.png)
GTC2022大會(huì)亮點(diǎn):NVIDIA發(fā)布全新AI計(jì)算系統(tǒng)—DGX H100
![GTC2022大會(huì)亮點(diǎn):<b class='flag-5'>NVIDIA</b>發(fā)布全新AI計(jì)算系統(tǒng)—DGX <b class='flag-5'>H100</b>](https://file.elecfans.com/web2/M00/37/8D/pYYBAGI8F2qAAnReAAPwaUIyH9E149.png)
NVIDIA發(fā)布DGX H100系統(tǒng) 羅德與施瓦茨提供O-RAN無(wú)線電單元方案
NVIDIA發(fā)布最新Hopper架構(gòu)的H100系列GPU和Grace CPU超級(jí)芯片
藍(lán)海大腦服務(wù)器全力支持NVIDIA H100 GPU
用NVIDIA H100 CNX構(gòu)建人工智能系統(tǒng)
![用<b class='flag-5'>NVIDIA</b> <b class='flag-5'>H100</b> CNX構(gòu)建人工智能系統(tǒng)](https://file.elecfans.com/web2/M00/3A/B3/poYBAGJFTzaARrHmAAEDPnPpxmY345.png)
利用NVIDIA HGX H100加速計(jì)算數(shù)據(jù)中心平臺(tái)應(yīng)用
![利用<b class='flag-5'>NVIDIA</b> HGX <b class='flag-5'>H100</b>加速計(jì)算數(shù)據(jù)中心平臺(tái)應(yīng)用](https://file.elecfans.com//web2/M00/3E/22/pYYBAGJfddyALufVAAB34X0yeGk024.png)
評(píng)論