本文整理汇总了Python中apex.amp.load_state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python amp.load_state_dict方法的具体用法?Python amp.load_state_dict怎么用?Python amp.load_state_dict使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类apex.amp
的用法示例。
在下文中一共展示了amp.load_state_dict方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_ckpt
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_ckpt(self):
''' Load ckpt if --load option is specified '''
if self.paras.load:
# Load weights
ckpt = torch.load(
self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu')
self.model.load_state_dict(ckpt['model'])
if self.emb_decoder is not None:
self.emb_decoder.load_state_dict(ckpt['emb_decoder'])
# if self.amp:
# amp.load_state_dict(ckpt['amp'])
# Load task-dependent items
metric = "None"
score = 0.0
for k, v in ckpt.items():
if type(v) is float:
metric, score = k, v
if self.mode == 'train':
self.step = ckpt['global_step']
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
self.verbose('Load ckpt from {}, restarting at step {} (recorded {} = {:.2f} %)'.format(
self.paras.load, self.step, metric, score))
else:
self.model.eval()
if self.emb_decoder is not None:
self.emb_decoder.eval()
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load, metric, score))
示例2: prepare_for_training
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def prepare_for_training(args, model, checkpoint_state_dict, amp):
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
if amp:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
if checkpoint_state_dict:
amp.load_state_dict(checkpoint_state_dict['amp'])
if checkpoint_state_dict:
optimizer.load_state_dict(checkpoint_state_dict['optimizer'])
model.load_state_dict(checkpoint_state_dict['model'])
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
return model, optimizer
示例3: resume
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def resume(
checkpoint: Union[str, Path],
model: torch.nn.Module,
reporter: Reporter,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
ngpu: int = 0,
use_apex: bool = False,
):
states = torch.load(
checkpoint,
map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
)
model.load_state_dict(states["model"])
reporter.load_state_dict(states["reporter"])
for optimizer, state in zip(optimizers, states["optimizers"]):
optimizer.load_state_dict(state)
for scheduler, state in zip(schedulers, states["schedulers"]):
if scheduler is not None:
scheduler.load_state_dict(state)
if use_apex and states["amp"] is not None:
try:
from apex import amp
except ImportError:
logging.error(
"You need to install apex. "
"See https://github.com/NVIDIA/apex#linux"
)
amp.load_state_dict(states["amp"])
logging.info(f"The training was resumed using {checkpoint}")
示例4: build_model_from_file
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def build_model_from_file(
cls,
config_file: Union[Path, str],
model_file: Union[Path, str] = None,
device: str = "cpu",
) -> Tuple[AbsESPnetModel, argparse.Namespace]:
"""This method is used for inference or fine-tuning.
Args:
config_file: The yaml file saved when training.
model_file: The model file saved when training.
device:
"""
assert check_argument_types()
config_file = Path(config_file)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
if not isinstance(model, AbsESPnetModel):
raise RuntimeError(
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
)
model.to(device)
if model_file is not None:
if device == "cuda":
# NOTE(kamo): "cuda" for torch.load always indicates cuda:0
# in PyTorch<=1.4
device = f"cuda:{torch.cuda.current_device()}"
model.load_state_dict(torch.load(model_file, map_location=device))
return model, args
示例5: load_state_dict
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_state_dict(self, state):
"""Sets the state of the model."""
for model, state_dict in zip(self.models, state["models"]):
model.load_state_dict(state_dict)
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(state_dict)
if self.schedulers:
for scheduler, state_dict in zip(self.schedulers,
state["schedulers"]):
scheduler.load_state_dict(state_dict)
if self.use_fp16 and "amp" in state and amp:
amp.load_state_dict(state["amp"])
self.epochs = state["epoch"]
self.training_operator.load_state_dict(state_dict)
示例6: load_state_stream
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_state_stream(self, byte_obj):
"""Loads a bytes object the training state dict."""
_buffer = io.BytesIO(byte_obj)
state_dict = torch.load(_buffer)
return self.load_state_dict(state_dict)
示例7: resume_from_checkpoint
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def resume_from_checkpoint(self, checkpoint_dir: str) -> int:
checkpoint = torch.load(
os.path.join(checkpoint_dir, 'checkpoint.bin'), map_location=self.device)
self.optimizer.load_state_dict(checkpoint['optimizer'])
if self.fp16:
self.optimizer._lazy_init_maybe_master_weights()
self.optimizer._amp_stash.lazy_init_called = True
self.optimizer.load_state_dict(checkpoint['optimizer'])
for param, saved in zip(
amp.master_params(self.optimizer), checkpoint['master params']):
param.data.copy_(saved.data)
amp.load_state_dict(checkpoint['amp'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
start_epoch = checkpoint['epoch'] + 1
return start_epoch
示例8: load_model
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_model(self, checkpoint):
state_dict = torch.load(checkpoint)
self.model.load_state_dict(state_dict['model'])
if self.mixed_precision:
import apex.amp as amp
amp.load_state_dict(state_dict['amp'])