当前位置: 首页>>代码示例>>Python>>正文


Python parallel.DistributedDataParallel方法代码示例

本文整理汇总了Python中torch.nn.parallel.DistributedDataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python parallel.DistributedDataParallel方法的具体用法?Python parallel.DistributedDataParallel怎么用?Python parallel.DistributedDataParallel使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.parallel的用法示例。


在下文中一共展示了parallel.DistributedDataParallel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _decorate_model

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:23,代码来源:base_task.py

示例2: __init__

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def __init__(self, hp, net_arch, loss_f, rank=0, world_size=1):
        self.hp = hp
        self.device = self.hp.model.device
        self.net = net_arch.to(self.device)
        self.rank = rank
        self.world_size = world_size
        if self.device != "cpu" and self.world_size != 0:
            self.net = DDP(self.net, device_ids=[self.rank])
        self.input = None
        self.GT = None
        self.step = 0
        self.epoch = -1

        # init optimizer
        optimizer_mode = self.hp.train.optimizer.mode
        if optimizer_mode == "adam":
            self.optimizer = torch.optim.Adam(
                self.net.parameters(), **(self.hp.train.optimizer[optimizer_mode])
            )
        else:
            raise Exception("%s optimizer not supported" % optimizer_mode)

        # init loss
        self.loss_f = loss_f
        self.log = DotDict() 
开发者ID:ryul99,项目名称:pytorch-project-template,代码行数:27,代码来源:model.py

示例3: build_model

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def build_model(self):

        self.G = networks.get_generator(encoder=self.model_config.arch.encoder, decoder=self.model_config.arch.decoder)
        self.G.cuda()

        if CONFIG.dist:
            self.logger.info("Using pytorch synced BN")
            self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)

        self.G_optimizer = torch.optim.Adam(self.G.parameters(),
                                            lr = self.train_config.G_lr,
                                            betas = [self.train_config.beta1, self.train_config.beta2])

        if CONFIG.dist:
            # SyncBatchNorm only supports DistributedDataParallel with single GPU per process
            self.G = DistributedDataParallel(self.G, device_ids=[CONFIG.local_rank], output_device=CONFIG.local_rank)
        else:
            self.G = nn.DataParallel(self.G)

        self.build_lr_scheduler() 
开发者ID:Yaoyi-Li,项目名称:GCA-Matting,代码行数:22,代码来源:trainer.py

示例4: demo_basic

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_basic(local_world_size, local_rank):

    # setup devices for this process. For local_world_size = 2, num_gpus = 8,
    # rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // local_world_size
    device_ids = list(range(local_rank * n, (local_rank + 1) * n))

    print(
        f"[{os.getpid()}] rank = {dist.get_rank()}, "
        + f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
    )

    model = ToyModel().cuda(device_ids[0])
    ddp_model = DDP(model, device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step() 
开发者ID:pytorch,项目名称:examples,代码行数:26,代码来源:example.py

示例5: demo_basic

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup() 
开发者ID:pytorch,项目名称:examples,代码行数:20,代码来源:main.py

示例6: demo_model_parallel

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_model_parallel(rank, world_size):
    print(f"Running DDP with model parallel example on rank {rank}.")
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup() 
开发者ID:pytorch,项目名称:examples,代码行数:23,代码来源:main.py

示例7: data_parallel

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def data_parallel(self):
        """Wraps the model with PyTorch's DistributedDataParallel.  The
        intention is for rlpyt to create a separate Python process to drive
        each GPU (or CPU-group for CPU-only, MPI-like configuration). Agents
        with additional model components (beyond ``self.model``) which will
        have gradients computed through them should extend this method to wrap
        those, as well.

        Typically called in the runner during startup.
        """
        if self.device.type == "cpu":
            self.model = DDPC(self.model)
            logger.log("Initialized DistributedDataParallelCPU agent model.")
        else:
            self.model = DDP(self.model,
                device_ids=[self.device.index], output_device=self.device.index)
            logger.log("Initialized DistributedDataParallel agent model on "
                f"device {self.device}.") 
开发者ID:astooke,项目名称:rlpyt,代码行数:20,代码来源:base.py

示例8: main

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(
            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
        )

    do_train(cfg, model, resume=args.resume)
    return do_test(cfg, model) 
开发者ID:facebookresearch,项目名称:detectron2,代码行数:21,代码来源:plain_train_net.py

示例9: initialize_model

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def initialize_model(
    arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
    print(f"=> creating model: {arch}")
    model = models.__dict__[arch]()
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    model.cuda(device_id)
    cudnn.benchmark = True
    model = DistributedDataParallel(model, device_ids=[device_id])
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(device_id)
    optimizer = SGD(
        model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
    )
    return model, criterion, optimizer 
开发者ID:pytorch,项目名称:elastic,代码行数:19,代码来源:main.py

示例10: main

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def main(args):
    cfg = setup(args)

    model = build_model(cfg)
    logger.info("Model:\n{}".format(model))
    if args.eval_only:
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        return do_test(cfg, model)

    distributed = comm.get_world_size() > 1
    if distributed:
        model = DistributedDataParallel(
            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
        )

    do_train(cfg, model)
    return do_test(cfg, model) 
开发者ID:conansherry,项目名称:detectron2,代码行数:21,代码来源:plain_train_net.py

示例11: __init__

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def __init__(self, *kargs, **kwargs):
        super(NestedTrainer, self).__init__(*kargs, **kwargs)
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        if self.distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[self.local_rank],
                output_device=self.local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(self.model_with_loss,
                                                    self.device_ids,
                                                    dim=0 if self.batch_first else 1)
        _, target_tok = self.save_info['tokenizers'].values()
        target_words = target_tok.common_words(8188)
        self.contrast_batch = batch_nested_sequences(target_words) 
开发者ID:eladhoffer,项目名称:seq2seq.pytorch,代码行数:18,代码来源:trainer.py

示例12: save

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save(self, name, **kwargs):
        if not self.save_dir:
            return

        if not self.save_to_disk:
            return

        data = {}
        if isinstance(self.model, DistributedDataParallel):
            data['model'] = self.model.module.state_dict()
        else:
            data['model'] = self.model.state_dict()
        if self.optimizer is not None:
            data["optimizer"] = self.optimizer.state_dict()
        if self.scheduler is not None:
            data["scheduler"] = self.scheduler.state_dict()
        data.update(kwargs)

        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
        self.logger.info("Saving checkpoint to {}".format(save_file))
        torch.save(data, save_file)

        self.tag_last_checkpoint(save_file) 
开发者ID:lufficc,项目名称:SSD,代码行数:25,代码来源:checkpoint.py

示例13: load

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def load(self, f=None, use_latest=True):
        if self.has_checkpoint() and use_latest:
            # override argument with existing checkpoint
            f = self.get_checkpoint_file()
        if not f:
            # no checkpoint could be found
            self.logger.info("No checkpoint found.")
            return {}

        self.logger.info("Loading checkpoint from {}".format(f))
        checkpoint = self._load_file(f)
        model = self.model
        if isinstance(model, DistributedDataParallel):
            model = self.model.module

        model.load_state_dict(checkpoint.pop("model"))
        if "optimizer" in checkpoint and self.optimizer:
            self.logger.info("Loading optimizer from {}".format(f))
            self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
        if "scheduler" in checkpoint and self.scheduler:
            self.logger.info("Loading scheduler from {}".format(f))
            self.scheduler.load_state_dict(checkpoint.pop("scheduler"))

        # return any further checkpoint data
        return checkpoint 
开发者ID:lufficc,项目名称:SSD,代码行数:27,代码来源:checkpoint.py

示例14: save_checkpoint

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save_checkpoint(model,
                    epoch,
                    num_iters,
                    out_dir,
                    filename_tmpl='epoch_{}.pth',
                    optimizer=None,
                    is_best=False):
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        model = model.module
    filename = os.path.join(out_dir, filename_tmpl.format(epoch))
    checkpoint = {
        'epoch': epoch,
        'num_iters': num_iters,
        'state_dict': model_weights_to_cpu(model.state_dict())
    }
    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    torch.save(checkpoint, filename)
    latest_link = os.path.join(out_dir, 'latest.pth')
    make_link(filename, latest_link)
    if is_best:
        best_link = os.path.join(out_dir, 'best.pth')
        make_link(filename, best_link) 
开发者ID:hellock,项目名称:torchpack,代码行数:27,代码来源:io.py

示例15: save_checkpoint

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save_checkpoint(self, cpt_file_name=None, epoch=None):
        self.logging('='*20 + 'Dump Checkpoint' + '='*20)
        if cpt_file_name is None:
            cpt_file_name = self.setting.cpt_file_name
        cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
        self.logging('Dump checkpoint into {}'.format(cpt_file_path))

        store_dict = {
            'setting': self.setting.__dict__,
        }

        if self.model:
            if isinstance(self.model, para.DataParallel) or \
                    isinstance(self.model, para.DistributedDataParallel):
                model_state = self.model.module.state_dict()
            else:
                model_state = self.model.state_dict()
            store_dict['model_state'] = model_state
        else:
            self.logging('No model state is dumped', level=logging.WARNING)

        if self.optimizer:
            store_dict['optimizer_state'] = self.optimizer.state_dict()
        else:
            self.logging('No optimizer state is dumped', level=logging.WARNING)

        if epoch:
            store_dict['epoch'] = epoch

        torch.save(store_dict, cpt_file_path) 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:32,代码来源:base_task.py


注:本文中的torch.nn.parallel.DistributedDataParallel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。