当前位置: 首页>>代码示例>>Python>>正文


Python amp.load_state_dict方法代码示例

本文整理汇总了Python中apex.amp.load_state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python amp.load_state_dict方法的具体用法?Python amp.load_state_dict怎么用?Python amp.load_state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在apex.amp的用法示例。


在下文中一共展示了amp.load_state_dict方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_ckpt

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_ckpt(self):
        ''' Load ckpt if --load option is specified '''
        if self.paras.load:
            # Load weights
            ckpt = torch.load(
                self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu')
            self.model.load_state_dict(ckpt['model'])
            if self.emb_decoder is not None:
                self.emb_decoder.load_state_dict(ckpt['emb_decoder'])
            # if self.amp:
            #    amp.load_state_dict(ckpt['amp'])
            # Load task-dependent items
            metric = "None"
            score = 0.0
            for k, v in ckpt.items():
                if type(v) is float:
                    metric, score = k, v
            if self.mode == 'train':
                self.step = ckpt['global_step']
                self.optimizer.load_opt_state_dict(ckpt['optimizer'])
                self.verbose('Load ckpt from {}, restarting at step {} (recorded {} = {:.2f} %)'.format(
                              self.paras.load, self.step, metric, score))
            else:
                self.model.eval()
                if self.emb_decoder is not None:
                    self.emb_decoder.eval()
                self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load, metric, score)) 
开发者ID:Alexander-H-Liu,项目名称:End-to-end-ASR-Pytorch,代码行数:29,代码来源:solver.py

示例2: prepare_for_training

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def prepare_for_training(args, model, checkpoint_state_dict, amp):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)

    if amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        if checkpoint_state_dict:
            amp.load_state_dict(checkpoint_state_dict['amp'])

    if checkpoint_state_dict:
        optimizer.load_state_dict(checkpoint_state_dict['optimizer'])
        model.load_state_dict(checkpoint_state_dict['model'])

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)

    return model, optimizer 
开发者ID:microsoft,项目名称:unilm,代码行数:30,代码来源:run_seq2seq.py

示例3: resume

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def resume(
        checkpoint: Union[str, Path],
        model: torch.nn.Module,
        reporter: Reporter,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        ngpu: int = 0,
        use_apex: bool = False,
    ):
        states = torch.load(
            checkpoint,
            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
        )
        model.load_state_dict(states["model"])
        reporter.load_state_dict(states["reporter"])
        for optimizer, state in zip(optimizers, states["optimizers"]):
            optimizer.load_state_dict(state)
        for scheduler, state in zip(schedulers, states["schedulers"]):
            if scheduler is not None:
                scheduler.load_state_dict(state)
        if use_apex and states["amp"] is not None:
            try:
                from apex import amp
            except ImportError:
                logging.error(
                    "You need to install apex. "
                    "See https://github.com/NVIDIA/apex#linux"
                )
            amp.load_state_dict(states["amp"])

        logging.info(f"The training was resumed using {checkpoint}") 
开发者ID:espnet,项目名称:espnet,代码行数:33,代码来源:abs_task.py

示例4: build_model_from_file

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def build_model_from_file(
        cls,
        config_file: Union[Path, str],
        model_file: Union[Path, str] = None,
        device: str = "cpu",
    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
        """This method is used for inference or fine-tuning.

        Args:
            config_file: The yaml file saved when training.
            model_file: The model file saved when training.
            device:

        """
        assert check_argument_types()
        config_file = Path(config_file)

        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        if model_file is not None:
            if device == "cuda":
                # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
                #   in PyTorch<=1.4
                device = f"cuda:{torch.cuda.current_device()}"
            model.load_state_dict(torch.load(model_file, map_location=device))

        return model, args 
开发者ID:espnet,项目名称:espnet,代码行数:36,代码来源:abs_task.py

示例5: load_state_dict

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_state_dict(self, state):
        """Sets the state of the model."""
        for model, state_dict in zip(self.models, state["models"]):
            model.load_state_dict(state_dict)
        for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
            optimizer.load_state_dict(state_dict)
        if self.schedulers:
            for scheduler, state_dict in zip(self.schedulers,
                                             state["schedulers"]):
                scheduler.load_state_dict(state_dict)

        if self.use_fp16 and "amp" in state and amp:
            amp.load_state_dict(state["amp"])
        self.epochs = state["epoch"]
        self.training_operator.load_state_dict(state_dict) 
开发者ID:ray-project,项目名称:ray,代码行数:17,代码来源:torch_runner.py

示例6: load_state_stream

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_state_stream(self, byte_obj):
        """Loads a bytes object the training state dict."""
        _buffer = io.BytesIO(byte_obj)
        state_dict = torch.load(_buffer)
        return self.load_state_dict(state_dict) 
开发者ID:ray-project,项目名称:ray,代码行数:7,代码来源:torch_runner.py

示例7: resume_from_checkpoint

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def resume_from_checkpoint(self, checkpoint_dir: str) -> int:
        checkpoint = torch.load(
            os.path.join(checkpoint_dir, 'checkpoint.bin'), map_location=self.device)
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if self.fp16:
            self.optimizer._lazy_init_maybe_master_weights()
            self.optimizer._amp_stash.lazy_init_called = True
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            for param, saved in zip(
                    amp.master_params(self.optimizer), checkpoint['master params']):
                param.data.copy_(saved.data)
            amp.load_state_dict(checkpoint['amp'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        return start_epoch 
开发者ID:songlab-cal,项目名称:tape,代码行数:17,代码来源:training.py

示例8: load_model

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import load_state_dict [as 别名]
def load_model(self, checkpoint):

        state_dict = torch.load(checkpoint)
        self.model.load_state_dict(state_dict['model'])
        if self.mixed_precision:
            import apex.amp as amp
            amp.load_state_dict(state_dict['amp']) 
开发者ID:ZhengkunTian,项目名称:OpenTransformer,代码行数:9,代码来源:train.py


注:本文中的apex.amp.load_state_dict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。