当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch annotate用法及代码示例


本文简要介绍python语言中 torch.jit.annotate 的用法。

用法:

torch.jit.annotate(the_type, the_value)

参数

  • the_type-应作为 the_value 的类型提示传递给 TorchScript 编译器的 Python 类型

  • the_value-提示类型的值或表达式。

返回

the_value 作为返回值传回。

该方法是一个 pass-through 函数,返回 the_value ,用于提示 TorchScript 编译器 the_value 的类型。在 TorchScript 之外运行时,它是 no-op。

虽然 TorchScript 可以推断出大多数 Python 表达式的正确类型,但在某些情况下类型推断可能会出错,包括:

  • 空容器,例如 []{} ,其中 TorchScript 假定为 Tensor 的容器

  • 可选类型,如 Optional[T] 但分配了 T 类型的有效值,TorchScript 会假设它是类型 T 而不是 Optional[T]

请注意,annotate()torch.nn.Module 子类的 __init__ 方法没有帮助,因为它是在即刻模式下执行的。要注释 torch.nn.Module 属性的类型,请改用 Annotate()

例子:

import torch
from typing import Dict

@torch.jit.script
def fn():
    # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
    # instead of default dictionary type of (str -> Tensor).
    d = torch.jit.annotate(Dict[str, int], {})

    # Without `torch.jit.annotate` above, following statement would fail because of
    # type mismatch.
    d["name"] = 20

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.annotate。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。