本文整理汇总了Python中torch.nn.parallel.data_parallel.DataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python data_parallel.DataParallel方法的具体用法?Python data_parallel.DataParallel怎么用?Python data_parallel.DataParallel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.parallel.data_parallel
的用法示例。
在下文中一共展示了data_parallel.DataParallel方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: patch_replication_callback
# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
示例2: patch_replication_callback
# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
示例3: set_device
# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def set_device(self, device):
device = cast_device(device)
str_device = device_to_str(device)
nn_module = self.get_nn_module()
if isinstance(device, (list, tuple)):
device_ids = []
for dev in device:
if dev.type != 'cuda':
raise ValueError("Non cuda device in list of devices")
if dev.index is None:
raise ValueError("Cuda device without index in list of devices")
device_ids.append(dev.index)
if len(device_ids) != len(set(device_ids)):
raise ValueError("Cuda device indices must be unique")
nn_module = DataParallel(nn_module, device_ids=device_ids)
device = device[0]
self.params['device'] = str_device
self.device = device
self.nn_module = nn_module.to(self.device)
if self.loss is not default:
self.loss = self.loss.to(self.device)
示例4: patch_replication_callback
# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
示例5: patch_replication_callback
# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
exec_data_parallel_replication_callback(modules)
return modules
data_parallel.replicate = new_replicate