当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch ScriptModule.to用法及代码示例


本文简要介绍python语言中 torch.jit.ScriptModule.to 的用法。

用法:

to(*args, **kwargs)

参数

  • device(torch.device) -该模块中参数和缓冲区的所需设备

  • dtype(torch.dtype) -此模块中参数和缓冲区的所需浮点或复杂数据类型

  • tensor(torch.Tensor) -张量,其 dtype 和 device 是该模块中所有参数和缓冲区所需的 dtype 和 device

  • memory_format(torch.memory_format) -此模块中 4D 参数和缓冲区的所需内存格式(仅关键字参数)

返回

self

返回类型

torch.nn.Module

移动和/或强制转换参数和缓冲区。

这可以称为

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

它的签名类似于 torch.Tensor.to() ,但只接受浮点或复数 dtype s。此外,此方法只会将浮点数或复数参数和缓冲区转换为dtype(如果给定)。积分参数和缓冲区将被移动 device ,如果给出,但 dtypes 不变。当 non_blocking 设置时,如果可能,它会尝试相对于主机异步转换/移动,例如,将带有固定内存的 CPU 张量移动到 CUDA 设备。

请参阅下面的示例。

注意

此方法就地修改模块。

例子:

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.ScriptModule.to。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。