当前位置: 首页>>代码示例>>Python>>正文


Python lr_scheduler._LRScheduler方法代码示例

本文整理汇总了Python中torch.optim.lr_scheduler._LRScheduler方法的典型用法代码示例。如果您正苦于以下问题:Python lr_scheduler._LRScheduler方法的具体用法?Python lr_scheduler._LRScheduler怎么用?Python lr_scheduler._LRScheduler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.optim.lr_scheduler的用法示例。


在下文中一共展示了lr_scheduler._LRScheduler方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])

        # model_dict = self.state_dict()
        # pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}  # filter out unnecessary keys
        # model_dict.update(pretrained_dict)
        # self.load_state_dict(model_dict)
        # torch.nn.DataParallel(self).cuda()
        #step = checkpoint['step']
        step=0
        # if optimizer is not None:
        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # if scheduler is not None:
        #     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
开发者ID:MagicChuyi,项目名称:SlowFast-Network-pytorch,代码行数:18,代码来源:model.py

示例2: checkpoint_model

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def checkpoint_model(
        self,
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpoint the model.

        Args:
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.
        """
        self.checkpointer.checkpoint(
            self.unit_total, model, optimizer, lr_scheduler, metric_dict
        ) 
开发者ID:SenWu,项目名称:emmental,代码行数:20,代码来源:logging_manager.py

示例3: __init__

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def __init__(self, config: Config, optimizer):
        super().__init__(config)
        name = config.get("train.lr_scheduler")
        args = config.get("train.lr_scheduler_args")
        self._lr_scheduler: _LRScheduler = None
        if name != "":
            try:
                self._lr_scheduler = getattr(torch.optim.lr_scheduler, name)(
                    optimizer, **args
                )
            except Exception as e:
                raise ValueError(
                    (
                        "Invalid LR scheduler {} or scheduler arguments {}. "
                        "Error: {}"
                    ).format(name, args, e)
                )

        self._metric_based = name in ["ReduceLROnPlateau"] 
开发者ID:uma-pi1,项目名称:kge,代码行数:21,代码来源:optimizer.py

示例4: __init__

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def __init__(self, lr_scheduler, save_history=False, **kwds):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError("Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                            "but given {}".format(type(lr_scheduler)))

        if len(lr_scheduler.optimizer.param_groups) > 1:
            raise ValueError("Optimizer passed to lr_scheduler should have a single param group, "
                             "but currently there are {} param groups".format(len(lr_scheduler.optimizer.param_groups)))

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer,
            param_name='lr',
            save_history=save_history
        ) 
开发者ID:leokarlin,项目名称:LaSO,代码行数:18,代码来源:param_scheduler.py

示例5: simulate_values

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """
        copy_lr_scheduler = LRScheduler._copy_lr_scheduler(lr_scheduler)
        values = []
        scheduler = cls(save_history=False, lr_scheduler=copy_lr_scheduler)
        for i in range(num_events):
            scheduler(engine=None)
            values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])

        return values 
开发者ID:leokarlin,项目名称:LaSO,代码行数:21,代码来源:param_scheduler.py

示例6: save

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
开发者ID:potterhsu,项目名称:easy-faster-rcnn.pytorch,代码行数:12,代码来源:model.py

示例7: load

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
开发者ID:potterhsu,项目名称:easy-faster-rcnn.pytorch,代码行数:11,代码来源:model.py

示例8: save

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, 'model-{}.pth'.format(step))
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint
    # 
开发者ID:MagicChuyi,项目名称:SlowFast-Network-pytorch,代码行数:13,代码来源:model.py

示例9: collect_state_dict

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        """Collect the state dict of the model.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.

        Returns:
          The state dict.
        """
        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict 
开发者ID:SenWu,项目名称:emmental,代码行数:41,代码来源:checkpointer.py

示例10: get_lr

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def get_lr(self):
        """Get updated learning rate."""
        # HACK: We need to check if this is the first time ``self.get_lr()`` was called,
        # since ``torch.optim.lr_scheduler._LRScheduler`` will call ``self.get_lr()``
        # when first initialized, but the learning rate should remain unchanged
        # for the first epoch.
        if not self._initialized:
            self._initialized = True
            return self.base_lrs

        step = self.last_epoch + 1
        self._cycle_counter = step - self._last_restart

        lrs = [
            self.eta_min + ((lr - self.eta_min) / 2) * (
                np.cos(
                    np.pi *
                    (self._cycle_counter % self._updated_cycle_len) / self._updated_cycle_len
                ) + 1
            ) for lr in self.base_lrs
        ]

        if self._cycle_counter % self._updated_cycle_len == 0:
            # Adjust the cycle length.
            self._cycle_factor *= self.factor
            self._cycle_counter = 0
            self._updated_cycle_len = int(self._cycle_factor * self.t_max)
            self._last_restart = step

        return lrs 
开发者ID:arthurdouillard,项目名称:incremental_learning.pytorch,代码行数:32,代码来源:schedulers.py

示例11: maybe_update_lr

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def maybe_update_lr(self):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
开发者ID:MIC-DKFZ,项目名称:nnUNet,代码行数:13,代码来源:network_trainer.py

示例12: maybe_update_lr

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
开发者ID:MIC-DKFZ,项目名称:nnUNet,代码行数:14,代码来源:nnUNetTrainerV2_SGD_ReduceOnPlateau.py

示例13: maybe_update_lr

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0 and self.train_loss_MA is not None:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
开发者ID:MIC-DKFZ,项目名称:nnUNet,代码行数:14,代码来源:nnUNetTrainerV2_Adam_ReduceOnPlateau.py

示例14: __init__

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def __init__(self, lr_scheduler, save_history=False, **kwargs):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer, param_name="lr", save_history=save_history
        )
        self._state_attrs += [
            "lr_scheduler",
        ] 
开发者ID:pytorch,项目名称:ignite,代码行数:17,代码来源:param_scheduler.py

示例15: simulate_values

# 需要导入模块: from torch.optim import lr_scheduler [as 别名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 别名]
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(),
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
            for i in range(num_events):
                params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(obj["optimizer"])

            return values 
开发者ID:pytorch,项目名称:ignite,代码行数:43,代码来源:param_scheduler.py


注:本文中的torch.optim.lr_scheduler._LRScheduler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。