当前位置: 首页>>代码示例>>Python>>正文


Python distributed.is_available方法代码示例

本文整理汇总了Python中torch.distributed.is_available方法的典型用法代码示例。如果您正苦于以下问题:Python distributed.is_available方法的具体用法?Python distributed.is_available怎么用?Python distributed.is_available使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.distributed的用法示例。


在下文中一共展示了distributed.is_available方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:18,代码来源:distributed.py

示例2: __init__

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        import torch.distributed as dist

        super().__init__(dataset)
        if num_replicas is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
开发者ID:mars-project,项目名称:mars,代码行数:22,代码来源:sampler.py

示例3: __init__

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = True 
开发者ID:clw5180,项目名称:remote_sensing_object_detection_2019,代码行数:18,代码来源:distributed.py

示例4: __init__

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, pad=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.pad = pad
        self.epoch = 0
        if self.pad:
            self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.num_samples = int(math.ceil((len(self.dataset)-self.rank) * 1.0 / self.num_replicas))
            self.total_size = len(self.dataset) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:22,代码来源:sampler.py

示例5: setup_distributed

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def setup_distributed(port=29500):
    if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
        return 0, 1

    if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
        from mpi4py import MPI
        mpi_rank = MPI.COMM_WORLD.Get_rank()
        mpi_size = MPI.COMM_WORLD.Get_size()

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(port)

        dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
        return mpi_rank, mpi_size

    dist.init_process_group(backend="nccl", init_method="env://")
    return dist.get_rank(), dist.get_world_size() 
开发者ID:openai,项目名称:gpt-2-output-dataset,代码行数:19,代码来源:train.py

示例6: reduce_mean

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def reduce_mean(tensor):
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:8,代码来源:gfl_head.py

示例7: _parse_losses

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
                which may be a weighted sum of all losses, log_vars contains
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:36,代码来源:base.py

示例8: get_world_size

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:8,代码来源:comm.py

示例9: get_rank

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:8,代码来源:comm.py

示例10: synchronize

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:15,代码来源:comm.py

示例11: get_world_size

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def get_world_size() -> int:
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:8,代码来源:misc.py

示例12: is_dist_avail_and_initialized

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_available [as 别名]
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True 
开发者ID:paperswithcode,项目名称:torchbench,代码行数:8,代码来源:coco_eval.py


注:本文中的torch.distributed.is_available方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。