本文整理汇总了Python中torch.nn.parallel.DistributedDataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python parallel.DistributedDataParallel方法的具体用法?Python parallel.DistributedDataParallel怎么用?Python parallel.DistributedDataParallel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.parallel
的用法示例。
在下文中一共展示了parallel.DistributedDataParallel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _decorate_model
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def _decorate_model(self, parallel_decorate=True):
self.logging('='*20 + 'Decorate Model' + '='*20)
if self.setting.fp16:
self.model.half()
self.model.to(self.device)
self.logging('Set model device to {}'.format(str(self.device)))
if parallel_decorate:
if self.in_distributed_mode():
self.model = para.DistributedDataParallel(self.model,
device_ids=[self.setting.local_rank],
output_device=self.setting.local_rank)
self.logging('Wrap distributed data parallel')
# self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
elif self.n_gpu > 1:
self.model = para.DataParallel(self.model)
self.logging('Wrap data parallel')
else:
self.logging('Do not wrap parallel layers')
示例2: __init__
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def __init__(self, hp, net_arch, loss_f, rank=0, world_size=1):
self.hp = hp
self.device = self.hp.model.device
self.net = net_arch.to(self.device)
self.rank = rank
self.world_size = world_size
if self.device != "cpu" and self.world_size != 0:
self.net = DDP(self.net, device_ids=[self.rank])
self.input = None
self.GT = None
self.step = 0
self.epoch = -1
# init optimizer
optimizer_mode = self.hp.train.optimizer.mode
if optimizer_mode == "adam":
self.optimizer = torch.optim.Adam(
self.net.parameters(), **(self.hp.train.optimizer[optimizer_mode])
)
else:
raise Exception("%s optimizer not supported" % optimizer_mode)
# init loss
self.loss_f = loss_f
self.log = DotDict()
示例3: build_model
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def build_model(self):
self.G = networks.get_generator(encoder=self.model_config.arch.encoder, decoder=self.model_config.arch.decoder)
self.G.cuda()
if CONFIG.dist:
self.logger.info("Using pytorch synced BN")
self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)
self.G_optimizer = torch.optim.Adam(self.G.parameters(),
lr = self.train_config.G_lr,
betas = [self.train_config.beta1, self.train_config.beta2])
if CONFIG.dist:
# SyncBatchNorm only supports DistributedDataParallel with single GPU per process
self.G = DistributedDataParallel(self.G, device_ids=[CONFIG.local_rank], output_device=CONFIG.local_rank)
else:
self.G = nn.DataParallel(self.G)
self.build_lr_scheduler()
示例4: demo_basic
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_basic(local_world_size, local_rank):
# setup devices for this process. For local_world_size = 2, num_gpus = 8,
# rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // local_world_size
device_ids = list(range(local_rank * n, (local_rank + 1) * n))
print(
f"[{os.getpid()}] rank = {dist.get_rank()}, "
+ f"world_size = {dist.get_world_size()}, n = {n}, device_ids = {device_ids}"
)
model = ToyModel().cuda(device_ids[0])
ddp_model = DDP(model, device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
示例5: demo_basic
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
示例6: demo_model_parallel
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
示例7: data_parallel
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def data_parallel(self):
"""Wraps the model with PyTorch's DistributedDataParallel. The
intention is for rlpyt to create a separate Python process to drive
each GPU (or CPU-group for CPU-only, MPI-like configuration). Agents
with additional model components (beyond ``self.model``) which will
have gradients computed through them should extend this method to wrap
those, as well.
Typically called in the runner during startup.
"""
if self.device.type == "cpu":
self.model = DDPC(self.model)
logger.log("Initialized DistributedDataParallelCPU agent model.")
else:
self.model = DDP(self.model,
device_ids=[self.device.index], output_device=self.device.index)
logger.log("Initialized DistributedDataParallel agent model on "
f"device {self.device}.")
示例8: main
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def main(args):
cfg = setup(args)
model = build_model(cfg)
logger.info("Model:\n{}".format(model))
if args.eval_only:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return do_test(cfg, model)
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
do_train(cfg, model, resume=args.resume)
return do_test(cfg, model)
示例9: initialize_model
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def initialize_model(
arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
print(f"=> creating model: {arch}")
model = models.__dict__[arch]()
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
model.cuda(device_id)
cudnn.benchmark = True
model = DistributedDataParallel(model, device_ids=[device_id])
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(device_id)
optimizer = SGD(
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
)
return model, criterion, optimizer
示例10: main
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def main(args):
cfg = setup(args)
model = build_model(cfg)
logger.info("Model:\n{}".format(model))
if args.eval_only:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return do_test(cfg, model)
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
do_train(cfg, model)
return do_test(cfg, model)
示例11: __init__
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def __init__(self, *kargs, **kwargs):
super(NestedTrainer, self).__init__(*kargs, **kwargs)
self.model_with_loss = AddLossModule(self.model, self.criterion)
if self.distributed:
self.model_with_loss = DistributedDataParallel(
self.model_with_loss,
device_ids=[self.local_rank],
output_device=self.local_rank)
else:
if isinstance(self.device_ids, tuple):
self.model_with_loss = DataParallel(self.model_with_loss,
self.device_ids,
dim=0 if self.batch_first else 1)
_, target_tok = self.save_info['tokenizers'].values()
target_words = target_tok.common_words(8188)
self.contrast_batch = batch_nested_sequences(target_words)
示例12: save
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save(self, name, **kwargs):
if not self.save_dir:
return
if not self.save_to_disk:
return
data = {}
if isinstance(self.model, DistributedDataParallel):
data['model'] = self.model.module.state_dict()
else:
data['model'] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
data.update(kwargs)
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
self.tag_last_checkpoint(save_file)
示例13: load
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def load(self, f=None, use_latest=True):
if self.has_checkpoint() and use_latest:
# override argument with existing checkpoint
f = self.get_checkpoint_file()
if not f:
# no checkpoint could be found
self.logger.info("No checkpoint found.")
return {}
self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
model = self.model
if isinstance(model, DistributedDataParallel):
model = self.model.module
model.load_state_dict(checkpoint.pop("model"))
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
self.logger.info("Loading scheduler from {}".format(f))
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# return any further checkpoint data
return checkpoint
示例14: save_checkpoint
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save_checkpoint(model,
epoch,
num_iters,
out_dir,
filename_tmpl='epoch_{}.pth',
optimizer=None,
is_best=False):
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
filename = os.path.join(out_dir, filename_tmpl.format(epoch))
checkpoint = {
'epoch': epoch,
'num_iters': num_iters,
'state_dict': model_weights_to_cpu(model.state_dict())
}
if optimizer is not None:
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, filename)
latest_link = os.path.join(out_dir, 'latest.pth')
make_link(filename, latest_link)
if is_best:
best_link = os.path.join(out_dir, 'best.pth')
make_link(filename, best_link)
示例15: save_checkpoint
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DistributedDataParallel [as 别名]
def save_checkpoint(self, cpt_file_name=None, epoch=None):
self.logging('='*20 + 'Dump Checkpoint' + '='*20)
if cpt_file_name is None:
cpt_file_name = self.setting.cpt_file_name
cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
self.logging('Dump checkpoint into {}'.format(cpt_file_path))
store_dict = {
'setting': self.setting.__dict__,
}
if self.model:
if isinstance(self.model, para.DataParallel) or \
isinstance(self.model, para.DistributedDataParallel):
model_state = self.model.module.state_dict()
else:
model_state = self.model.state_dict()
store_dict['model_state'] = model_state
else:
self.logging('No model state is dumped', level=logging.WARNING)
if self.optimizer:
store_dict['optimizer_state'] = self.optimizer.state_dict()
else:
self.logging('No optimizer state is dumped', level=logging.WARNING)
if epoch:
store_dict['epoch'] = epoch
torch.save(store_dict, cpt_file_path)