本文简要介绍python语言中 torch.nn.Module.to
的用法。
用法:
to(*args, **kwargs)
device(
torch.device
) -该模块中参数和缓冲区的所需设备dtype(
torch.dtype
) -此模块中参数和缓冲区的所需浮点或复杂数据类型tensor(torch.Tensor) -张量,其 dtype 和 device 是该模块中所有参数和缓冲区所需的 dtype 和 device
memory_format(
torch.memory_format
) -此模块中 4D 参数和缓冲区的所需内存格式(仅关键字参数)
self
移动和/或强制转换参数和缓冲区。
这可以称为
to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)
它的签名类似于
torch.Tensor.to()
,但只接受浮点或复数dtype
s。此外,此方法只会将浮点数或复数参数和缓冲区转换为dtype
(如果给定)。积分参数和缓冲区将被移动device
,如果给出,但 dtypes 不变。当non_blocking
设置时,如果可能,它会尝试相对于主机异步转换/移动,例如,将带有固定内存的 CPU 张量移动到 CUDA 设备。请参阅下面的示例。
注意
此方法就地修改模块。
例子:
>>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
参数:
返回:
返回类型:
相关用法
- Python PyTorch Module.buffers用法及代码示例
- Python PyTorch Module.register_full_backward_hook用法及代码示例
- Python PyTorch Module.named_modules用法及代码示例
- Python PyTorch Module.parameters用法及代码示例
- Python PyTorch Module.register_forward_hook用法及代码示例
- Python PyTorch Module.named_parameters用法及代码示例
- Python PyTorch Module.state_dict用法及代码示例
- Python PyTorch Module.register_forward_pre_hook用法及代码示例
- Python PyTorch Module.named_children用法及代码示例
- Python PyTorch Module.modules用法及代码示例
- Python PyTorch Module.register_buffer用法及代码示例
- Python PyTorch Module.apply用法及代码示例
- Python PyTorch Module.named_buffers用法及代码示例
- Python PyTorch ModuleList用法及代码示例
- Python PyTorch Module用法及代码示例
- Python PyTorch ModuleDict用法及代码示例
- Python PyTorch MaxUnpool3d用法及代码示例
- Python PyTorch MultiStepLR用法及代码示例
- Python PyTorch MaxPool1d用法及代码示例
- Python PyTorch MetaInferGroupedPooledEmbeddingsLookup.state_dict用法及代码示例
- Python PyTorch MetaInferGroupedEmbeddingsLookup.named_buffers用法及代码示例
- Python PyTorch MultiLabelMarginLoss用法及代码示例
- Python PyTorch MultiplicativeLR用法及代码示例
- Python PyTorch MixtureSameFamily用法及代码示例
- Python PyTorch MultiheadAttention用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Module.to。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。