当前位置: 首页>>代码示例>>Python>>正文


Python serialization.default_restore_location方法代码示例

本文整理汇总了Python中torch.serialization.default_restore_location方法的典型用法代码示例。如果您正苦于以下问题:Python serialization.default_restore_location方法的具体用法?Python serialization.default_restore_location怎么用?Python serialization.default_restore_location使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.serialization的用法示例。


在下文中一共展示了serialization.default_restore_location方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_model_state

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        #model.load_state_dict(state['model'], strict=True)
        if (state['args'].arch == 'convlm'): # fix parameter name mismatch
            for paramname in list(state['model'].keys()):
                state['model'][paramname.replace('layers','convolutions')] = state['model'].pop(paramname)
        model_state = model.state_dict()
        print('| mismatched parameters: {}'.format(set(model_state.keys()) ^ set (state['model'].keys())))
        model_state.update(state['model'])
        model.load_state_dict(model_state)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:24,代码来源:utils.py

示例2: _load_single_fairseq_checkpoint

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def _load_single_fairseq_checkpoint(self, path: str) -> Dict[str, Any]:
        """
        Loads a fairseq model from file.

        :param path:
            path to file

        :return state:
            loaded fairseq state
        """
        with open(path, "rb") as f:
            try:
                state = torch.load(
                    f, map_location=lambda s, l: default_restore_location(s, "cpu")
                )
            except ModuleNotFoundError:
                raise ModuleNotFoundError(
                    "Please install fairseq: https://github.com/pytorch/fairseq#requirements-and-installation"
                )

        return state 
开发者ID:facebookresearch,项目名称:ParlAI,代码行数:23,代码来源:convert_fairseq_to_parlai.py

示例3: load_model_state

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_model_state(filename, model, cuda_device=None):
    if not os.path.exists(filename):
        return None, [], None
    if cuda_device is None:
        state = torch.load(filename)
    else:
        state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
        )
    state = _upgrade_state_dict(state)
    state['model'] = model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'])
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
开发者ID:EdinburghNLP,项目名称:XSum,代码行数:23,代码来源:utils.py

示例4: load_ensemble_for_inference

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
    """Load an ensemble of models for inference.

    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        state = _upgrade_state_dict(state)
        states.append(state)
    args = states[0]['args']
    if model_arg_overrides is not None:
        args = _override_model_args(args, model_arg_overrides)

    # build ensemble
    ensemble = []
    for state in states:
        model = task.build_model(state['args'])
        model.upgrade_state_dict(state['model'])
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)
    return ensemble, args 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:29,代码来源:utils.py

示例5: load_checkpoint_to_cpu

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    with PathManager.open(path, "rb") as f:
        state = torch.load(
            f, map_location=lambda s, l: default_restore_location(s, "cpu")
        )

    args = state["args"]
    if arg_overrides is not None:
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)
    state = _upgrade_state_dict(state)
    return state 
开发者ID:pytorch,项目名称:fairseq,代码行数:15,代码来源:checkpoint_utils.py

示例6: load_check_point

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_check_point(self, name):
        path = os.path.join(self.cp_save_path, name)
        self.logger.info('reload best model from %s', path)
        model_load = torch.load(
            path,
            map_location=lambda s, l: default_restore_location(s, "cpu"))
        if not isinstance(model_load, dict):
            model_load = model_load.state_dict()
        self.model.load_state_dict(model_load) 
开发者ID:fastnlp,项目名称:fastNLP,代码行数:11,代码来源:dist_trainer.py

示例7: load_model_state

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_model_state(filename, model):
    if not os.path.exists(filename):
        return None, [], None
    state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    model.upgrade_state_dict(state['model'])

    # load model parameters
    try:
        model.load_state_dict(state['model'], strict=True)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')

    return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:17,代码来源:utils.py

示例8: load_ensemble_for_inference

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
    """Load an ensemble of models for inference.

    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        state = _upgrade_state_dict(state)
        states.append(state)

    ensemble = []
    for state in states:
        args = state['args']
        
        if model_arg_overrides is not None:
            args = _override_model_args(args, model_arg_overrides)

        # build model for ensemble
        model = task.build_model(args)
        model.upgrade_state_dict(state['model'])
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)

    return ensemble, args 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:32,代码来源:utils.py

示例9: load_ensemble_for_inference

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
    """Load an ensemble of models for inference.
    model_arg_overrides allows you to pass a dictionary model_arg_overrides --
    {'arg_name': arg} -- to override model args that were used during model
    training
    """
    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        state = _upgrade_state_dict(state)
        states.append(state)
    args = states[0]['args']
    if model_arg_overrides is not None:
        args = _override_model_args(args, model_arg_overrides)

    # build ensemble
    ensemble = []
    for state in states:
        model = task.build_model(args)
        model.upgrade_state_dict(state['model'])
        model.load_state_dict(state['model'], strict=True)
        ensemble.append(model)
    return ensemble, args 
开发者ID:facebookresearch,项目名称:inversecooking,代码行数:28,代码来源:utils.py

示例10: load_ensemble_for_inference

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
    """Load an ensemble of models for inference.

    The source and target dictionaries can be given explicitly, or loaded from
    the `data_dir` directory.
    """
    from fairseq import data, models

    # load model architectures and weights
    states = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError('Model file not found: {}'.format(filename))
        states.append(
            torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        )
    args = states[0]['args']
    args = _upgrade_args(args)

    if src_dict is None or dst_dict is None:
        assert data_dir is not None
        src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)

    # build ensemble
    ensemble = []
    for state in states:
        model = models.build_model(args, src_dict, dst_dict)
        model.load_state_dict(state['model'])
        ensemble.append(model)
    return ensemble, args 
开发者ID:EdinburghNLP,项目名称:XSum,代码行数:32,代码来源:utils.py

示例11: build_model

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def build_model(self, args):
        model = super().build_model(args)
        if args.pretrained is not None: # load pretrained model:
            if not os.path.exists(args.pretrained):
                raise ValueError('Could not load pretrained weights \
                                 - from {}'.format(args.pretrained))
            from torch.serialization import default_restore_location
            saved_state = torch.load(
                args.pretrained, 
                map_location=lambda s, l: default_restore_location(s, 'cpu')
            )
            self.adapt_state(saved_state['model'], model)

        return model 
开发者ID:elbayadm,项目名称:attn2d,代码行数:16,代码来源:dynamic_simultaneous_translation.py

示例12: load_checkpoint_to_cpu

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def load_checkpoint_to_cpu(path):
    state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    state = _upgrade_state_dict(state)
    return state 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:6,代码来源:utils.py

示例13: import_individual_models

# 需要导入模块: from torch import serialization [as 别名]
# 或者: from torch.serialization import default_restore_location [as 别名]
def import_individual_models(restore_files, trainer):
    param2size = {}
    for name, param in trainer.model.named_parameters():
        param2size[name] = param.size()
    cuda_device = torch.cuda.current_device()
    model_state = {}
    for idx, filename in enumerate(restore_files):
        sub_state = torch.load(
            filename,
            map_location=lambda s, l: default_restore_location(
                s, "cuda:{}".format(cuda_device)
            ),
        )
        for name, value in sub_state["model"].items():
            new_name = None
            if name.startswith("encoder."):
                subname = name[8:]
                new_name = f"encoder.encoders.{idx}.{subname}"
            elif name == "decoder.output_projection_w":
                new_name = (
                    f"decoder.combi_strat.output_projections.{idx}."
                    f"output_projection_w"
                )
            elif name == "decoder.output_projection_b":
                new_name = (
                    f"decoder.combi_strat.output_projections.{idx}."
                    f"output_projection_b"
                )
            elif name.startswith("decoder."):
                subname = name[8:]
                new_name = f"decoder.decoders.{idx}.{subname}"
            if new_name is None:
                print(f"WARN: Ignoring {name} in {filename} (no match)")
            elif new_name not in param2size:
                print(f"WARN: Could not find {new_name}. Check architectures")
            elif value.size() != param2size[new_name]:
                print(
                    f"WARN: Tried to map {name} to {new_name}, but sizes do not match "
                    f"({value.size()} != {param2size[new_name]})"
                )
            else:
                model_state[new_name] = value
    trainer.model.load_state_dict(model_state, strict=False)
    print(f"|  Imported {len(model_state)} parameters.")
    trainer._optim_history = [] 
开发者ID:pytorch,项目名称:translate,代码行数:47,代码来源:multi_model.py


注:本文中的torch.serialization.default_restore_location方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。