当前位置: 首页>>代码示例>>Python>>正文


Python torch.distributed方法代码示例

本文整理汇总了Python中torch.distributed方法的典型用法代码示例。如果您正苦于以下问题:Python torch.distributed方法的具体用法?Python torch.distributed怎么用?Python torch.distributed使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.distributed方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __iter__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = self._get_epoch_indices(g)
            randperm = torch.randperm(len(indices), generator=g).tolist()
            indices = indices[randperm]
        else:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = self._get_epoch_indices(g)
            # indices = torch.arange(len(self.dataset)).tolist()

        # when balance len(indices) diff from dataset image_num
        self.total_size = len(indices)
        logging_rank('balance sample total_size: {}'.format(self.total_size), distributed=1, local_rank=self.rank)
        # subsample
        self.num_samples = int(len(indices) / self.num_replicas)
        offset = self.num_samples * self.rank
        indices = indices[offset: offset + self.num_samples]
        assert len(indices) == self.num_samples

        return iter(indices) 
开发者ID:soeaver,项目名称:Parsing-R-CNN,代码行数:26,代码来源:repeat_factor.py

示例2: _init_device

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def _init_device(self):
        self.logging('='*20 + 'Init Device' + '='*20)

        # set device
        if self.setting.local_rank == -1 or self.setting.no_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available() and not self.setting.no_cuda else "cpu")
            self.n_gpu = torch.cuda.device_count()
        else:
            self.device = torch.device("cuda", self.setting.local_rank)
            self.n_gpu = 1
            if self.setting.fp16:
                self.logging("16-bits training currently not supported in distributed training")
                self.setting.fp16 = False  # (see https://github.com/pytorch/pytorch/pull/13496)
        self.logging("device {} n_gpu {} distributed training {}".format(
            self.device, self.n_gpu,self.in_distributed_mode()
        )) 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:18,代码来源:base_task.py

示例3: init_epoch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def init_epoch(self):
        """Set up the batch generator for a new epoch."""
        if not self.distributed:
            if self._restored_from_state:
                self.random_shuffler.random_state = self._random_state_this_epoch
            else:
                self._random_state_this_epoch = self.random_shuffler.random_state

        self.create_batches()

        if not self.distributed:
            if self._restored_from_state:
                self._restored_from_state = False
            else:
                self._iterations_this_epoch = 0
        else:
            self._iterations_this_epoch = 0


        if not self.repeat:
            self.iterations = 0
        self.epoch += 1
        if self.distributed:
            self.random_shuffler.set_epoch(self.epoch) 
开发者ID:salesforce,项目名称:decaNLP,代码行数:26,代码来源:iterator.py

示例4: _decorate_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:23,代码来源:base_task.py

示例5: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
开发者ID:Res2Net,项目名称:Res2Net-maskrcnn,代码行数:18,代码来源:distributed.py

示例6: parse_args

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("cfg_file", help="config file", type=str)
    parser.add_argument("--iter", dest="start_iter",
                        help="train at iteration i",
                        default=0, type=int)
    parser.add_argument("--workers", default=4, type=int)
    parser.add_argument("--initialize", action="store_true")

    parser.add_argument("--distributed", action="store_true")
    parser.add_argument("--world-size", default=-1, type=int,
                        help="number of nodes of distributed training")
    parser.add_argument("--rank", default=0, type=int,
                        help="node rank for distributed training")
    parser.add_argument("--dist-url", default=None, type=str,
                        help="url used to set up distributed training")
    parser.add_argument("--dist-backend", default="nccl", type=str)

    args = parser.parse_args()
    return args 
开发者ID:DataXujing,项目名称:CornerNet-Lite-Pytorch,代码行数:22,代码来源:trainmyData.py

示例7: parse_args

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def parse_args():
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("cfg_file", help="config file", type=str) #训练用的配置文件
    parser.add_argument("--iter", dest="start_iter",
                        help="train at iteration i",
                        default=0, type=int)  #指定训练从第i次迭代开始
    parser.add_argument("--workers", default=4, type=int)
    parser.add_argument("--initialize", action="store_true")

    parser.add_argument("--distributed", action="store_true")  # 分布式训练
    parser.add_argument("--world-size", default=-1, type=int,
                        help="number of nodes of distributed training")  # 分布式节点的数量
    parser.add_argument("--rank", default=0, type=int,
                        help="node rank for distributed training")  # 分布式训练节点的等级
    parser.add_argument("--dist-url", default=None, type=str,
                        help="url used to set up distributed training")
    parser.add_argument("--dist-backend", default="nccl", type=str)

    args = parser.parse_args()
    return args 
开发者ID:DataXujing,项目名称:CornerNet-Lite-Pytorch,代码行数:22,代码来源:train.py

示例8: distributed_init

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def distributed_init(args):
    if args.distributed_world_size == 1:
        raise ValueError('Cannot initialize distributed with distributed_world_size=1')

    print('| distributed init (rank {}): {}'.format(
        args.distributed_rank, args.distributed_init_method), flush=True)
    if args.distributed_init_method.startswith('tcp://'):
        torch.distributed.init_process_group(
            backend=args.distributed_backend, init_method=args.distributed_init_method,
            world_size=args.distributed_world_size, rank=args.distributed_rank)
    else:
        torch.distributed.init_process_group(
            backend=args.distributed_backend, init_method=args.distributed_init_method,
            world_size=args.distributed_world_size)

    args.distributed_rank = torch.distributed.get_rank()
    if not is_master(args):
        suppress_output()

    return args.distributed_rank 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:22,代码来源:distributed_utils.py

示例9: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        import torch.distributed as dist

        super().__init__(dataset)
        if num_replicas is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:  # pragma: no cover
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle 
开发者ID:mars-project,项目名称:mars,代码行数:22,代码来源:sampler.py

示例10: parse_args

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def parse_args():
    parser = argparse.ArgumentParser(description='Test dense matching benchmark')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--checkpoint', help='checkpoint file')
    parser.add_argument('--out_dir', help='output result directory')
    parser.add_argument('--show', type=str, default='False', help='show results in images')
    parser.add_argument('--validate', action='store_true', help='whether to evaluate the result')
    parser.add_argument('--gpus', type=int, default=1,
        help='number of gpus to use (only applicable to non-distributed training)')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='pytorch',
        help='job launcher'
    )
    parser.add_argument('--local_rank', type=int, default=0)

    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args 
开发者ID:DeepMotionAIResearch,项目名称:DenseMatchingBenchmark,代码行数:24,代码来源:test.py

示例11: run

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def run(backend, rank, rows, columns, num_gpus):
    # https://pytorch.org/docs/master/distributed.html
    if backend == 'gloo':
        print('Run operations supported by \'gloo\' backend.')
        _broadcast(rank, rows, columns)
        _all_reduce(rank, rows, columns)
        _barrier(rank)

        # this operation supported only on cpu
        if num_gpus == 0:
            _send_recv(rank, rows, columns)
    elif backend == 'nccl':
        print('Run operations supported by \'nccl\' backend.')
        # Note: nccl does not support gather or scatter as well:
        # https://github.com/pytorch/pytorch/blob/v0.4.0/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
        _broadcast(rank, rows, columns)
        _all_reduce(rank, rows, columns)
        _reduce(rank, rows, columns)
        _all_gather(rank, rows, columns) 
开发者ID:aws,项目名称:sagemaker-pytorch-training-toolkit,代码行数:21,代码来源:distributed_operations.py

示例12: init_distributed_mode

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0) 
开发者ID:lopuhin,项目名称:kaggle-kuzushiji-2019,代码行数:25,代码来源:utils.py

示例13: _get_train_data_loader

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
    logger.info("Get train data loader")
    dataset = datasets.MNIST(
        training_dir,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=False,  # True sets a dependency on an external site for our canaries.
    )
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train_sampler is None,
        sampler=train_sampler,
        **kwargs
    )
    return train_sampler, train_loader 
开发者ID:aws,项目名称:sagemaker-python-sdk,代码行数:23,代码来源:mnist.py

示例14: process_generic_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def process_generic_model(params: List, iters: int, has_early_stop: bool = False):
    """
    Runs a mock training with zero grads. This is due to a bug where the connection gets reset with custom new groups.
    :param params: The params of the model
    :param iters: Iterations.
    """
    # Hopefully this function can go away in newer versions.
    for i in range(iters):
        for p in params:
            z = torch.zeros(p)
            dist.all_reduce(z, op=torch.distributed.ReduceOp.SUM)

        if has_early_stop:
            dist.all_reduce(torch.tensor(0.0), op=torch.distributed.ReduceOp.SUM)
            zeros = torch.zeros(1)
            dist.all_reduce(zeros, op=torch.distributed.ReduceOp.SUM)
            if zeros.item() > 0:
                break 
开发者ID:dmmiller612,项目名称:sparktorch,代码行数:20,代码来源:distributed.py

示例15: _parse_losses

# 需要导入模块: import torch [as 别名]
# 或者: from torch import distributed [as 别名]
def _parse_losses(self, losses):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
                which may be a weighted sum of all losses, log_vars contains
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:36,代码来源:base.py


注:本文中的torch.distributed方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。