注意力機(jī)制的掩碼允許我們發(fā)送不同長度的批次數(shù)據(jù)一次性的發(fā)送到transformer中。在代碼中是通過將所有序列填充到相同的長度,然后使用“attention_mask”張量來識(shí)別哪些令牌是填充的來做到這一點(diǎn),本文將詳細(xì)介紹這個(gè)掩碼的原理和機(jī)制。
我們先介紹下如果不使用掩碼,是如何運(yùn)行的。這里用GPT-2每次使用一個(gè)序列來執(zhí)行推理,因?yàn)槊看沃挥幸粋€(gè)序列,所以速度很慢:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
context = tokenizer('It will rain in the', return_tensors='pt')
prediction = gpt2.generate(**context, max_length=10)
tokenizer.decode(prediction[0])
# prints 'It will rain in the morning, and the rain'
在顯存允許的情況下,使用批處理輸入的速度更快,因?yàn)槲覀冊(cè)谝淮瓮评淼倪^程可以同時(shí)處理多個(gè)序列。對(duì)許多樣本執(zhí)行推理要快得多,但也稍微復(fù)雜一些,下面是使用transformer庫進(jìn)行推理的代碼:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
sentences = ["It will rain in the",
"I want to eat a big bowl of",
"My dog is"]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
output_sequences = gpt2.generate(**inputs)
for seq in output_sequences:
print(tokenizer.decode(seq))
transformer庫幫我們處理了很多細(xì)節(jié),我們現(xiàn)在詳細(xì)的介紹它里面到底做了什么。
我們將令牌輸入到語言模型中,如GPT-2和BERT,作為張量進(jìn)行推理。張量就像一個(gè)python列表,但有一些額外的特征和限制。比如說,對(duì)于一個(gè)2+維的張量,該維中的所有向量必須是相同的長度。例如,
from torch import tensor
tensor([[1,2], [3,4]]) # ok
tensor([[1,2], [3]]) # error!
當(dāng)我們對(duì)輸入進(jìn)行標(biāo)記時(shí),它將被轉(zhuǎn)換為序列的張量,每個(gè)整數(shù)對(duì)應(yīng)于模型詞表中的一個(gè)項(xiàng)。以下是GPT-2中的標(biāo)記化示例:
如果我們想在輸入中包含第二個(gè)序列:
因?yàn)檫@兩個(gè)序列有不同的長度,所以不能把它們組合成一個(gè)張量。這時(shí)就需要用虛擬標(biāo)記填充較短的序列,以便每個(gè)序列具有相同的長度。因?yàn)槲覀兿胱屇P屠^續(xù)向序列的右側(cè)添加,我們將填充較短序列的左側(cè)。
這就是注意力掩碼的一個(gè)應(yīng)用。注意力掩碼告訴模型哪些令牌是填充的,在填充令牌的位置放置0,在實(shí)際令牌的位置放置1?,F(xiàn)在我們理解了這一點(diǎn),讓我們逐行查看代碼。
tokenizer.padding_side = "left"
這一行告訴標(biāo)記器從左邊開始填充(默認(rèn)是右邊),因?yàn)樽钣疫厴?biāo)記的logits將用于預(yù)測未來的標(biāo)記。
tokenizer.pad_token = tokenizer.eos_token
這一行指定將使用哪個(gè)令牌進(jìn)行填充。選擇哪一個(gè)并不重要,這里我們選擇的是“序列結(jié)束”標(biāo)記。
sentences = ["It will rain in the",
"I want to eat a big bowl of",
"My dog is"]
上面這三個(gè)序列在標(biāo)記時(shí)都有不同的長度,我們使用下面的方法填充:
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
在進(jìn)行表計(jì)劃和添加填充后,得到了以下的結(jié)果:
{'input_ids': tensor([
[50256, 50256, 50256, 1026, 481, 6290, 287, 262],
[ 40, 765, 284, 4483, 257, 1263, 9396, 286],
[50256, 50256, 50256, 50256, 50256, 3666, 3290, 318]
]),
'attention_mask': tensor([
[0, 0, 0, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1]
])}
可以看到,第一個(gè)和第三個(gè)序列在開始時(shí)進(jìn)行了填充,并且attention_mask參數(shù)標(biāo)記了這個(gè)填充的位置。
現(xiàn)在讓我們將這個(gè)輸入傳遞給模型來生成新的文本:
output_sequences = gpt2.generate(**inputs)
如果你不熟悉函數(shù)調(diào)用的**kwargs語法,它是將輸入字典作為命名參數(shù)傳入,使用鍵作為參數(shù)名,并使用值作為相應(yīng)的實(shí)參值。
我們只需要循環(huán)遍歷每個(gè)生成的序列并以人類可讀的形式打印出結(jié)果,使用decode()函數(shù)將令牌id轉(zhuǎn)換為字符串。
for seq in output_sequences:
print(tokenizer.decode(seq))
在注意力掩碼中,我們的輸入是0和1,但是在最終的計(jì)算時(shí),會(huì)將在將無效位置的注意力權(quán)重設(shè)置為一個(gè)很小的值,通常為負(fù)無窮(-inf),以便在計(jì)算注意力分?jǐn)?shù)時(shí)將其抑制為接近零的概率。
這時(shí)因?yàn)?,在?jì)算注意力權(quán)重時(shí),需要進(jìn)行Softmax的計(jì)算:
Softmax函數(shù)的性質(zhì):注意力機(jī)制通常使用Softmax函數(shù)將注意力分?jǐn)?shù)轉(zhuǎn)化為注意力權(quán)重,Softmax函數(shù)對(duì)輸入值進(jìn)行指數(shù)運(yùn)算,然后進(jìn)行歸一化。當(dāng)輸入值非常小或負(fù)無窮時(shí),經(jīng)過指數(shù)運(yùn)算后會(huì)接近零。因此,將掩碼設(shè)置為負(fù)無窮可以確保在Softmax函數(shù)計(jì)算時(shí),對(duì)應(yīng)位置的注意力權(quán)重趨近于零。
排除無效位置的影響:通過將無效位置的注意力權(quán)重設(shè)置為負(fù)無窮,可以有效地將這些位置的權(quán)重壓低。在計(jì)算注意力權(quán)重時(shí),負(fù)無窮的權(quán)重會(huì)使對(duì)應(yīng)位置的注意力權(quán)重接近于零,從而模型會(huì)忽略無效位置的影響。這樣可以確保模型更好地關(guān)注有效的信息,提高模型的準(zhǔn)確性和泛化能力。
但是負(fù)無窮并不是唯一的選擇。有時(shí)也可以選擇使用一個(gè)很大的負(fù)數(shù),以達(dá)到相似的效果。具體的選擇可以根據(jù)具體的任務(wù)和模型的需求來確定。
-
處理器
+關(guān)注
關(guān)注
68文章
19432瀏覽量
231284 -
虛擬機(jī)
+關(guān)注
關(guān)注
1文章
949瀏覽量
28457 -
python
+關(guān)注
關(guān)注
56文章
4809瀏覽量
85065
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
深度分析NLP中的注意力機(jī)制
注意力機(jī)制的誕生、方法及幾種常見模型
注意力機(jī)制或?qū)⑹俏磥頇C(jī)器學(xué)習(xí)的核心要素
基于注意力機(jī)制的深度學(xué)習(xí)模型AT-DPCNN
![基于<b class='flag-5'>注意力</b><b class='flag-5'>機(jī)制</b>的深度學(xué)習(xí)模型AT-DPCNN](https://file.elecfans.com/web1/M00/E5/D2/pIYBAGBRYa2AJBfwAAItGPaXyTE263.png)
基于多層CNN和注意力機(jī)制的文本摘要模型
![基于多層CNN和<b class='flag-5'>注意力</b><b class='flag-5'>機(jī)制</b>的文本摘要模型](https://file.elecfans.com/web1/M00/E9/AD/pIYBAGBtKnSAQwZNAAJC8wqBACw976.png)
結(jié)合注意力機(jī)制的跨域服裝檢索方法
基于多層注意力機(jī)制的回指消解算法綜述
基于注意力機(jī)制等的社交網(wǎng)絡(luò)熱度預(yù)測模型
基于非對(duì)稱注意力機(jī)制殘差網(wǎng)絡(luò)的圖像檢測
計(jì)算機(jī)視覺中的注意力機(jī)制
![計(jì)算機(jī)視覺<b class='flag-5'>中</b>的<b class='flag-5'>注意力</b><b class='flag-5'>機(jī)制</b>](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
PyTorch教程11.4之Bahdanau注意力機(jī)制
![PyTorch教程11.4之Bahdanau<b class='flag-5'>注意力</b><b class='flag-5'>機(jī)制</b>](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
評(píng)論