在我們對(duì)線(xiàn)性回歸的介紹中,我們介紹了各種組件,包括數(shù)據(jù)、模型、損失函數(shù)和優(yōu)化算法。事實(shí)上,線(xiàn)性回歸是最簡(jiǎn)單的機(jī)器學(xué)習(xí)模型之一。然而,訓(xùn)練它使用許多與本書(shū)中其他模型所需的組件相同的組件。因此,在深入了解實(shí)現(xiàn)細(xì)節(jié)之前,有必要設(shè)計(jì)一些貫穿本書(shū)的 API。將深度學(xué)習(xí)中的組件視為對(duì)象,我們可以從為這些對(duì)象及其交互定義類(lèi)開(kāi)始。這種面向?qū)ο蟮膶?shí)現(xiàn)設(shè)計(jì)將極大地簡(jiǎn)化演示,您甚至可能想在您的項(xiàng)目中使用它。
受PyTorch Lightning等開(kāi)源庫(kù)的啟發(fā),在高層次上我們希望擁有三個(gè)類(lèi):(i)Module
包含模型、損失和優(yōu)化方法;(ii)DataModule
提供用于訓(xùn)練和驗(yàn)證的數(shù)據(jù)加載器;(iii) 兩個(gè)類(lèi)結(jié)合使用該類(lèi) Trainer
,這使我們能夠在各種硬件平臺(tái)上訓(xùn)練模型。本書(shū)中的大部分代碼都改編自Module
and DataModule
。Trainer
只有在討論 GPU、CPU、并行訓(xùn)練和優(yōu)化算法時(shí),我們才會(huì)涉及該類(lèi)。
import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
import tensorflow as tf
from d2l import torch as d2l
3.2.1. 公用事業(yè)
我們需要一些實(shí)用程序來(lái)簡(jiǎn)化 Jupyter 筆記本中的面向?qū)ο?a target='_blank' class='arckwlink_none'>編程。挑戰(zhàn)之一是類(lèi)定義往往是相當(dāng)長(zhǎng)的代碼塊。筆記本電腦的可讀性需要簡(jiǎn)短的代碼片段,穿插著解釋?zhuān)@種要求與 Python 庫(kù)常見(jiàn)的編程風(fēng)格不相容。第一個(gè)實(shí)用函數(shù)允許我們?cè)趧?chuàng)建類(lèi)后將函數(shù)注冊(cè)為類(lèi)中的方法。事實(shí)上,即使我們已經(jīng)創(chuàng)建了類(lèi)的實(shí)例,我們也可以這樣做!它允許我們將一個(gè)類(lèi)的實(shí)現(xiàn)拆分成多個(gè)代碼塊。
def add_to_class(Class): #@save
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper
讓我們快速瀏覽一下如何使用它。我們計(jì)劃 A
用一個(gè)方法來(lái)實(shí)現(xiàn)一個(gè)類(lèi)do
。我們可以先聲明類(lèi)并創(chuàng)建一個(gè)實(shí)例,而不是在同一個(gè)代碼塊中A
同時(shí) 擁有兩者的代碼。do
A
a
do
接下來(lái)我們像往常一樣 定義方法,但不在 classA
的范圍內(nèi)。相反,我們add_to_class
用類(lèi)A
作為參數(shù)來(lái)裝飾這個(gè)方法。這樣做時(shí),該方法能夠訪(fǎng)問(wèn) 的成員變量,A
正如我們所期望的那樣,如果它已被定義為 的A
定義的一部分。讓我們看看當(dāng)我們?yōu)閷?shí)例調(diào)用它時(shí)會(huì)發(fā)生什么a
。
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
Class attribute "b" is 1
Class attribute "b" is 1
第二個(gè)是實(shí)用程序類(lèi),它將類(lèi) __init__
方法中的所有參數(shù)保存為類(lèi)屬性。這使我們無(wú)需額外代碼即可隱式擴(kuò)展構(gòu)造函數(shù)調(diào)用簽名。
我們將其實(shí)施推遲到第 23.7 節(jié)。HyperParameters
要使用它,我們定義繼承自該方法并調(diào)用 save_hyperparameters
該方法的類(lèi)__init__
。
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
最后一個(gè)實(shí)用程序允許我們?cè)趯?shí)驗(yàn)進(jìn)行時(shí)以交互方式繪制實(shí)驗(yàn)進(jìn)度。為了尊重更強(qiáng)大(和復(fù)雜)的TensorBoard,我們將其命名為ProgressBoard
。實(shí)現(xiàn)推遲到 第 23.7 節(jié)。現(xiàn)在,讓我們簡(jiǎn)單地看看它的實(shí)際效果。
該方法在圖中 draw
繪制一個(gè)點(diǎn),并在圖例中指定。可選的僅通過(guò)顯示來(lái)平滑線(xiàn)條(x, y)
label
every_n
1/n圖中的點(diǎn)。他們的價(jià)值是從平均n原始圖中的鄰居點(diǎn)。
class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImpleme
評(píng)論