與我們大多數(shù)從頭開始的實(shí)施一樣, 第 9.5 節(jié)旨在深入了解每個(gè)組件的工作原理。但是,當(dāng)您每天使用 RNN 或編寫生產(chǎn)代碼時(shí),您會(huì)希望更多地依賴于減少實(shí)現(xiàn)時(shí)間(通過為通用模型和函數(shù)提供庫代碼)和計(jì)算時(shí)間(通過優(yōu)化這些庫實(shí)現(xiàn))。本節(jié)將向您展示如何使用深度學(xué)習(xí)框架提供的高級(jí) API 更有效地實(shí)現(xiàn)相同的語言模型。和以前一樣,我們首先加載時(shí)間機(jī)器數(shù)據(jù)集。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
9.6.1. 定義模型
我們使用由高級(jí) API 實(shí)現(xiàn)的 RNN 定義以下類。
Specifically, to initialize the hidden state, we invoke the member method begin_state
. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen
API.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
繼承自9.5 節(jié)RNNLMScratch
中的類 ,下面的類定義了一個(gè)完整的基于 RNN 的語言模型。請(qǐng)注意,我們需要?jiǎng)?chuàng)建一個(gè)單獨(dú)的全連接輸出層。RNNLM
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = nn.LazyLinear(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
9.6.2. 訓(xùn)練和預(yù)測(cè)
在訓(xùn)練模型之前,讓我們使用隨機(jī)權(quán)重初始化的模型進(jìn)行預(yù)測(cè)。鑒于我們還沒有訓(xùn)練網(wǎng)絡(luò),它會(huì)產(chǎn)生無意義的預(yù)測(cè)。
'it hasgggggggggggggggggggg'
'it hasxlxlxlxlxlxlxlxlxlxl'
接下來,我們利用高級(jí) API 訓(xùn)練我們的模型。
與第 9.5 節(jié)相比,該模型實(shí)現(xiàn)了相當(dāng)?shù)睦Щ蠖?,但由于?shí)現(xiàn)優(yōu)化,運(yùn)行速度更快。和以前一樣,我們可以在指定的前綴字符串之后生成預(yù)測(cè)標(biāo)記。
評(píng)論