当前位置: 首页>>代码示例>>Python>>正文


Python amp.state_dict方法代码示例

本文整理汇总了Python中apex.amp.state_dict方法的典型用法代码示例。如果您正苦于以下问题:Python amp.state_dict方法的具体用法?Python amp.state_dict怎么用?Python amp.state_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在apex.amp的用法示例。


在下文中一共展示了amp.state_dict方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _save

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
        save_state = {
            'epoch': epoch,
            'arch': args.model,
            'state_dict': get_state_dict(model),
            'optimizer': optimizer.state_dict(),
            'args': args,
            'version': 2,  # version < 2 increments epoch before save
        }
        if use_amp and 'state_dict' in amp.__dict__:
            save_state['amp'] = amp.state_dict()
        if model_ema is not None:
            save_state['state_dict_ema'] = get_state_dict(model_ema)
        if metric is not None:
            save_state['metric'] = metric
        torch.save(save_state, save_path) 
开发者ID:rwightman,项目名称:pytorch-image-models,代码行数:18,代码来源:utils.py

示例2: state_dict

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def state_dict(self):
        """Returns the state of the runner."""
        state = {
            "epoch": self.epochs,
            "operator": self.training_operator.state_dict(),
            "models": [model.state_dict() for model in self.models],
            "optimizers": [opt.state_dict() for opt in self.optimizers]
        }
        if self.schedulers:
            state.update({
                "schedulers": [
                    scheduler.state_dict() for scheduler in self.schedulers
                ]
            })
        # Check if fp16 is True and if NVIDIA Apex is imported.
        if self.use_fp16 and amp:
            state.update({"amp": amp.state_dict()})
        return state 
开发者ID:ray-project,项目名称:ray,代码行数:20,代码来源:torch_runner.py

示例3: save_state

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def save_state(self, save_directory: typing.Union[str, Path], epoch_id: int):
        save_directory = Path(save_directory)
        if not save_directory.exists():
            save_directory.mkdir()
        else:
            assert save_directory.is_dir(), "Save path should be a directory"
        model_to_save = getattr(self.model, 'module', self.model)
        model_to_save.save_pretrained(save_directory)
        optimizer_state: typing.Dict[str, typing.Any] = {
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epoch': epoch_id}
        if APEX_FOUND:
            optimizer_state['master params'] = list(amp.master_params(self.optimizer))
            try:
                optimizer_state['amp'] = amp.state_dict()
            except AttributeError:
                pass
        torch.save(optimizer_state, save_directory / 'checkpoint.bin') 
开发者ID:songlab-cal,项目名称:tape,代码行数:21,代码来源:training.py

示例4: save_model

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def save_model(self, epoch=None, save_name=None):
        if save_name is None:
            save_name = 'model.epoch.%d.pt' % epoch

        if self.mixed_precision:
            import apex.amp as amp
            amp_state_dict = amp.state_dict()
        else:
            amp_state_dict = None

        checkpoint = {
            'epoch': epoch,
            'params': self.params,
            'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(),
             #'optimizer': self.optimizer.state_dict(),
            'amp': amp_state_dict
        }

        torch.save(checkpoint, os.path.join(self.expdir, save_name)) 
开发者ID:ZhengkunTian,项目名称:OpenTransformer,代码行数:21,代码来源:train.py

示例5: get_checkpoint_state

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def get_checkpoint_state(self) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]:
        if self._moving_average is not None:
            # Assigning average value to model parameters.  The checkpointer will call
            # `restore_state_after_checkpointing` when it is done to put this back to what it was.
            self._moving_average.assign_average_value()

        model_state = self.model.state_dict()

        # These are the training states we need to persist.
        training_states = {
            "metric_tracker": self._metric_tracker.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "batch_num_total": self._batch_num_total,
        }

        # If we have a learning rate or momentum scheduler, we should persist them too.
        if self._learning_rate_scheduler is not None:
            training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
        if self._momentum_scheduler is not None:
            training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
        # If model was trained with amp, we should persist the amp state.
        if self._opt_level is not None:
            training_states["amp"] = amp.state_dict()

        try:
            yield model_state, training_states
        finally:
            if self._moving_average is not None:
                self._moving_average.restore() 
开发者ID:allenai,项目名称:allennlp,代码行数:31,代码来源:trainer.py

示例6: get_state_dict

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def get_state_dict(model):
    return unwrap_model(model).state_dict() 
开发者ID:rwightman,项目名称:pytorch-image-models,代码行数:4,代码来源:utils.py

示例7: update

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) 
开发者ID:rwightman,项目名称:pytorch-image-models,代码行数:14,代码来源:utils.py

示例8: load_state_dict

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def load_state_dict(self, state):
        """Sets the state of the model."""
        for model, state_dict in zip(self.models, state["models"]):
            model.load_state_dict(state_dict)
        for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
            optimizer.load_state_dict(state_dict)
        if self.schedulers:
            for scheduler, state_dict in zip(self.schedulers,
                                             state["schedulers"]):
                scheduler.load_state_dict(state_dict)

        if self.use_fp16 and "amp" in state and amp:
            amp.load_state_dict(state["amp"])
        self.epochs = state["epoch"]
        self.training_operator.load_state_dict(state_dict) 
开发者ID:ray-project,项目名称:ray,代码行数:17,代码来源:torch_runner.py

示例9: state_stream

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def state_stream(self):
        """Returns a bytes object for the state dict."""
        state_dict = self.state_dict()
        _buffer = io.BytesIO()
        torch.save(state_dict, _buffer)
        return _buffer.getvalue() 
开发者ID:ray-project,项目名称:ray,代码行数:8,代码来源:torch_runner.py

示例10: load_state_stream

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def load_state_stream(self, byte_obj):
        """Loads a bytes object the training state dict."""
        _buffer = io.BytesIO(byte_obj)
        state_dict = torch.load(_buffer)
        return self.load_state_dict(state_dict) 
开发者ID:ray-project,项目名称:ray,代码行数:7,代码来源:torch_runner.py

示例11: check_state_dict_fp32

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def check_state_dict_fp32(self, state_dict):
        for key in state_dict:
            if 'num_batches_tracked' in key:
                continue
            param = state_dict[key]
            self.assertEqual(param.type(), FLOAT,
                             'Parameter in state_dict not FLOAT') 
开发者ID:NVIDIA,项目名称:apex,代码行数:9,代码来源:test_checkpointing.py

示例12: compare_models

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def compare_models(self, modelA, modelB, test_setup=''):
        state_dictA = modelA.state_dict()
        state_dictB = modelB.state_dict()
        self.assertEqual(len(state_dictA), len(state_dictB),
                         'state_dicts have different lengths' + test_setup)
        for key in state_dictA:
            paramA = state_dictA[key]
            paramB = state_dictB[key]
            self.assertTrue((paramA==paramB).all(),
                msg='Parameters in state_dices not equal.' +
                    'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
                        key, paramA, paramB, paramA - paramB, test_setup)) 
开发者ID:NVIDIA,项目名称:apex,代码行数:14,代码来源:test_checkpointing.py

示例13: test_state_dict

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def test_state_dict(self):
        for opt_level in self.test_opt_levels:
            # Skip O3
            if opt_level == 'O3':
                continue

            model = MyModel().to('cuda')
            optimizer = optim.Adam(model.parameters(), lr=1e-3)
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, verbosity=0)

            # Export state_dict and check for Half
            state_dict = model.state_dict()
            for key in state_dict:
                self.assertFalse('Half' in state_dict[key].type())

            # Check, if model is still trainable
            # Create dummy data
            data = torch.randn(10, 3, 4, 4, device='cuda')
            target = torch.randn(10, 6, 4, 4, device='cuda')
            
            # Get initnial loss
            optimizer.zero_grad()
            output = model(data)
            loss = F.mse_loss(output, target)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            last_loss = loss.item()

            # train for some epochs
            for epoch in range(10):
                optimizer.zero_grad()
                output = model(data)
                loss = F.mse_loss(output, target)
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
                self.assertTrue(loss.item() < last_loss)
                last_loss = loss.item() 
开发者ID:NVIDIA,项目名称:apex,代码行数:42,代码来源:test_checkpointing.py

示例14: load_model

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def load_model(self, checkpoint):

        state_dict = torch.load(checkpoint)
        self.model.load_state_dict(state_dict['model'])
        if self.mixed_precision:
            import apex.amp as amp
            amp.load_state_dict(state_dict['amp']) 
开发者ID:ZhengkunTian,项目名称:OpenTransformer,代码行数:9,代码来源:train.py

示例15: test_loss_scale_decrease

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import state_dict [as 别名]
def test_loss_scale_decrease(self):
        num_losses = 3
        nb_decrease_loss_scales = [0, 1, 2]
        for opt_level in self.test_opt_levels:
            #print('#' * 75 + f'\n opt_level {opt_level}\n')
            # Create new tmp copy for this run
            nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)

            model = MyModel().to('cuda')
        
            optimizer = optim.SGD(model.parameters(),
                                  lr=self.initial_lr)
        
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=opt_level, num_losses=num_losses,
                verbosity=0)

            if amp._amp_state.opt_properties.loss_scale != 'dynamic':
                #print('Static loss scale set. Skipping opt_level.')
                continue
        
            # force to skip some updates to decrease the loss_scale
            initial_loss_scales = []
            for idx in range(num_losses):
                initial_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            
            for _ in range(len(nb_decrease_loss_scales)):
                x = torch.randn(16, 3, 24, 24, device='cuda')
                for idx in range(num_losses):
                    while nb_decrease_loss_scales_tmp[idx] > 0:
                        optimizer.zero_grad()
                        output = model(x * 2**17)
                        loss = output.mean()            
                    
                        with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
                            scaled_loss.backward(retain_graph=True)
                        optimizer.step()
                        nb_decrease_loss_scales_tmp[idx] -= 1
                
            # Check loss scales afterwards
            updated_loss_scales = []
            for idx in range(num_losses):
                updated_loss_scales.append(
                    amp._amp_state.loss_scalers[idx].loss_scale())
            for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
                                                  updated_loss_scales,
                                                  initial_loss_scales):
                self.assertEqual(update_ls, init_ls / 2**factor)

            # Check state dict
            amp_state_dict = amp.state_dict()
            for scaler_idx, factor, init_ls in zip(amp_state_dict,
                                                   nb_decrease_loss_scales,
                                                   initial_loss_scales):
                scaler = amp_state_dict[scaler_idx]
                self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
                unskipped_target = 0
                self.assertEqual(scaler['unskipped'], unskipped_target) 
开发者ID:NVIDIA,项目名称:apex,代码行数:61,代码来源:test_checkpointing.py


注:本文中的apex.amp.state_dict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。