本文简要介绍python语言中 torch.nn.DataParallel
的用法。
用法:
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
module(torch.nn.Module) -要并行化的模块
device_ids(Python列表:int或者torch.device) -CUDA 设备(默认:所有设备)
output_device(int或者torch.device) -输出的设备位置(默认值:device_ids[0])
~DataParallel.module(torch.nn.Module) -要并行化的模块
在模块级别实现数据并行。
该容器并行化给定
module
的应用程序,方法是通过在批处理维度中分块将输入拆分到指定的设备(其他对象将在每个设备上复制一次)。在前向传递中,模块在每个设备上复制,每个副本处理一部分输入。在向后传递期间,来自每个副本的梯度被汇总到原始模块中。批量大小应大于使用的 GPU 数量。
警告
建议使用
DistributedDataParallel
代替此类来进行multi-GPU训练,即使只有一个节点。请参阅:使用 nn.parallel.DistributedDataParallel 而不是多处理或 nn.DataParallel 和分布式数据并行。允许将任意位置和关键字输入传递到DataParallel,但某些类型会经过特殊处理。张量将分散在指定的dim(默认为 0)上。 tuple、list 和 dict 类型将被浅拷贝。其他类型将在不同线程之间共享,并且如果写入模型的正向传递中可能会损坏。
在运行
DataParallel
模块之前,并行化的module
必须在device_ids[0]
上拥有其参数和缓冲区。警告
在每一个前锋中,
module
是复制的在每个设备上,因此对正在运行的模块的任何更新forward
会迷路。例如,如果module
有一个计数器属性,在每个递增forward
,它将始终保持在初始值,因为更新是在之后销毁的副本上完成的forward
.然而,DataParallel
保证副本上device[0]
将使其参数和缓冲区与基本并行化共享存储module
.所以到位更新参数或缓冲区device[0]
将被记录。例如:,torch.nn.BatchNorm2d和torch.nn.utils.spectral_norm依靠这种行为来更新缓冲区。警告
module
及其子模块上定义的前向和后向钩子将被调用len(device_ids)
次,每次的输入都位于特定设备上。特别地,钩子仅保证相对于相应设备上的操作以正确的顺序执行。例如,不能保证通过register_forward_pre_hook()
设置的钩子在all
len(device_ids)
forward()
调用之前执行,但每个这样的钩子都在该设备的相应forward()
调用之前执行。警告
当
module
在forward()
中返回标量(即 0 维张量)时,此包装器将返回一个长度等于数据并行中使用的设备数量的向量,其中包含来自每个设备的结果。注意
在
DataParallel
包的Module
中使用pack sequence -> recurrent network -> unpack sequence
模式有一个微妙之处。有关详细信息,请参阅常见问题解答中的“我的循环网络无法使用数据并行性”部分。例子:
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU
参数:
变量:
相关用法
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DistributedModelParallel用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.DataParallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。