当前位置: 首页>>代码示例>>Python>>正文


Python data_parallel.DataParallel方法代码示例

本文整理汇总了Python中torch.nn.parallel.data_parallel.DataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python data_parallel.DataParallel方法的具体用法?Python data_parallel.DataParallel怎么用?Python data_parallel.DataParallel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.parallel.data_parallel的用法示例。


在下文中一共展示了data_parallel.DataParallel方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: patch_replication_callback

# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.
    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate 
开发者ID:clovaai,项目名称:overhaul-distillation,代码行数:26,代码来源:replicate.py

示例2: patch_replication_callback

# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate 
开发者ID:speedinghzl,项目名称:pytorch-segmentation-toolbox,代码行数:27,代码来源:encoding.py

示例3: set_device

# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def set_device(self, device):
        device = cast_device(device)
        str_device = device_to_str(device)
        nn_module = self.get_nn_module()

        if isinstance(device, (list, tuple)):
            device_ids = []
            for dev in device:
                if dev.type != 'cuda':
                    raise ValueError("Non cuda device in list of devices")
                if dev.index is None:
                    raise ValueError("Cuda device without index in list of devices")
                device_ids.append(dev.index)
            if len(device_ids) != len(set(device_ids)):
                raise ValueError("Cuda device indices must be unique")
            nn_module = DataParallel(nn_module, device_ids=device_ids)
            device = device[0]

        self.params['device'] = str_device
        self.device = device
        self.nn_module = nn_module.to(self.device)
        if self.loss is not default:
            self.loss = self.loss.to(self.device) 
开发者ID:lRomul,项目名称:argus,代码行数:25,代码来源:build.py

示例4: patch_replication_callback

# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
  """
  Monkey-patch an existing `DataParallel` object. Add the replication callback.
  Useful when you have customized `DataParallel` implementation.

  Examples:
      > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
      > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
      > patch_replication_callback(sync_bn)
      # this is equivalent to
      > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
      > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
  """

  assert isinstance(data_parallel, DataParallel)

  old_replicate = data_parallel.replicate

  @functools.wraps(old_replicate)
  def new_replicate(module, device_ids):
    modules = old_replicate(module, device_ids)
    execute_replication_callbacks(modules)
    return modules

  data_parallel.replicate = new_replicate 
开发者ID:PRBonn,项目名称:lidar-bonnetal,代码行数:27,代码来源:replicate.py

示例5: patch_replication_callback

# 需要导入模块: from torch.nn.parallel import data_parallel [as 别名]
# 或者: from torch.nn.parallel.data_parallel import DataParallel [as 别名]
def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.

    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = ReplicationCallbackDataParallel(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        exec_data_parallel_replication_callback(modules)
        return modules

    data_parallel.replicate = new_replicate 
开发者ID:vacancy,项目名称:Jacinle,代码行数:27,代码来源:replication_callback.py


注:本文中的torch.nn.parallel.data_parallel.DataParallel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。