當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch annotate用法及代碼示例


本文簡要介紹python語言中 torch.jit.annotate 的用法。

用法:

torch.jit.annotate(the_type, the_value)

參數

  • the_type-應作為 the_value 的類型提示傳遞給 TorchScript 編譯器的 Python 類型

  • the_value-提示類型的值或表達式。

返回

the_value 作為返回值傳回。

該方法是一個 pass-through 函數,返回 the_value ,用於提示 TorchScript 編譯器 the_value 的類型。在 TorchScript 之外運行時,它是 no-op。

雖然 TorchScript 可以推斷出大多數 Python 表達式的正確類型,但在某些情況下類型推斷可能會出錯,包括:

  • 空容器,例如 []{} ,其中 TorchScript 假定為 Tensor 的容器

  • 可選類型,如 Optional[T] 但分配了 T 類型的有效值,TorchScript 會假設它是類型 T 而不是 Optional[T]

請注意,annotate()torch.nn.Module 子類的 __init__ 方法沒有幫助,因為它是在即刻模式下執行的。要注釋 torch.nn.Module 屬性的類型,請改用 Annotate()

例子:

import torch
from typing import Dict

@torch.jit.script
def fn():
    # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
    # instead of default dictionary type of (str -> Tensor).
    d = torch.jit.annotate(Dict[str, int], {})

    # Without `torch.jit.annotate` above, following statement would fail because of
    # type mismatch.
    d["name"] = 20

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.annotate。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。