本文整理汇总了Python中torch.nn.parallel.DataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python parallel.DataParallel方法的具体用法?Python parallel.DataParallel怎么用?Python parallel.DataParallel使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.parallel
的用法示例。
在下文中一共展示了parallel.DataParallel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _decorate_model
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _decorate_model(self, parallel_decorate=True):
self.logging('='*20 + 'Decorate Model' + '='*20)
if self.setting.fp16:
self.model.half()
self.model.to(self.device)
self.logging('Set model device to {}'.format(str(self.device)))
if parallel_decorate:
if self.in_distributed_mode():
self.model = para.DistributedDataParallel(self.model,
device_ids=[self.setting.local_rank],
output_device=self.setting.local_rank)
self.logging('Wrap distributed data parallel')
# self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
elif self.n_gpu > 1:
self.model = para.DataParallel(self.model)
self.logging('Wrap data parallel')
else:
self.logging('Do not wrap parallel layers')
示例2: __init__
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(self, proto_model: Module,
gpu_ids_abs: Union[list, tuple] = (),
init_method: Union[str, FunctionType, None] = "kaiming",
show_structure=False,
check_point_pos=None, verbose=True):
# if not isinstance(proto_model, Module):
# raise TypeError(
# "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
self.model: Union[DataParallel, Module] = None
self.model_name = proto_model.__class__.__name__
self.weights_init = None
self.init_fc = None
self.init_name: str = None
self.num_params: int = 0
self.verbose = verbose
self.check_point_pos = check_point_pos
self.define(proto_model, gpu_ids_abs, init_method, show_structure)
示例3: __init__
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(self, *kargs, **kwargs):
super(NestedTrainer, self).__init__(*kargs, **kwargs)
self.model_with_loss = AddLossModule(self.model, self.criterion)
if self.distributed:
self.model_with_loss = DistributedDataParallel(
self.model_with_loss,
device_ids=[self.local_rank],
output_device=self.local_rank)
else:
if isinstance(self.device_ids, tuple):
self.model_with_loss = DataParallel(self.model_with_loss,
self.device_ids,
dim=0 if self.batch_first else 1)
_, target_tok = self.save_info['tokenizers'].values()
target_words = target_tok.common_words(8188)
self.contrast_batch = batch_nested_sequences(target_words)
示例4: save_checkpoint
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_checkpoint(model,
epoch,
num_iters,
out_dir,
filename_tmpl='epoch_{}.pth',
optimizer=None,
is_best=False):
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
filename = os.path.join(out_dir, filename_tmpl.format(epoch))
checkpoint = {
'epoch': epoch,
'num_iters': num_iters,
'state_dict': model_weights_to_cpu(model.state_dict())
}
if optimizer is not None:
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, filename)
latest_link = os.path.join(out_dir, 'latest.pth')
make_link(filename, latest_link)
if is_best:
best_link = os.path.join(out_dir, 'best.pth')
make_link(filename, best_link)
示例5: save_checkpoint
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_checkpoint(self, cpt_file_name=None, epoch=None):
self.logging('='*20 + 'Dump Checkpoint' + '='*20)
if cpt_file_name is None:
cpt_file_name = self.setting.cpt_file_name
cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
self.logging('Dump checkpoint into {}'.format(cpt_file_path))
store_dict = {
'setting': self.setting.__dict__,
}
if self.model:
if isinstance(self.model, para.DataParallel) or \
isinstance(self.model, para.DistributedDataParallel):
model_state = self.model.module.state_dict()
else:
model_state = self.model.state_dict()
store_dict['model_state'] = model_state
else:
self.logging('No model state is dumped', level=logging.WARNING)
if self.optimizer:
store_dict['optimizer_state'] = self.optimizer.state_dict()
else:
self.logging('No optimizer state is dumped', level=logging.WARNING)
if epoch:
store_dict['epoch'] = epoch
torch.save(store_dict, cpt_file_path)
示例6: print_network
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
# Generator
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
if self.is_train:
# Discriminator
s, n = self.get_network_description(self.netD)
if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
self.netD.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netD.__class__.__name__)
if self.rank <= 0:
logger.info('Network D structure: {}, with parameters: {:,d}'.format(
net_struc_str, n))
logger.info(s)
if self.cri_fea: # F, Perceptual Network
s, n = self.get_network_description(self.netF)
if isinstance(self.netF, nn.DataParallel) or isinstance(
self.netF, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
self.netF.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netF.__class__.__name__)
if self.rank <= 0:
logger.info('Network F structure: {}, with parameters: {:,d}'.format(
net_struc_str, n))
logger.info(s)
示例7: print_network
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
示例8: test_is_module_wrapper
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def test_is_module_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
assert not is_module_wrapper(model)
dp = DataParallel(model)
assert is_module_wrapper(dp)
mmdp = MMDataParallel(model)
assert is_module_wrapper(mmdp)
ddp = DistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(ddp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_module_wrapper(mmddp)
deprecated_mmddp = DeprecatedMMDDP(model)
assert is_module_wrapper(deprecated_mmddp)
# test module wrapper registry
@MODULE_WRAPPERS.register_module()
class ModuleWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
module_wraper = ModuleWrapper(model)
assert is_module_wrapper(module_wraper)
示例9: load_weights
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
"""Assemble a model and weights from paths or passing parameters.
You can load a model from a file, passing parameters or both.
:param weights: Pytorch weights or weights file path.
:param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
default:``True``
:return: ``module``
Example::
>>> from torchvision.models.resnet import resnet18
>>> model = Model(resnet18())
ResNet Total number of parameters: 11689512
ResNet model use CPU!
apply kaiming weight init!
>>> model.save_weights("model.pth",)
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("model.pth", True)
Try to remove `moudle.` to keys of weights dict
"""
if isinstance(weights, str):
weights = load(weights, map_location=lambda storage, loc: storage)
else:
raise TypeError("`weights` must be a `dict` or a path of weights file.")
if isinstance(self.model, DataParallel):
self._print("Try to add `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "add", False)
else:
self._print("Try to remove `moudle.` to keys of weights dict")
weights = self._fix_weights(weights, "remove", False)
self.model.load_state_dict(weights, strict=strict)
示例10: save_weights
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_weights(self, weights_path: str, fix_weights=True):
"""Save a model and weights to files.
You can save a model, weights or both to file.
.. note::
This method deal well with different devices on model saving.
You don' need to care about which devices your model have saved.
:param weights_path: Pytorch weights or weights file path.
:param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``.
without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
default:``False``
Example::
>>> from torch.nn import Linear
>>> model = Model(Linear(10,1))
Linear Total number of parameters: 11
Linear model use CPU!
apply kaiming weight init!
>>> model.save_weights("weights.pth")
try to remove 'module.' in keys of weights dict...
>>> model.load_weights("weights.pth")
Try to remove `moudle.` to keys of weights dict
"""
if fix_weights:
import copy
weights = copy.deepcopy(self.model.state_dict())
self._print("try to remove 'module.' in keys of weights dict...")
weights = self._fix_weights(weights, "remove", False)
else:
weights = self.model.state_dict()
save(weights, weights_path)
示例11: _set_device
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
if not gpu_ids_abs:
gpu_ids_abs = []
# old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"]
# os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
# gpu_ids = [i for i in range(len(gpu_ids_abs))]
gpu_available = torch.cuda.is_available()
model_name = proto_model.__class__.__name__
if len(gpu_ids_abs) == 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = proto_model.cuda(gpu_ids_abs[0])
self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs))
elif len(gpu_ids_abs) > 1:
if not gpu_available:
raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
"CUDA_VISIBLE_DEVICES=%s" % \
os.environ["CUDA_VISIBLE_DEVICES"])
proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs)
self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs))
else:
self._print("%s model use CPU!" % model_name)
return proto_model
示例12: configure
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def configure(self):
config_dic = dict()
if isinstance(self.model, DataParallel):
config_dic["model_name"] = str(self.model.module.__class__.__name__)
elif isinstance(self.model, Module):
config_dic["model_name"] = str(self.model.__class__.__name__)
else:
raise TypeError("Type of `self.model` is wrong!")
config_dic["init_method"] = str(self.init_name)
config_dic["total_params"] = self.num_params
config_dic["structure"] = str(self.model)
return config_dic
示例13: __init__
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(
self,
model: nn.Module,
dataset: Dataset = None,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.dataset = dataset
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
示例14: _load_model
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
self.logger.warning(
"'{}' has shape {} in the checkpoint but {} in the "
"model! Skipped.".format(
k, shape_checkpoint, shape_model
)
)
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(
checkpoint_state_dict, strict=False
)
if incompatible.missing_keys:
self.logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
self.logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
示例15: print_network
# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)