当前位置: 首页>>代码示例>>Python>>正文


Python parallel.DataParallel方法代码示例

本文整理汇总了Python中torch.nn.parallel.DataParallel方法的典型用法代码示例。如果您正苦于以下问题:Python parallel.DataParallel方法的具体用法?Python parallel.DataParallel怎么用?Python parallel.DataParallel使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.parallel的用法示例。


在下文中一共展示了parallel.DataParallel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _decorate_model

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _decorate_model(self, parallel_decorate=True):
        self.logging('='*20 + 'Decorate Model' + '='*20)

        if self.setting.fp16:
            self.model.half()

        self.model.to(self.device)
        self.logging('Set model device to {}'.format(str(self.device)))

        if parallel_decorate:
            if self.in_distributed_mode():
                self.model = para.DistributedDataParallel(self.model,
                                                          device_ids=[self.setting.local_rank],
                                                          output_device=self.setting.local_rank)
                self.logging('Wrap distributed data parallel')
                # self.logging('In Distributed Mode, but do not use DistributedDataParallel Wrapper')
            elif self.n_gpu > 1:
                self.model = para.DataParallel(self.model)
                self.logging('Wrap data parallel')
        else:
            self.logging('Do not wrap parallel layers') 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:23,代码来源:base_task.py

示例2: __init__

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(self, proto_model: Module,
                 gpu_ids_abs: Union[list, tuple] = (),
                 init_method: Union[str, FunctionType, None] = "kaiming",
                 show_structure=False,
                 check_point_pos=None, verbose=True):

        # if not isinstance(proto_model, Module):
        #     raise TypeError(
        #         "The type of `proto_model` must be `torch.nn.Module`, but got %s instead" % type(proto_model))
        self.model: Union[DataParallel, Module] = None
        self.model_name = proto_model.__class__.__name__
        self.weights_init = None
        self.init_fc = None
        self.init_name: str = None
        self.num_params: int = 0
        self.verbose = verbose
        self.check_point_pos = check_point_pos
        self.define(proto_model, gpu_ids_abs, init_method, show_structure) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:20,代码来源:model.py

示例3: __init__

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(self, *kargs, **kwargs):
        super(NestedTrainer, self).__init__(*kargs, **kwargs)
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        if self.distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[self.local_rank],
                output_device=self.local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(self.model_with_loss,
                                                    self.device_ids,
                                                    dim=0 if self.batch_first else 1)
        _, target_tok = self.save_info['tokenizers'].values()
        target_words = target_tok.common_words(8188)
        self.contrast_batch = batch_nested_sequences(target_words) 
开发者ID:eladhoffer,项目名称:seq2seq.pytorch,代码行数:18,代码来源:trainer.py

示例4: save_checkpoint

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_checkpoint(model,
                    epoch,
                    num_iters,
                    out_dir,
                    filename_tmpl='epoch_{}.pth',
                    optimizer=None,
                    is_best=False):
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    if isinstance(model, (DataParallel, DistributedDataParallel)):
        model = model.module
    filename = os.path.join(out_dir, filename_tmpl.format(epoch))
    checkpoint = {
        'epoch': epoch,
        'num_iters': num_iters,
        'state_dict': model_weights_to_cpu(model.state_dict())
    }
    if optimizer is not None:
        checkpoint['optimizer'] = optimizer.state_dict()
    torch.save(checkpoint, filename)
    latest_link = os.path.join(out_dir, 'latest.pth')
    make_link(filename, latest_link)
    if is_best:
        best_link = os.path.join(out_dir, 'best.pth')
        make_link(filename, best_link) 
开发者ID:hellock,项目名称:torchpack,代码行数:27,代码来源:io.py

示例5: save_checkpoint

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_checkpoint(self, cpt_file_name=None, epoch=None):
        self.logging('='*20 + 'Dump Checkpoint' + '='*20)
        if cpt_file_name is None:
            cpt_file_name = self.setting.cpt_file_name
        cpt_file_path = os.path.join(self.setting.model_dir, cpt_file_name)
        self.logging('Dump checkpoint into {}'.format(cpt_file_path))

        store_dict = {
            'setting': self.setting.__dict__,
        }

        if self.model:
            if isinstance(self.model, para.DataParallel) or \
                    isinstance(self.model, para.DistributedDataParallel):
                model_state = self.model.module.state_dict()
            else:
                model_state = self.model.state_dict()
            store_dict['model_state'] = model_state
        else:
            self.logging('No model state is dumped', level=logging.WARNING)

        if self.optimizer:
            store_dict['optimizer_state'] = self.optimizer.state_dict()
        else:
            self.logging('No optimizer state is dumped', level=logging.WARNING)

        if epoch:
            store_dict['epoch'] = epoch

        torch.save(store_dict, cpt_file_path) 
开发者ID:dolphin-zs,项目名称:Doc2EDAG,代码行数:32,代码来源:base_task.py

示例6: print_network

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
开发者ID:xinntao,项目名称:BasicSR,代码行数:39,代码来源:SRGAN_model.py

示例7: print_network

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
开发者ID:xinntao,项目名称:BasicSR,代码行数:12,代码来源:SR_model.py

示例8: test_is_module_wrapper

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def test_is_module_wrapper():

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(2, 2, 1)

        def forward(self, x):
            return self.conv(x)

    model = Model()
    assert not is_module_wrapper(model)

    dp = DataParallel(model)
    assert is_module_wrapper(dp)

    mmdp = MMDataParallel(model)
    assert is_module_wrapper(mmdp)

    ddp = DistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(ddp)

    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
    assert is_module_wrapper(mmddp)

    deprecated_mmddp = DeprecatedMMDDP(model)
    assert is_module_wrapper(deprecated_mmddp)

    # test module wrapper registry
    @MODULE_WRAPPERS.register_module()
    class ModuleWrapper(object):

        def __init__(self, module):
            self.module = module

        def forward(self, *args, **kwargs):
            return self.module(*args, **kwargs)

    module_wraper = ModuleWrapper(model)
    assert is_module_wrapper(module_wraper) 
开发者ID:open-mmlab,项目名称:mmcv,代码行数:43,代码来源:test_parallel.py

示例9: load_weights

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def load_weights(self, weights: Union[OrderedDict, dict, str], strict=True):
        """Assemble a model and weights from paths or passing parameters.

        You can load a model from a file, passing parameters or both.

        :param weights: Pytorch weights or weights file path.
        :param strict: The same function in pytorch ``model.load_state_dict(weights,strict = strict)`` .
         default:``True``
        :return: ``module``

        Example::

            >>> from torchvision.models.resnet import resnet18
            >>> model = Model(resnet18())
            ResNet Total number of parameters: 11689512
            ResNet model use CPU!
            apply kaiming weight init!
            >>> model.save_weights("model.pth",)
            try to remove 'module.' in keys of weights dict...
            >>> model.load_weights("model.pth", True)
            Try to remove `moudle.` to keys of weights dict

        """
        if isinstance(weights, str):
            weights = load(weights, map_location=lambda storage, loc: storage)
        else:
            raise TypeError("`weights` must be a `dict` or a path of weights file.")
        if isinstance(self.model, DataParallel):
            self._print("Try to add `moudle.` to keys of weights dict")
            weights = self._fix_weights(weights, "add", False)
        else:
            self._print("Try to remove `moudle.` to keys of weights dict")
            weights = self._fix_weights(weights, "remove", False)
        self.model.load_state_dict(weights, strict=strict) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:36,代码来源:model.py

示例10: save_weights

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def save_weights(self, weights_path: str, fix_weights=True):
        """Save a model and weights to files.

        You can save a model, weights or both to file.

        .. note::

            This method deal well with different devices on model saving.
            You don' need to care about which devices your model have saved.

        :param weights_path: Pytorch weights or weights file path.
        :param fix_weights: If this is true, it will remove the '.module' in keys, when you save a ``DataParallel``.
         without any moving operation. Otherwise, it will move to cpu, especially in ``DataParallel``.
         default:``False``

        Example::

           >>> from torch.nn import Linear
           >>> model = Model(Linear(10,1))
           Linear Total number of parameters: 11
           Linear model use CPU!
           apply kaiming weight init!
           >>> model.save_weights("weights.pth")
           try to remove 'module.' in keys of weights dict...
           >>> model.load_weights("weights.pth")
           Try to remove `moudle.` to keys of weights dict

        """
        if fix_weights:
            import copy
            weights = copy.deepcopy(self.model.state_dict())
            self._print("try to remove 'module.' in keys of weights dict...")
            weights = self._fix_weights(weights, "remove", False)
        else:
            weights = self.model.state_dict()

        save(weights, weights_path) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:39,代码来源:model.py

示例11: _set_device

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _set_device(self, proto_model: Module, gpu_ids_abs: list) -> Union[Module, DataParallel]:
        if not gpu_ids_abs:
            gpu_ids_abs = []
        # old_enviroment = os.environ["CUDA_VISIBLE_DEVICES"]
        # os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in gpu_ids_abs])
        # gpu_ids = [i for i in range(len(gpu_ids_abs))]
        gpu_available = torch.cuda.is_available()
        model_name = proto_model.__class__.__name__

        if len(gpu_ids_abs) == 1:
            if not gpu_available:
                raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
                                       "CUDA_VISIBLE_DEVICES=%s" % \
                                       os.environ["CUDA_VISIBLE_DEVICES"])

            proto_model = proto_model.cuda(gpu_ids_abs[0])
            self._print("%s model use GPU %s!" % (model_name, gpu_ids_abs))
        elif len(gpu_ids_abs) > 1:
            if not gpu_available:
                raise EnvironmentError("No gpu available! torch.cuda.is_available() is False. "
                                       "CUDA_VISIBLE_DEVICES=%s" % \
                                       os.environ["CUDA_VISIBLE_DEVICES"])
            proto_model = DataParallel(proto_model.cuda(gpu_ids_abs[0]), gpu_ids_abs)
            self._print("%s dataParallel use GPUs%s!" % (model_name, gpu_ids_abs))
        else:
            self._print("%s model use CPU!" % model_name)
        return proto_model 
开发者ID:dingguanglei,项目名称:jdit,代码行数:29,代码来源:model.py

示例12: configure

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def configure(self):
        config_dic = dict()
        if isinstance(self.model, DataParallel):
            config_dic["model_name"] = str(self.model.module.__class__.__name__)
        elif isinstance(self.model, Module):
            config_dic["model_name"] = str(self.model.__class__.__name__)
        else:
            raise TypeError("Type of `self.model` is wrong!")
        config_dic["init_method"] = str(self.init_name)
        config_dic["total_params"] = self.num_params
        config_dic["structure"] = str(self.model)
        return config_dic 
开发者ID:dingguanglei,项目名称:jdit,代码行数:14,代码来源:model.py

示例13: __init__

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def __init__(
            self,
            model: nn.Module,
            dataset: Dataset = None,
            save_dir: str = "",
            *,
            save_to_disk: bool = True,
            **checkpointables: object,
    ):
        """
        Args:
            model (nn.Module): model.
            save_dir (str): a directory to save and find checkpoints.
            save_to_disk (bool): if True, save checkpoint to disk, otherwise
                disable saving for this checkpointer.
            checkpointables (object): any checkpointable objects, i.e., objects
                that have the `state_dict()` and `load_state_dict()` method. For
                example, it can be used like
                `Checkpointer(model, "dir", optimizer=optimizer)`.
        """
        if isinstance(model, (DistributedDataParallel, DataParallel)):
            model = model.module
        self.model = model
        self.dataset = dataset
        self.checkpointables = copy.copy(checkpointables)
        self.logger = logging.getLogger(__name__)
        self.save_dir = save_dir
        self.save_to_disk = save_to_disk 
开发者ID:JDAI-CV,项目名称:fast-reid,代码行数:30,代码来源:checkpoint.py

示例14: _load_model

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def _load_model(self, checkpoint: Any):
        """
        Load weights from a checkpoint.
        Args:
            checkpoint (Any): checkpoint contains the weights.
        """
        checkpoint_state_dict = checkpoint.pop("model")
        self._convert_ndarray_to_tensor(checkpoint_state_dict)

        # if the state_dict comes from a model that was wrapped in a
        # DataParallel or DistributedDataParallel during serialization,
        # remove the "module" prefix before performing the matching.
        _strip_prefix_if_present(checkpoint_state_dict, "module.")

        # work around https://github.com/pytorch/pytorch/issues/24139
        model_state_dict = self.model.state_dict()
        for k in list(checkpoint_state_dict.keys()):
            if k in model_state_dict:
                shape_model = tuple(model_state_dict[k].shape)
                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                if shape_model != shape_checkpoint:
                    self.logger.warning(
                        "'{}' has shape {} in the checkpoint but {} in the "
                        "model! Skipped.".format(
                            k, shape_checkpoint, shape_model
                        )
                    )
                    checkpoint_state_dict.pop(k)

        incompatible = self.model.load_state_dict(
            checkpoint_state_dict, strict=False
        )
        if incompatible.missing_keys:
            self.logger.info(
                get_missing_parameters_message(incompatible.missing_keys)
            )
        if incompatible.unexpected_keys:
            self.logger.info(
                get_unexpected_parameters_message(incompatible.unexpected_keys)
            ) 
开发者ID:JDAI-CV,项目名称:fast-reid,代码行数:42,代码来源:checkpoint.py

示例15: print_network

# 需要导入模块: from torch.nn import parallel [as 别名]
# 或者: from torch.nn.parallel import DataParallel [as 别名]
def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s) 
开发者ID:xinntao,项目名称:EDVR,代码行数:12,代码来源:Video_base_model.py


注:本文中的torch.nn.parallel.DataParallel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。