当前位置: 首页>>代码示例>>Python>>正文


Python distributed.is_initialized方法代码示例

本文整理汇总了Python中torch.distributed.is_initialized方法的典型用法代码示例。如果您正苦于以下问题:Python distributed.is_initialized方法的具体用法?Python distributed.is_initialized怎么用?Python distributed.is_initialized使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.distributed的用法示例。


在下文中一共展示了distributed.is_initialized方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: average_across_processes

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]):
    r"""
    Averages a tensor, or a dict of tensors across all processes in a process
    group. Objects in all processes will finally have same mean value.

    .. note::

        Nested dicts of tensors are not supported.

    Parameters
    ----------
    t: torch.Tensor or Dict[str, torch.Tensor]
        A tensor or dict of tensors to average across processes.
    """
    if dist.is_initialized():
        if isinstance(t, torch.Tensor):
            dist.all_reduce(t, op=dist.ReduceOp.SUM)
            t /= get_world_size()
        elif isinstance(t, dict):
            for k in t:
                dist.all_reduce(t[k], op=dist.ReduceOp.SUM)
                t[k] /= dist.get_world_size() 
开发者ID:kdexd,项目名称:virtex,代码行数:24,代码来源:distributed.py

示例2: scatter

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def scatter(self, scatter_list, src, size=None, device=None):
        """Scatters a list of tensors to all parties."""
        assert dist.is_initialized(), "initialize the communicator first"
        if src != self.get_rank():
            if size is None:
                size = scatter_list[self.get_rank()].size()
            if device is None:
                try:
                    device = scatter_list[self.get_rank()].device
                except Exception:
                    pass
            tensor = torch.empty(size=size, dtype=torch.long, device=device)
            dist.scatter(tensor, [], src, group=self.main_group)
        else:
            scatter_list = [s.data for s in scatter_list]
            tensor = scatter_list[self.get_rank()]
            dist.scatter(tensor, scatter_list, src, group=self.main_group)
        return tensor 
开发者ID:facebookresearch,项目名称:CrypTen,代码行数:20,代码来源:distributed_communicator.py

示例3: reduce

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def reduce(self, input, dst, op=ReduceOp.SUM, batched=False):
        """Reduces the input data across all parties."""
        assert dist.is_initialized(), "initialize the communicator first"

        if batched:
            assert isinstance(input, list), "batched reduce input must be a list"
            reqs = []
            result = [x.clone().data for x in input]
            for tensor in result:
                reqs.append(
                    dist.reduce(
                        tensor, dst, op=op, group=self.main_group, async_op=True
                    )
                )
            for req in reqs:
                req.wait()
        else:
            assert torch.is_tensor(
                input.data
            ), "unbatched input for reduce must be a torch tensor"
            result = input.clone()
            dist.reduce(result.data, dst, op=op, group=self.main_group)

        return result if dst == self.get_rank() else None 
开发者ID:facebookresearch,项目名称:CrypTen,代码行数:26,代码来源:distributed_communicator.py

示例4: all_reduce

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def all_reduce(self, input, op=ReduceOp.SUM, batched=False):
        """Reduces the input data across all parties; all get the final result."""
        assert dist.is_initialized(), "initialize the communicator first"

        if batched:
            assert isinstance(input, list), "batched reduce input must be a list"
            reqs = []
            result = [x.clone() for x in input]
            for tensor in result:
                reqs.append(
                    dist.all_reduce(
                        tensor.data, op=op, group=self.main_group, async_op=True
                    )
                )
            for req in reqs:
                req.wait()
        else:
            assert torch.is_tensor(
                input.data
            ), "unbatched input for reduce must be a torch tensor"
            result = input.clone()
            dist.all_reduce(result.data, op=op, group=self.main_group)
        return result 
开发者ID:facebookresearch,项目名称:CrypTen,代码行数:25,代码来源:distributed_communicator.py

示例5: reduce_mean

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def reduce_mean(tensor):
    if not (dist.is_available() and dist.is_initialized()):
        return tensor
    tensor = tensor.clone()
    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
    return tensor 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:8,代码来源:gfl_head.py

示例6: _parse_losses

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
                which may be a weighted sum of all losses, log_vars contains
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:36,代码来源:base.py

示例7: get_world_size

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size():
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:8,代码来源:comm.py

示例8: get_rank

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_rank():
    if not dist.is_available():
        return 0
    if not dist.is_initialized():
        return 0
    return dist.get_rank() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:8,代码来源:comm.py

示例9: synchronize

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def synchronize():
    """
    Helper function to synchronize (barrier) among all processes when
    using distributed training
    """
    if not dist.is_available():
        return
    if not dist.is_initialized():
        return
    world_size = dist.get_world_size()
    if world_size == 1:
        return
    dist.barrier() 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:15,代码来源:comm.py

示例10: get_world_size

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size() -> int:
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size() 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:8,代码来源:misc.py

示例11: get_world_size

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size():
    if not torch.distributed.is_initialized():
        return 1
    return torch.distributed.get_world_size() 
开发者ID:AceCoooool,项目名称:LEDNet,代码行数:6,代码来源:parallel.py

示例12: get_rank

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_rank():
    if not torch.distributed.is_initialized():
        return 0
    return torch.distributed.get_rank() 
开发者ID:AceCoooool,项目名称:LEDNet,代码行数:6,代码来源:parallel.py

示例13: is_dist_avail_and_initialized

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True 
开发者ID:paperswithcode,项目名称:torchbench,代码行数:8,代码来源:coco_eval.py

示例14: setup_logger

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def setup_logger(logpth):
    logfile = 'Deeplab_v3plus-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
    logfile = osp.join(logpth, logfile)
    FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
    log_level = logging.INFO
    if dist.is_initialized() and dist.get_rank()!=0:
        log_level = logging.WARNING
    logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
    logging.root.addHandler(logging.StreamHandler()) 
开发者ID:CoinCheung,项目名称:DeepLab-v3-plus-cityscapes,代码行数:11,代码来源:logger.py

示例15: __init__

# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def __init__(self, cfg, *args, **kwargs):
        self.cfg = cfg
        self.distributed = dist.is_initialized()
        ## dataloader
        dsval = CityScapes(cfg, mode='val')
        sampler = None
        if self.distributed:
            sampler = torch.utils.data.distributed.DistributedSampler(dsval)
        self.dl = DataLoader(dsval,
                        batch_size = cfg.eval_batchsize,
                        sampler = sampler,
                        shuffle = False,
                        num_workers = cfg.eval_n_workers,
                        drop_last = False) 
开发者ID:CoinCheung,项目名称:DeepLab-v3-plus-cityscapes,代码行数:16,代码来源:evaluate.py


注:本文中的torch.distributed.is_initialized方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。