本文整理匯總了Python中torch.serialization.default_restore_location方法的典型用法代碼示例。如果您正苦於以下問題:Python serialization.default_restore_location方法的具體用法?Python serialization.default_restore_location怎麽用?Python serialization.default_restore_location使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.serialization
的用法示例。
在下文中一共展示了serialization.default_restore_location方法的13個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: load_model_state
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
#model.load_state_dict(state['model'], strict=True)
if (state['args'].arch == 'convlm'): # fix parameter name mismatch
for paramname in list(state['model'].keys()):
state['model'][paramname.replace('layers','convolutions')] = state['model'].pop(paramname)
model_state = model.state_dict()
print('| mismatched parameters: {}'.format(set(model_state.keys()) ^ set (state['model'].keys())))
model_state.update(state['model'])
model.load_state_dict(model_state)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
示例2: _load_single_fairseq_checkpoint
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def _load_single_fairseq_checkpoint(self, path: str) -> Dict[str, Any]:
"""
Loads a fairseq model from file.
:param path:
path to file
:return state:
loaded fairseq state
"""
with open(path, "rb") as f:
try:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, "cpu")
)
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install fairseq: https://github.com/pytorch/fairseq#requirements-and-installation"
)
return state
示例3: load_model_state
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_model_state(filename, model, cuda_device=None):
if not os.path.exists(filename):
return None, [], None
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'])
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
示例4: load_ensemble_for_inference
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build ensemble
ensemble = []
for state in states:
model = task.build_model(state['args'])
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
示例5: load_checkpoint_to_cpu
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
with PathManager.open(path, "rb") as f:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, "cpu")
)
args = state["args"]
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
state = _upgrade_state_dict(state)
return state
示例6: load_check_point
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_check_point(self, name):
path = os.path.join(self.cp_save_path, name)
self.logger.info('reload best model from %s', path)
model_load = torch.load(
path,
map_location=lambda s, l: default_restore_location(s, "cpu"))
if not isinstance(model_load, dict):
model_load = model_load.state_dict()
self.model.load_state_dict(model_load)
示例7: load_model_state
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
示例8: load_ensemble_for_inference
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
示例9: load_ensemble_for_inference
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build ensemble
ensemble = []
for state in states:
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
示例10: load_ensemble_for_inference
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
from fairseq import data, models
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
args = states[0]['args']
args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble, args
示例11: build_model
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def build_model(self, args):
model = super().build_model(args)
if args.pretrained is not None: # load pretrained model:
if not os.path.exists(args.pretrained):
raise ValueError('Could not load pretrained weights \
- from {}'.format(args.pretrained))
from torch.serialization import default_restore_location
saved_state = torch.load(
args.pretrained,
map_location=lambda s, l: default_restore_location(s, 'cpu')
)
self.adapt_state(saved_state['model'], model)
return model
示例12: load_checkpoint_to_cpu
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def load_checkpoint_to_cpu(path):
state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
return state
示例13: import_individual_models
# 需要導入模塊: from torch import serialization [as 別名]
# 或者: from torch.serialization import default_restore_location [as 別名]
def import_individual_models(restore_files, trainer):
param2size = {}
for name, param in trainer.model.named_parameters():
param2size[name] = param.size()
cuda_device = torch.cuda.current_device()
model_state = {}
for idx, filename in enumerate(restore_files):
sub_state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(
s, "cuda:{}".format(cuda_device)
),
)
for name, value in sub_state["model"].items():
new_name = None
if name.startswith("encoder."):
subname = name[8:]
new_name = f"encoder.encoders.{idx}.{subname}"
elif name == "decoder.output_projection_w":
new_name = (
f"decoder.combi_strat.output_projections.{idx}."
f"output_projection_w"
)
elif name == "decoder.output_projection_b":
new_name = (
f"decoder.combi_strat.output_projections.{idx}."
f"output_projection_b"
)
elif name.startswith("decoder."):
subname = name[8:]
new_name = f"decoder.decoders.{idx}.{subname}"
if new_name is None:
print(f"WARN: Ignoring {name} in {filename} (no match)")
elif new_name not in param2size:
print(f"WARN: Could not find {new_name}. Check architectures")
elif value.size() != param2size[new_name]:
print(
f"WARN: Tried to map {name} to {new_name}, but sizes do not match "
f"({value.size()} != {param2size[new_name]})"
)
else:
model_state[new_name] = value
trainer.model.load_state_dict(model_state, strict=False)
print(f"| Imported {len(model_state)} parameters.")
trainer._optim_history = []