当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch set_default_dtype用法及代码示例


本文简要介绍python语言中 torch.set_default_dtype 的用法。

用法:

torch.set_default_dtype(d)

参数

d(torch.dtype) -要设为默认值的浮点 dtype。 torch.float32 或 torch.float64。

将默认浮点数据类型设置为 d 。支持 torch.float32 和 torch.float64 作为输入。其他 dtype 可能会毫无怨言地被接受,但不受支持且不太可能按预期工作。

当PyTorch初始化时,其默认浮点数据类型为torch.float32,set_default_dtype(torch.float64)的目的是促进NumPy-like类型推断。默认浮点数据类型用于:

  1. 隐式确定默认的复杂数据类型。当默认浮点类型为 float32 时,默认复数 dtype 为 complex64,当默认浮点类型为 float64 时,默认复数类型为 complex128。

  2. 推断使用 Python 浮点数或复杂 Python 数字构造的张量的 dtype。请参阅下面的示例。

  3. 确定 bool 和整数张量以及 Python 浮点数和复杂 Python 数字之间的类型提升的结果。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype   # a new complex tensor
torch.complex128

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.set_default_dtype。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。