本章說(shuō)明 Python API 的基本用法,假設(shè)您從 ONNX 模型開始。onnx_resnet50.py示例更詳細(xì)地說(shuō)明了這個(gè)用例。
Python API 可以通過(guò)tensorrt模塊訪問(wèn):
import tensorrt as trt
4.1. The Build Phase
要?jiǎng)?chuàng)建構(gòu)建器,您需要首先創(chuàng)建一個(gè)記錄器。 Python 綁定包括一個(gè)簡(jiǎn)單的記錄器實(shí)現(xiàn),它將高于特定嚴(yán)重性的所有消息記錄到stdout
。
logger = trt.Logger(trt.Logger.WARNING)
或者,可以通過(guò)從ILogger
類派生來(lái)定義您自己的記錄器實(shí)現(xiàn):
class MyLogger(trt.ILogger): def __init__(self): trt.ILogger.__init__(self) def log(self, severity, msg): pass # Your custom logging implementation here logger = MyLogger()
然后,您可以創(chuàng)建一個(gè)構(gòu)建器:
builder = trt.Builder(logger)
4.1.1. Creating a Network Definition in Python
創(chuàng)建構(gòu)建器后,優(yōu)化模型的第一步是創(chuàng)建網(wǎng)絡(luò)定義:
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
為了使用 ONNX 解析器導(dǎo)入模型,需要EXPLICIT_BATCH
標(biāo)志。有關(guān)詳細(xì)信息,請(qǐng)參閱顯式與隱式批處理部分。
4.1.2. Importing a Model using the ONNX Parser
現(xiàn)在,需要從 ONNX 表示中填充網(wǎng)絡(luò)定義。您可以創(chuàng)建一個(gè) ONNX 解析器來(lái)填充網(wǎng)絡(luò),如下所示:
parser = trt.OnnxParser(network, logger)
然后,讀取模型文件并處理任何錯(cuò)誤:
success = parser.parse_from_file(model_path) for idx in range(parser.num_errors): print(parser.get_error(idx)) if not success: pass # Error handling code here
4.1.3. Building an Engine
下一步是創(chuàng)建一個(gè)構(gòu)建配置,指定 TensorRT 應(yīng)該如何優(yōu)化模型:
config = builder.create_builder_config()
這個(gè)接口有很多屬性,你可以設(shè)置這些屬性來(lái)控制 TensorRT 如何優(yōu)化網(wǎng)絡(luò)。一個(gè)重要的屬性是最大工作空間大小。層實(shí)現(xiàn)通常需要一個(gè)臨時(shí)工作空間,并且此參數(shù)限制了網(wǎng)絡(luò)中任何層可以使用的最大大小。如果提供的工作空間不足,TensorRT 可能無(wú)法找到層的實(shí)現(xiàn):
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) # 1 MiB
指定配置后,可以使用以下命令構(gòu)建和序列化引擎:
serialized_engine = builder.build_serialized_network(network, config)
將引擎保存到文件以供將來(lái)使用可能很有用。你可以這樣做:
with open(“sample.engine”, “wb”) as f: f.write(serialized_engine)
4.2. Deserializing a Plan
要執(zhí)行推理,您首先需要使用Runtime接口反序列化引擎。與構(gòu)建器一樣,運(yùn)行時(shí)需要記錄器的實(shí)例。
runtime = trt.Runtime(logger)
然后,您可以從內(nèi)存緩沖區(qū)反序列化引擎:
engine = runtime.deserialize_cuda_engine(serialized_engine)
如果您需要首先從文件加載引擎,請(qǐng)運(yùn)行:
with open(“sample.engine”, “rb”) as f: serialized_engine = f.read()
4.3. Performing Inference
引擎擁有優(yōu)化的模型,但要執(zhí)行推理需要額外的中間激活狀態(tài)。這是通過(guò)IExecutionContext
接口完成的:
context = engine.create_execution_context()
一個(gè)引擎可以有多個(gè)執(zhí)行上下文,允許一組權(quán)重用于多個(gè)重疊的推理任務(wù)。 (當(dāng)前的一個(gè)例外是使用動(dòng)態(tài)形狀時(shí),每個(gè)優(yōu)化配置文件只能有一個(gè)執(zhí)行上下文。)
要執(zhí)行推理,您必須為輸入和輸出傳遞 TensorRT 緩沖區(qū),TensorRT 要求您在 GPU 指針列表中指定。您可以使用為輸入和輸出張量提供的名稱查詢引擎,以在數(shù)組中找到正確的位置:
input_idx = engine[input_name] output_idx = engine[output_name]
使用這些索引,為每個(gè)輸入和輸出設(shè)置 GPU 緩沖區(qū)。多個(gè) Python 包允許您在 GPU 上分配內(nèi)存,包括但不限于 PyTorch、Polygraphy CUDA 包裝器和 PyCUDA。
然后,創(chuàng)建一個(gè) GPU 指針列表。例如,對(duì)于 PyTorch CUDA 張量,您可以使用data_ptr()
方法訪問(wèn) GPU 指針;對(duì)于 PolygraphyDeviceArray
,使用ptr
屬性:
buffers = [None] * 2 # Assuming 1 input and 1 output buffers[input_idx] = input_ptr buffers[output_idx] = output_ptr
填充輸入緩沖區(qū)后,您可以調(diào)用 TensorRT 的execute_async
方法以使用 CUDA 流異步啟動(dòng)推理。
首先,創(chuàng)建 CUDA 流。如果您已經(jīng)有 CUDA 流,則可以使用指向現(xiàn)有流的指針。例如,對(duì)于 PyTorch CUDA 流,即torch.cuda.Stream()
,您可以使用cuda_stream
屬性訪問(wèn)指針;對(duì)于 Polygraphy CUDA 流,使用ptr
屬性。 接下來(lái),開始推理:
context.execute_async_v2(buffers, stream_ptr)
通常在內(nèi)核之前和之后將異步memcpy()
排入隊(duì)列以從 GPU 中移動(dòng)數(shù)據(jù)(如果數(shù)據(jù)尚不存在)。
要確定內(nèi)核(可能還有memcpy() )何時(shí)完成,請(qǐng)使用標(biāo)準(zhǔn) CUDA 同步機(jī)制,例如事件或等待流。例如,對(duì)于 Polygraphy,使用:
stream.synchronize()
如果您更喜歡同步推理,請(qǐng)使用execute_v2
方法而不是execute_async_v2
。
關(guān)于作者
Ken He 是 NVIDIA 企業(yè)級(jí)開發(fā)者社區(qū)經(jīng)理 & 高級(jí)講師,擁有多年的 GPU 和人工智能開發(fā)經(jīng)驗(yàn)。自 2017 年加入 NVIDIA 開發(fā)者社區(qū)以來(lái),完成過(guò)上百場(chǎng)培訓(xùn),幫助上萬(wàn)個(gè)開發(fā)者了解人工智能和 GPU 編程開發(fā)。在計(jì)算機(jī)視覺,高性能計(jì)算領(lǐng)域完成過(guò)多個(gè)獨(dú)立項(xiàng)目。并且,在機(jī)器人和無(wú)人機(jī)領(lǐng)域,有過(guò)豐富的研發(fā)經(jīng)驗(yàn)。對(duì)于圖像識(shí)別,目標(biāo)的檢測(cè)與跟蹤完成過(guò)多種解決方案。曾經(jīng)參與 GPU 版氣象模式GRAPES,是其主要研發(fā)者。
審核編輯:郭婷
-
gpu
+關(guān)注
關(guān)注
28文章
4783瀏覽量
129395 -
python
+關(guān)注
關(guān)注
56文章
4809瀏覽量
85065 -
CUDA
+關(guān)注
關(guān)注
0文章
121瀏覽量
13692
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
Python存儲(chǔ)數(shù)據(jù)詳解
有關(guān)Python的解析
一張圖學(xué)會(huì)Python3的基本用法
![一張圖學(xué)會(huì)<b class='flag-5'>Python</b>3的基本<b class='flag-5'>用法</b>](https://file.elecfans.com/web1/M00/45/EE/o4YBAFp8-FCAWsMoAAALBI2OPQ4077.jpg)
python代碼示例之基于Python的日歷api調(diào)用代碼實(shí)例
![<b class='flag-5'>python</b>代碼示例之基于<b class='flag-5'>Python</b>的日歷<b class='flag-5'>api</b>調(diào)用代碼實(shí)例](https://file.elecfans.com/web1/M00/63/17/o4YBAFuQy8-AO90pAAAei-DUxgU163.png)
API-Shop-OCR-營(yíng)業(yè)執(zhí)照識(shí)別API接口Python調(diào)用示例代碼說(shuō)明
![<b class='flag-5'>API</b>-Shop-OCR-營(yíng)業(yè)執(zhí)照識(shí)別<b class='flag-5'>API</b><b class='flag-5'>接口</b><b class='flag-5'>Python</b>調(diào)用示例代碼說(shuō)明](https://file.elecfans.com/web1/M00/82/25/pIYBAFw24LuAJk32AAM-2t6LMes231.png)
使用Python實(shí)現(xiàn)游戲APP充值API調(diào)用的代碼實(shí)例
ADM1266 Linux API和Python庫(kù)簡(jiǎn)介
![ADM1266 Linux <b class='flag-5'>API</b>和<b class='flag-5'>Python</b>庫(kù)簡(jiǎn)介](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
將Tengine Python API移植到Tengine Lite
![將Tengine <b class='flag-5'>Python</b> <b class='flag-5'>API</b>移植到Tengine Lite](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
評(píng)論