0x1. OpenAI Triton介紹閱讀
這里來(lái)看官方的介紹:https://openai.com/research/triton ,從官方的介紹中我們可以看到OpenAI Triton的產(chǎn)生動(dòng)機(jī)以及它的目標(biāo)是什么,還可以看到一些經(jīng)典算法的實(shí)現(xiàn)例子展示。
這里的標(biāo)題是 Introducing Triton: Open-source GPU programming for neural networks ,翻譯就是《介紹 Triton:用于神經(jīng)網(wǎng)絡(luò)的開(kāi)源 GPU 編程語(yǔ)言》。然后下面的一句話翻譯過(guò)來(lái)是:我們發(fā)布了 Triton 1.0,這是一種開(kāi)源的類 Python 編程語(yǔ)言,它使得沒(méi)有 CUDA 經(jīng)驗(yàn)的研究人員能夠編寫(xiě)高效的 GPU 代碼——大多數(shù)情況下,其效能與專家所能編寫(xiě)的代碼相當(dāng)。這里指出了triton的目的,就是讓編寫(xiě)cuda kernrl變得更簡(jiǎn)單。接下來(lái)就逐步看一下介紹里的具體內(nèi)容,為了更加準(zhǔn)確這里會(huì)截圖對(duì)應(yīng)的原文然后放上我的翻譯或者理解。
這里的意思是Triton可以使得用戶用較少的努力就寫(xiě)出一個(gè)達(dá)到硬件峰值性能的kernel,比如使用 Triton 可以編寫(xiě) FP16 矩陣乘法的核函數(shù),其性能能夠匹配 cuBLAS,并且這個(gè)代碼不超過(guò)25行。然后研究者已經(jīng)用Triton開(kāi)發(fā)了一些高效的實(shí)現(xiàn),和功能相同的Torch實(shí)現(xiàn)相比,性能可以達(dá)到兩倍提升。后面一段就是強(qiáng)調(diào)了使用CUDA來(lái)把一些原始的PyTorch實(shí)現(xiàn)寫(xiě)一個(gè)算子一般會(huì)更加高效,但是這個(gè)難度不小,并且目前已有工作也不能很好覆蓋這種情況,所以O(shè)penAI Triton誕生。
這里講的是GPU編程的挑戰(zhàn),現(xiàn)代 GPU 的架構(gòu)大致可以分為三個(gè)主要部分——DRAM、SRAM 和 ALU。在優(yōu)化 CUDA 代碼時(shí),必須考慮到這些組件:
從 DRAM 的內(nèi)存?zhèn)鬏敱仨毢喜⒊纱笮褪聞?wù),以利用現(xiàn)代內(nèi)存接口的大總線寬度(內(nèi)存合并訪問(wèn))。
數(shù)據(jù)必須在重復(fù)使用前手動(dòng)存儲(chǔ)到 SRAM 中,并進(jìn)行管理來(lái)最小化bank conflict。
計(jì)算必須仔細(xì)地進(jìn)行劃分和調(diào)度,不僅是在流式多處理器(SMs)之間,還包括在其內(nèi)部,以促進(jìn)指令/線程級(jí)并行性,并利用專用的 ALU(例如,Tensor Cores)。
考慮所有這些因素可能對(duì)于擁有多年經(jīng)驗(yàn)的資深 CUDA 程序員來(lái)說(shuō)都是一個(gè)挑戰(zhàn)。Triton 的目的是完全自動(dòng)化這些優(yōu)化,以便開(kāi)發(fā)者能夠更好地專注于他們并行代碼的高層邏輯。Triton 旨在廣泛適用,因此不會(huì)自動(dòng)在流式多處理器(SMs)之間調(diào)度工作——留下一些重要的算法考慮(例如,tiling,跨 SM 同步)由開(kāi)發(fā)者自行決定。
然后給了一個(gè)表格展示cuda的編譯器和triton的區(qū)別。
在所有可用的領(lǐng)域特定語(yǔ)言和即時(shí)編譯器中,Triton可能和Numba最相似:kernel被定義為一個(gè)裝飾過(guò)的函數(shù),并以不同的 program_id 并行啟動(dòng)在所謂的網(wǎng)格實(shí)例上。然而,正如下面的代碼片段所示,相似之處僅此而已:Triton 通過(guò)對(duì)塊上的操作來(lái)暴露實(shí)例內(nèi)部的并行性——這些小數(shù)組的尺寸是二的冪次方——而不是單指令多線程(SIMT)執(zhí)行模型。這樣做,Triton 有效地抽象出了所有與 CUDA 線程塊內(nèi)部并發(fā)相關(guān)的問(wèn)題(例如,內(nèi)存合并、共享內(nèi)存同步/沖突、Tensor Cores調(diào)度)。
注意,Triton 的即時(shí)編譯器將 X 和 Y 視為指針而不是張量;我們認(rèn)為保留對(duì)內(nèi)存訪問(wèn)的低級(jí)控制對(duì)于處理更復(fù)雜的數(shù)據(jù)結(jié)構(gòu)(例如,塊稀疏張量)是重要的。重要的是,這種特定的 softmax 實(shí)現(xiàn)在整個(gè)標(biāo)準(zhǔn)化過(guò)程中將 X 的行保留在 SRAM 中,這在適用時(shí)最大化了數(shù)據(jù)重用(約 <32K 列)。這與 PyTorch 的內(nèi)部 CUDA 代碼不同,后者使用臨時(shí)內(nèi)存使其更具通用性,但顯著更慢(如下所示)。這里的關(guān)鍵不是 Triton 本質(zhì)上更好,而是它簡(jiǎn)化了專用kernel的開(kāi)發(fā),這些內(nèi)核可能比在通用庫(kù)中找到的內(nèi)核快得多。
Torch(v1.9)JIT編譯器的較低性能凸顯了從高級(jí)張量操作序列自動(dòng)生成 CUDA 代碼的難度。
這里是說(shuō)Triton大概只需要25行Python代碼就可以實(shí)現(xiàn)一個(gè)接近峰值的矩陣乘法。(后面有專門的一大節(jié)講這個(gè)代碼的原理)代碼如下:
@triton.jit defmatmul(A,B,C,M,N,K,stride_am,stride_ak, stride_bk,stride_bn,stride_cm,stride_cn, **META): #extractmetaparameters BLOCK_M,GROUP_M=META['BLOCK_M'],META['GROUP_M'] BLOCK_N=META['BLOCK_N'] BLOCK_K=META['BLOCK_K'] #programsaregroupedtogethertoimproveL2hitrate _pid_m=tl.program_id(0) _pid_n=tl.program_id(1) pid_m=_pid_m//GROUP_M pid_n=(_pid_n*GROUP_M)+(_pid_m%GROUP_M) #rm(resp.rn)denotesarangeofindices #forrows(resp.col)ofC rm=pid_m*BLOCK_M+tl.arange(0,BLOCK_M) rn=pid_n*BLOCK_N+tl.arange(0,BLOCK_N) #rkdenotesarangeofindicesforcolumns #(resp.rows)ofA(resp.B) rk=tl.arange(0,BLOCK_K) #thememoryaddressesofelementsinthefirstblockof #AandBcanbecomputedusingnumpy-stylebroadcasting A=A+(rm[:,None]*stride_am+rk[None,:]*stride_ak) B=B+(rk[:,None]*stride_bk+rn[None,:]*stride_bn) #initializeanditerativelyupdateaccumulator acc=tl.zeros((BLOCK_M,BLOCK_N),dtype=tl.float32) forkinrange(K,0,-BLOCK_K): a=tl.load(A) b=tl.load(B) #blocklevelmatrixmultiplication acc+=tl.dot(a,b) #incrementpointerssothatthenextblocksofAandB #areloadedduringthenextiteration A+=BLOCK_K*stride_ak B+=BLOCK_K*stride_bk #fuseleakyReLUifdesired #acc=tl.where(acc>=0,acc,alpha*acc) #writebackresult C=C+(rm[:,None]*stride_cm+rn[None,:]*stride_cn) mask=(rm[:,None]
手寫(xiě)矩陣乘法kernel的一個(gè)重要優(yōu)勢(shì)是,它們可以根據(jù)需要定制,以適應(yīng)輸入(例如,切片)和輸出(例如,LeakyReLU)的融合轉(zhuǎn)換。如果沒(méi)有像 Triton 這樣的系統(tǒng),沒(méi)有出色的 GPU 編程專長(zhǎng)的開(kāi)發(fā)者將無(wú)法進(jìn)行矩陣乘法內(nèi)核的定制修改。
這里是說(shuō)Triton 的良好性能源于一個(gè)以 Triton-IR 為中心的模塊化系統(tǒng)架構(gòu),Triton-IR 是一個(gè)基于 LLVM 的中間表示,在這個(gè)系統(tǒng)中,多維值塊(這個(gè)是MLIR的概念)是一等公民。GPT@triton.jit 裝飾器的工作原理是遍歷提供的 Python 函數(shù)的抽象語(yǔ)法樹(shù)(AST),以便使用常見(jiàn)的 SSA 構(gòu)建算法即時(shí)生成 Triton-IR。然后,編譯器后端會(huì)簡(jiǎn)化、優(yōu)化并自動(dòng)并行化所產(chǎn)生的 IR 代碼,再將其轉(zhuǎn)換為高質(zhì)量的 LLVM-IR —— 最終生成 PTX —— 以在近期的 NVIDIA GPU 上執(zhí)行。目前不支持 CPU 和 AMD GPU,但我們歡迎社區(qū)貢獻(xiàn),旨在解決這一限制。
我們發(fā)現(xiàn),通過(guò) Triton-IR 使用塊級(jí)別程序表示,使我們的編譯器能夠自動(dòng)執(zhí)行各種重要的程序優(yōu)化。例如,可以通過(guò)觀察計(jì)算密集型塊級(jí)操作(例如,tl.dot)的操作數(shù),自動(dòng)將數(shù)據(jù)暫存到共享內(nèi)存中,并使用標(biāo)準(zhǔn)的活性分析技術(shù)進(jìn)行分配和同步。另一方面,如下所示,Triton 程序可以高效且自動(dòng)地并行化,既可以(1)通過(guò)并發(fā)執(zhí)行不同的kernel實(shí)例在流式多處理器(SMs)間并行,也可以(2)通過(guò)分析每個(gè)塊級(jí)操作的迭代空間,并在不同的 SIMD 單元間適當(dāng)分配,從而在 SMs 內(nèi)部并行。
0x2. 教程1 Vector Addition閱讀
意思是這一節(jié)教程會(huì)介紹Triton編程模型定義kernel的基本寫(xiě)法,此外也會(huì)介紹一下怎么實(shí)現(xiàn)一個(gè)良好的benchmark測(cè)試。下面來(lái)看計(jì)算kernel實(shí)現(xiàn),我把注釋改成中文了:
importtorch importtriton importtriton.languageastl @triton.jit defadd_kernel(x_ptr,#*指針*,指向第一個(gè)輸入向量。 y_ptr,#*指針*,指向第二個(gè)輸入向量。 output_ptr,#*指針*,指向輸出向量。 n_elements,#向量的大小。 BLOCK_SIZE:tl.constexpr,#每個(gè)程序應(yīng)處理的元素?cái)?shù)量。 #注意:`constexpr`這樣可以被用作形狀值。 ): #這里有多個(gè)“程序”處理不同的數(shù)據(jù)。我們?cè)谶@里識(shí)別我們是哪一個(gè)程序: pid=tl.program_id(axis=0)#我們使用一維啟動(dòng)網(wǎng)格,所以軸是0。 #該程序?qū)⑻幚韽某跏紨?shù)據(jù)偏移的輸入。 #例如,如果你有一個(gè)長(zhǎng)度為256的向量和塊大小為64,那么程序 #將分別訪問(wèn)元素[0:64,64:128,128:192,192:256]。 #注意偏移量是一個(gè)指針列表: block_start=pid*BLOCK_SIZE offsets=block_start+tl.arange(0,BLOCK_SIZE) #創(chuàng)建一個(gè)掩碼以防止內(nèi)存操作越界訪問(wèn)。 mask=offsets
這里還聲明了一個(gè)輔助函數(shù)來(lái)(1)分配z張量,(2)使用適當(dāng)?shù)木W(wǎng)格/塊大小排隊(duì)上面的kernel:
defadd(x:torch.Tensor,y:torch.Tensor): #我們需要預(yù)分配輸出。 output=torch.empty_like(x) assertx.is_cudaandy.is_cudaandoutput.is_cuda n_elements=output.numel() #SPMD啟動(dòng)網(wǎng)格表示并行運(yùn)行的kernel實(shí)例的數(shù)量。 #它類似于CUDA啟動(dòng)網(wǎng)格。它可以是Tuple[int],也可以是Callable(metaparameters)->Tuple[int]。 #在這種情況下,我們使用一個(gè)1D網(wǎng)格,其大小是塊的數(shù)量: grid=lambdameta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),) #注意: #-每個(gè)torch.tensor對(duì)象都隱式地轉(zhuǎn)換為指向其第一個(gè)元素的指針。 #-使用`triton.jit`裝飾的函數(shù)可以用一個(gè)啟動(dòng)網(wǎng)格索引來(lái)獲得可調(diào)用的GPU內(nèi)核。 #-不要忘記將元參數(shù)作為關(guān)鍵字參數(shù)傳遞。 add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024) #我們返回一個(gè)指向z的句柄,但是因?yàn)閌torch.cuda.synchronize()`還沒(méi)有被調(diào)用,所以這時(shí)kernel仍然 #在異步運(yùn)行。 returnoutput
我們現(xiàn)在可以使用上面定義的函數(shù)來(lái)計(jì)算兩個(gè)torch.tensor對(duì)象的逐元素求和,并測(cè)試其正確性:
torch.manual_seed(0) size=98432 x=torch.rand(size,device='cuda') y=torch.rand(size,device='cuda') output_torch=x+y output_triton=add(x,y) print(output_torch) print(output_triton) print(f'Themaximumdifferencebetweentorchandtritonis' f'{torch.max(torch.abs(output_torch-output_triton))}')
輸出:
tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0') tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0') Themaximumdifferencebetweentorchandtritonis0.0
我們可以對(duì)不同大小的向量進(jìn)行自定義操作的性能基準(zhǔn)測(cè)試,以了解它相對(duì)于PyTorch的表現(xiàn)如何。為了簡(jiǎn)化操作,Triton提供了一系列內(nèi)置工具,使我們能夠簡(jiǎn)潔地繪制出自定義操作在不同問(wèn)題規(guī)模下的性能圖表。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'],#用作繪圖x軸的參數(shù)名。 x_vals=[2**iforiinrange(12,28,1)],#`x_name`的不同可能值。 x_log=True,#x軸是對(duì)數(shù)的。 line_arg='provider',#其值對(duì)應(yīng)于圖中不同線條的參數(shù)名。 line_vals=['triton','torch'],#`line_arg`的可能值。 line_names=['Triton','Torch'],#線條的標(biāo)簽名稱。 styles=[('blue','-'),('green','-')],#線條樣式。 ylabel='GB/s',#y軸的標(biāo)簽名稱。 plot_name='vector-add-performance',#繪圖的名稱。也用作保存繪圖的文件名。 args={},#不在`x_names`和`y_name`中的函數(shù)參數(shù)的值。 )) defbenchmark(size,provider): x=torch.rand(size,device='cuda',dtype=torch.float32) y=torch.rand(size,device='cuda',dtype=torch.float32) quantiles=[0.5,0.2,0.8] ifprovider=='torch': ms,min_ms,max_ms=triton.testing.do_bench(lambda:x+y,quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:add(x,y),quantiles=quantiles) gbps=lambdams:12*size/ms*1e-6 returngbps(ms),gbps(max_ms),gbps(min_ms)
gbps = lambda ms: 12 * size / ms * 1e-6這里的12表示的是數(shù)據(jù)讀寫(xiě)的bit,因?yàn)橛衳和y以及z的存在,所以是3*4=12bit?,F(xiàn)在可以運(yùn)行上面的裝飾函數(shù)了。傳遞 print_data=True 參數(shù)來(lái)查看性能數(shù)據(jù),傳遞 show_plots=True 參數(shù)來(lái)繪制圖表,和/或傳遞 save_path='/path/to/results/' 參數(shù)來(lái)將它們連同原始CSV數(shù)據(jù)一起保存到磁盤上:
benchmark.run(print_data=True,show_plots=True)
可以看到,對(duì)于elementwise任務(wù),Triton的性能幾乎和PyTorch持平,但是Triton寫(xiě)起來(lái)很簡(jiǎn)單。0x3. 教程2 Fused Softmax閱讀
在這個(gè)教程中,我們將編寫(xiě)一個(gè)融合的softmax操作,這個(gè)操作對(duì)于特定類型的矩陣來(lái)說(shuō)比PyTorch的原生操作要快得多:那些行的大小可以放入GPU的SRAM中的矩陣。
通過(guò)這樣做,我們將學(xué)習(xí)到:
kernel融合對(duì)于帶寬受限操作的好處。
Triton中的reduce操作符。
動(dòng)機(jī)
自定義GPU kernel用于逐元素加法在教育上是有價(jià)值的,但在實(shí)際應(yīng)用中可能作用有限。讓我們考慮一個(gè)簡(jiǎn)單的(數(shù)值穩(wěn)定的)softmax操作的情況:
importtorch importtriton importtriton.languageastl @torch.jit.script defnaive_softmax(x): """使用原生pytorch計(jì)算X的逐行softmax 我們減去最大元素是為了避免溢出。Softmax對(duì)這種偏移是不變的。 """ #讀取MN個(gè)元素;寫(xiě)入M個(gè)元素 x_max=x.max(dim=1)[0] #讀取MN+M個(gè)元素;寫(xiě)入MN個(gè)元素 z=x-x_max[:,None] #讀取MN個(gè)元素;寫(xiě)入MN個(gè)元素 numerator=torch.exp(z) #讀取MN個(gè)元素;寫(xiě)入M個(gè)元素 denominator=numerator.sum(dim=1) #讀取MN+M個(gè)元素;寫(xiě)入MN個(gè)元素 ret=numerator/denominator[:,None] #總計(jì):讀取5MN+2M個(gè)元素;寫(xiě)入3MN+2M個(gè)元素 returnret
計(jì)算kernel
我們的softmax kernel的工作方式如下:每個(gè)程序加載輸入矩陣X的一行,對(duì)其進(jìn)行歸一化處理,然后將結(jié)果寫(xiě)回到輸出Y中。需要注意的是,Triton的一個(gè)重要限制是每個(gè)塊必須包含2的冪次方個(gè)元素,因此如果我們想處理任何可能的輸入形狀,我們需要在內(nèi)部對(duì)每行進(jìn)行“pad”以及對(duì)內(nèi)存訪問(wèn)操作進(jìn)行保護(hù)(也就是防止越界):
@triton.jit defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_cols,BLOCK_SIZE:tl.constexpr): #softmax的各行是獨(dú)立的,所以我們?cè)谶@些行上進(jìn)行并行處理 row_idx=tl.program_id(0) #步長(zhǎng)代表我們需要增加多少指針來(lái)前進(jìn)1行 row_start_ptr=input_ptr+row_idx*input_row_stride #塊大小是大于n_cols的下一個(gè)2的冪次,因此我們可以將每一行放入單個(gè)塊中 col_offsets=tl.arange(0,BLOCK_SIZE) input_ptrs=row_start_ptr+col_offsets #將行加載到SRAM中,使用掩碼因?yàn)锽LOCK_SIZE可能大于n_cols row=tl.load(input_ptrs,mask=col_offsets
解析來(lái)創(chuàng)建一個(gè)輔助函數(shù),該函數(shù)為任何給定的輸入張量排隊(duì)執(zhí)行kernel并且設(shè)置了啟動(dòng)參數(shù)。
defsoftmax(x): n_rows,n_cols=x.shape #塊大小是大于`x`中列數(shù)的最小2的冪 BLOCK_SIZE=triton.next_power_of_2(n_cols) #我們可以使用的另一個(gè)技巧是要求編譯器通過(guò)增加每行分布的warp數(shù)(`num_warps`)來(lái)使用更多的線程。 #在下一個(gè)教程中,你將看到如何以更自然的方式自動(dòng)調(diào)整這個(gè)值,這樣你就不必自己想出手動(dòng)啟發(fā)式方法。 num_warps=4 ifBLOCK_SIZE>=2048: num_warps=8 ifBLOCK_SIZE>=4096: num_warps=16 #分配輸出 y=torch.empty_like(x) #排隊(duì)執(zhí)行內(nèi)核。一維啟動(dòng)網(wǎng)格很簡(jiǎn)單:我們有每行一個(gè)內(nèi)核實(shí)例 #輸入矩陣 softmax_kernel[(n_rows,)]( y, x, x.stride(0), y.stride(0), n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) returny
這里是驗(yàn)證Triton實(shí)現(xiàn)的fuse softmax和PyTorch的naive實(shí)現(xiàn)等價(jià),顯然他們是等價(jià)的。BenchMark
這里設(shè)定矩陣的行數(shù)為固定的4096來(lái)做benchmark。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'],#用作繪圖x軸的參數(shù)名 x_vals=[128*iforiinrange(2,100)],#`x_name`的不同可能值 line_arg='provider',#其值對(duì)應(yīng)于圖中不同線條的參數(shù)名 line_vals=[ 'triton', 'torch-native', 'torch-jit', ],#`line_arg`的可能值 line_names=[ "Triton", "Torch(原生)", "Torch(jit)", ],#線條的標(biāo)簽名稱 styles=[('blue','-'),('green','-'),('green','--')],#線條樣式 ylabel="GB/s",#y軸的標(biāo)簽名稱 plot_name="softmax-performance",#繪圖的名稱。也用作保存繪圖的文件名。 args={'M':4096},#不在`x_names`和`y_name`中的函數(shù)參數(shù)的值 )) defbenchmark(M,N,provider): x=torch.randn(M,N,device='cuda',dtype=torch.float32) quantiles=[0.5,0.2,0.8] ifprovider=='torch-native': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.softmax(x,axis=-1),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:softmax(x),quantiles=quantiles) ifprovider=='torch-jit': ms,min_ms,max_ms=triton.testing.do_bench(lambda:naive_softmax(x),quantiles=quantiles) gbps=lambdams:2*x.nelement()*x.element_size()*1e-9/(ms*1e-3) returngbps(ms),gbps(max_ms),gbps(min_ms) benchmark.run(show_plots=True,print_data=True)
這里提到雖然Triton實(shí)現(xiàn)的softmax性能更好并且易于理解和維護(hù),但PyTorch的torch.softmax則更加通用。0x4. 教程3 Matrix Multiply閱讀
首先教程指出這里就是要寫(xiě)一個(gè)Block級(jí)別的矩陣乘法,然后這里會(huì)涉及到多維度的指針操作,程序重排以更好的命中l(wèi)2 cache以及自動(dòng)調(diào)優(yōu)。動(dòng)機(jī)
矩陣乘法是大多數(shù)現(xiàn)代高性能計(jì)算系統(tǒng)的關(guān)鍵構(gòu)建塊。它們眾所周知難以優(yōu)化,因此它們的實(shí)現(xiàn)通常由硬件供應(yīng)商自己作為所謂的“內(nèi)核庫(kù)”(例如,cuBLAS)的一部分來(lái)完成。不幸的是,這些庫(kù)通常是專有的,無(wú)法輕易地定制以適應(yīng)現(xiàn)代深度學(xué)習(xí)工作負(fù)載的需求(例如,融合激活函數(shù))。在這個(gè)教程中,你將學(xué)習(xí)如何使用Triton自己實(shí)現(xiàn)高效的矩陣乘法,這種方法易于定制和擴(kuò)展。
大致來(lái)說(shuō),我們將要編寫(xiě)的內(nèi)核將實(shí)現(xiàn)以下塊級(jí)算法來(lái)乘以一個(gè) (M, K) 矩陣和一個(gè) (K, N) 矩陣:
#Doinparallel forminrange(0,M,BLOCK_SIZE_M): #Doinparallel forninrange(0,N,BLOCK_SIZE_N): acc=zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=float32) forkinrange(0,K,BLOCK_SIZE_K): a=A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K] b=B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N] acc+=dot(a,b) C[m:m+BLOCK_SIZE_M,n:n+BLOCK_SIZE_N]=acc
其中,雙重嵌套的for循環(huán)的每次迭代都由一個(gè)專用的Triton program實(shí)例執(zhí)行。
計(jì)算kernel
上述算法實(shí)際上在Triton中相當(dāng)容易實(shí)現(xiàn)。主要的難點(diǎn)來(lái)自于在內(nèi)循環(huán)中計(jì)算必須讀取A和B塊的內(nèi)存位置。為此,我們需要多維指針運(yùn)算。
指針運(yùn)算
對(duì)于一個(gè)2D Tensor X,X[i, j]的內(nèi)存位置為&X[i, j] = X + i*stride_xi + j*stride_xj。因此,對(duì)于A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]的塊指針可以用下面的偽代碼定義:
&A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]=a_ptr+(m:m+BLOCK_SIZE_M)[:,None]*A.stride(0)+(k:k+BLOCK_SIZE_K)[None,:]*A.stride(1); &B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]=b_ptr+(k:k+BLOCK_SIZE_K)[:,None]*B.stride(0)+(n:n+BLOCK_SIZE_N)[None,:]*B.stride(1);
這意味著A和B塊的指針可以在Triton中初始化,比如 k=0 如下代碼所示。另外注意,我們需要一個(gè)額外的模運(yùn)算來(lái)處理M不是BLOCK_SIZE_M的倍數(shù)或N不是BLOCK_SIZE_N的倍數(shù)的情況,在這種情況下,我們可以用一些無(wú)用的值填充數(shù)據(jù),這些值不會(huì)對(duì)結(jié)果產(chǎn)生影響。對(duì)于K維度,我們稍后將使用掩碼加載語(yǔ)義來(lái)處理。
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)
然后在內(nèi)循環(huán)中按如下方式更新:
a_ptrs+=BLOCK_SIZE_K*stride_ak; b_ptrs+=BLOCK_SIZE_K*stride_bk;
如上所述,每個(gè)program實(shí)例計(jì)算一個(gè) [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計(jì)算順序是很重要的,因?yàn)樗鼤?huì)影響我們程序的L2緩存命中率,不幸的是,一個(gè)簡(jiǎn)單的行優(yōu)先順序是不夠的。
pid=triton.program_id(0); grid_m=(M+BLOCK_SIZE_M-1)//BLOCK_SIZE_M; grid_n=(N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N; pid_m=pid/grid_n; pid_n=pid%grid_n;
L2 Cache優(yōu)化
如上所述,每個(gè)程序?qū)嵗?jì)算一個(gè) [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計(jì)算順序很重要,因?yàn)樗鼤?huì)影響我們程序的L2緩存命中率,不幸的是,一個(gè)簡(jiǎn)單的行主序排序是不夠的。
一個(gè)可能的解決方案是以一種促進(jìn)數(shù)據(jù)重用的順序啟動(dòng)塊。這可以通過(guò)在切換到下一列之前將塊在GROUP_M行的super group中分組來(lái)實(shí)現(xiàn):
#程序ID pid=tl.program_id(axis=0) #沿M軸的程序ID數(shù)量 num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) #沿N軸的程序ID數(shù)量 num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) #組中的程序數(shù)量 num_pid_in_group=GROUP_SIZE_M*num_pid_n #該程序所在組的ID group_id=pid//num_pid_in_group #組中第一個(gè)程序的行ID first_pid_m=group_id*GROUP_SIZE_M #如果`num_pid_m`不能被`GROUP_SIZE_M`整除,最后一個(gè)組更小 group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) #*在組內(nèi)*,程序按列主序排列 #程序在*啟動(dòng)網(wǎng)格*中的行ID pid_m=first_pid_m+(pid%group_size_m) #程序在*啟動(dòng)網(wǎng)格*中的列ID pid_n=(pid%num_pid_in_group)//group_size_m
例如,在下面的矩陣乘法中,每個(gè)矩陣由9個(gè)塊乘以9個(gè)塊組成,我們可以看到,如果我們按行主序計(jì)算輸出,我們需要將90個(gè)塊加載到SRAM中以計(jì)算前9個(gè)輸出塊,但如果我們按grouped ordering進(jìn)行計(jì)算,我們只需要加載54個(gè)塊。
在實(shí)際應(yīng)用中,這可以在某些硬件架構(gòu)上提高我們矩陣乘法內(nèi)核的性能超過(guò)10%(例如,在A100上從220提升到245 TFLOPS)。
L2 Cache優(yōu)化原理補(bǔ)充講解
上面的group oredering的訪問(wèn)代碼比較難理解,這里來(lái)更詳細(xì)的解析一下。
#程序ID pid=tl.program_id(axis=0) #沿M軸的程序ID數(shù)量 num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) #沿N軸的程序ID數(shù)量 num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
這里的num_pid_m和num_pid_n就是求分別要在M和N方向循環(huán)多少次。
然后上面圖中的黑色數(shù)字其實(shí)就可以理解為program id,我們可以看到program id增加的方向其實(shí)就代表了遍歷的ordering,對(duì)于row major來(lái)說(shuō)就是在行方向上順序遍歷,而對(duì)于group ordering來(lái)說(shuō)就是按照一個(gè)BLOCK_SIZE_M*BLOCK_SIZE_N這么大的一個(gè)小組來(lái)遍歷。其實(shí)這段代碼就是完成group ordering的遍歷:
num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m
以上面圖來(lái)看,num_pid_m=3,num_pid_n=3,num_pid_in_group=group_id * GROUP_SIZE_M=9*3=27,也就是下面的紅色框里面的program個(gè)數(shù),從名字也可以看出來(lái)這個(gè)紅色框劃分的區(qū)域也是一個(gè)group。
group_id 就表示當(dāng)前的這次 "循環(huán)", 是在第幾個(gè)紅色框里,以program 0為例,這里為group_id = pid // num_pid_in_group=0//27=0。而first_pid_m 代表當(dāng)前 group 中的第一個(gè)黃色program在全局的M維度上是第幾個(gè)program ,這里為first_pid_m = group_id * GROUP_SIZE_M=0,group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)這里是考慮到最后一個(gè)group可能占不滿數(shù)據(jù)(存在padding),所以就做一個(gè)截?cái)嗵幚怼?/p>
pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m
這兩行代碼計(jì)算當(dāng)前的program處理的黃色小塊坐標(biāo)([pid_m, pid_n]),pid_m這行是在行方向上移動(dòng),pid_n這行則是保證在上面的紅色框里面一定是一列一列來(lái)訪問(wèn)的。
作為對(duì)比,在Row-major的方法中,訪問(wèn)方式應(yīng)該是這樣的:
pid_m=pid//num_pid_n pid_n=pid%num_pid_n
計(jì)算最后的結(jié)果
有了上面的鋪墊,我們就可以計(jì)算最終的結(jié)果了,下面的代碼展示了完整的Triton 矩陣乘法kernel實(shí)現(xiàn)。
#使用`triton.jit`裝飾的函數(shù)可以通過(guò)`triton.autotune`裝飾器進(jìn)行自動(dòng)調(diào)優(yōu),該裝飾器包括: #-一系列定義不同配置的`triton.Config`對(duì)象, #這些配置涉及元參數(shù)(例如`BLOCK_SIZE_M`)和編譯選項(xiàng)(例如`num_warps`)的不同設(shè)置 #-一個(gè)自動(dòng)調(diào)優(yōu)*關(guān)鍵字*,其值的變化將觸發(fā)對(duì)所有 #提供的配置的評(píng)估 @triton.autotune( configs=[ #每個(gè)Config定義了一組特定的配置參數(shù)和編譯選項(xiàng) triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ], key=['M','N','K'],#自動(dòng)調(diào)優(yōu)關(guān)鍵字 ) @triton.jit defmatmul_kernel( #指向矩陣的指針 a_ptr,b_ptr,c_ptr, #矩陣維度 M,N,K, #步長(zhǎng)變量表示在特定維度上移動(dòng)1個(gè)元素時(shí)指針增加的量。 #例如`stride_am`是將`a_ptr`增加多少以獲取下一行的元素(A有M行)。 stride_am,stride_ak,#A矩陣的步長(zhǎng) stride_bk,stride_bn,#B矩陣的步長(zhǎng) stride_cm,stride_cn,#C矩陣的步長(zhǎng) #元參數(shù) BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr,# GROUP_SIZE_M:tl.constexpr,# ACTIVATION:tl.constexpr#激活函數(shù) ): """用于計(jì)算矩陣乘法C=AxB的內(nèi)核。 A的形狀為(M,K),B的形狀為(K,N),C的形狀為(M,N)。 """ #----------------------------------------------------------- #將程序ID`pid`映射到它應(yīng)該計(jì)算的C矩陣的塊。 #這是以groupedordering完成的,以促進(jìn)L2數(shù)據(jù)重用。 #詳細(xì)解釋看一節(jié) pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #為A和B的第一個(gè)塊創(chuàng)建指針。 #我們將在K方向移動(dòng)時(shí)推進(jìn)這個(gè)指針并累加 #`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針 #`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針 #有關(guān)詳細(xì)信息,請(qǐng)參閱上方“指針?biāo)阈g(shù)”部分 offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #迭代以計(jì)算C矩陣的一個(gè)塊。 #我們將累加到一個(gè)`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊 #的fp32值以獲得更高的精度。 #`accumulator`在循環(huán)后會(huì)轉(zhuǎn)換回fp16。 accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. a=tl.load(a_ptrs,mask=offs_k[None,:]=0,x,0.01*x)
我們現(xiàn)在可以創(chuàng)建一個(gè)方便的封裝函數(shù),它只需要兩個(gè)輸入張量,并且會(huì):(1)檢查任何形狀約束;(2)分配輸出;(3)啟動(dòng)上述kernel。
defmatmul(a,b,activation=""): #Checkconstraints. asserta.shape[1]==b.shape[0],"Incompatibledimensions" asserta.is_contiguous(),"MatrixAmustbecontiguous" assertb.is_contiguous(),"MatrixBmustbecontiguous" M,K=a.shape K,N=b.shape #Allocatesoutput. c=torch.empty((M,N),device=a.device,dtype=a.dtype) #1Dlaunchkernelwhereeachblockgetsitsownprogram. grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),) matmul_kernel[grid]( a,b,c,# M,N,K,# a.stride(0),a.stride(1),# b.stride(0),b.stride(1),# c.stride(0),c.stride(1),# ACTIVATION=activation# ) returnc
計(jì)算過(guò)程的補(bǔ)充說(shuō)明
上面的《L2 Cache優(yōu)化原理補(bǔ)充講解》這一節(jié)明確了kernel的group ordering的訪問(wèn)方式以及實(shí)現(xiàn),現(xiàn)在來(lái)看對(duì)于當(dāng)前的program實(shí)例具體是怎么計(jì)算的。現(xiàn)在以計(jì)算C中的第一個(gè)Block的(0, 0)為例子,它需要從A和B分別加載9個(gè)黃色的小塊數(shù)據(jù)相乘并累加最后得到C中的(0, 0)位置結(jié)果。如下圖所示:
下面的代碼先把program實(shí)例當(dāng)前要處理A和B的第一個(gè)Block加載上來(lái):
#---------------------------------------------------------- #為A和B的第一個(gè)塊創(chuàng)建指針。 #我們將在K方向移動(dòng)時(shí)推進(jìn)這個(gè)指針并累加 #`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針 #`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針 #有關(guān)詳細(xì)信息,請(qǐng)參閱上方“指針?biāo)阈g(shù)”部分 offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)
這里的a_ptr 是整個(gè) A 矩陣第一個(gè)元素的地址,offs_am和offs_bn表示當(dāng)前的program id在M維度和K維度的坐標(biāo),這個(gè)坐標(biāo)是一個(gè)list,用tl.arange(0, BLOCK_SIZE_K)來(lái)獲取。
得到 M 維度 和 K 維度的坐標(biāo)后, 就可以讓它們各自和 M 維度 和 K 維度的 stride 相乘, 然后和 a_ptr 相加, 就可以得到 A 矩陣 9 個(gè) block 中第一個(gè) block 中每個(gè)元素的地址了。 b_ptr也是同理。
最后一部分就是累加了,這里會(huì)在K維度上進(jìn)行累加,每次計(jì)算輸出的一個(gè)塊。
#迭代以計(jì)算C矩陣的一個(gè)塊。 #我們將累加到一個(gè)`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊 #的fp32值以獲得更高的精度。 #`accumulator`在循環(huán)后會(huì)轉(zhuǎn)換回fp16。 accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. a=tl.load(a_ptrs,mask=offs_k[None,:]
這行代碼a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)考慮到 K 可能不能被 BLOCK_SIZE_K 整除, 到每一行最后一個(gè) block 的時(shí)候, 實(shí)際大小是不足 BLOCK_SIZE_K 的,所以需要把超出的那部分元素mask掉。
最后這部分代碼是把當(dāng)前的算子和LeakyReLU激活函數(shù)進(jìn)行融合:
#當(dāng)累加器仍然是FP32時(shí),可以融合任意激活函數(shù) ifACTIVATION=="leaky_relu": accumulator=leaky_relu(accumulator) c=accumulator.to(tl.float16)
單元測(cè)試
Benchmark
這里使用一個(gè)方陣來(lái)對(duì)比Triton實(shí)現(xiàn)的matmul kernel和cublas的matmul kernel的性能。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#用作圖表x軸的參數(shù)名 x_vals=[128*iforiinrange(2,33)],#`x_name`的不同可能值 line_arg='provider',#其值對(duì)應(yīng)于圖表中不同線條的參數(shù)名 #`line_arg`的可能值 line_vals=['cublas','triton'], #線條的標(biāo)簽名稱 line_names=["cuBLAS","Triton"], #線條樣式 styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#y軸的標(biāo)簽名稱 plot_name="matmul-performance",#圖表的名稱,也用作保存圖表的文件名。 args={},#其他參數(shù) )) defbenchmark(M,N,K,provider): #初始化張量 a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8]#分位數(shù) #如果提供者是cublas ifprovider=='cublas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) #如果提供者是triton ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) #性能計(jì)算函數(shù) perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) #運(yùn)行基準(zhǔn)測(cè)試,展示圖表和打印數(shù)據(jù) benchmark.run(show_plots=True,print_data=True)
可以看到基于Triton實(shí)現(xiàn)的矩陣乘kernel性能大體可以和高度優(yōu)化的cuBlas持平。
審核編輯:劉清
-
sram
+關(guān)注
關(guān)注
6文章
768瀏覽量
114899 -
多處理器
+關(guān)注
關(guān)注
0文章
22瀏覽量
8986 -
Cache
+關(guān)注
關(guān)注
0文章
129瀏覽量
28441 -
python
+關(guān)注
關(guān)注
56文章
4808瀏覽量
85053 -
OpenAI
+關(guān)注
關(guān)注
9文章
1149瀏覽量
6729
原文標(biāo)題:【BBuf的CUDA筆記】十三,OpenAI Triton 入門筆記一
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論