當前位置: 首頁>>代碼示例>>Python>>正文


Python lr_scheduler._LRScheduler方法代碼示例

本文整理匯總了Python中torch.optim.lr_scheduler._LRScheduler方法的典型用法代碼示例。如果您正苦於以下問題:Python lr_scheduler._LRScheduler方法的具體用法?Python lr_scheduler._LRScheduler怎麽用?Python lr_scheduler._LRScheduler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.optim.lr_scheduler的用法示例。


在下文中一共展示了lr_scheduler._LRScheduler方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: load

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])

        # model_dict = self.state_dict()
        # pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}  # filter out unnecessary keys
        # model_dict.update(pretrained_dict)
        # self.load_state_dict(model_dict)
        # torch.nn.DataParallel(self).cuda()
        #step = checkpoint['step']
        step=0
        # if optimizer is not None:
        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # if scheduler is not None:
        #     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
開發者ID:MagicChuyi,項目名稱:SlowFast-Network-pytorch,代碼行數:18,代碼來源:model.py

示例2: checkpoint_model

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def checkpoint_model(
        self,
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> None:
        """Checkpoint the model.

        Args:
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.
        """
        self.checkpointer.checkpoint(
            self.unit_total, model, optimizer, lr_scheduler, metric_dict
        ) 
開發者ID:SenWu,項目名稱:emmental,代碼行數:20,代碼來源:logging_manager.py

示例3: __init__

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def __init__(self, config: Config, optimizer):
        super().__init__(config)
        name = config.get("train.lr_scheduler")
        args = config.get("train.lr_scheduler_args")
        self._lr_scheduler: _LRScheduler = None
        if name != "":
            try:
                self._lr_scheduler = getattr(torch.optim.lr_scheduler, name)(
                    optimizer, **args
                )
            except Exception as e:
                raise ValueError(
                    (
                        "Invalid LR scheduler {} or scheduler arguments {}. "
                        "Error: {}"
                    ).format(name, args, e)
                )

        self._metric_based = name in ["ReduceLROnPlateau"] 
開發者ID:uma-pi1,項目名稱:kge,代碼行數:21,代碼來源:optimizer.py

示例4: __init__

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def __init__(self, lr_scheduler, save_history=False, **kwds):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError("Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                            "but given {}".format(type(lr_scheduler)))

        if len(lr_scheduler.optimizer.param_groups) > 1:
            raise ValueError("Optimizer passed to lr_scheduler should have a single param group, "
                             "but currently there are {} param groups".format(len(lr_scheduler.optimizer.param_groups)))

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer,
            param_name='lr',
            save_history=save_history
        ) 
開發者ID:leokarlin,項目名稱:LaSO,代碼行數:18,代碼來源:param_scheduler.py

示例5: simulate_values

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """
        copy_lr_scheduler = LRScheduler._copy_lr_scheduler(lr_scheduler)
        values = []
        scheduler = cls(save_history=False, lr_scheduler=copy_lr_scheduler)
        for i in range(num_events):
            scheduler(engine=None)
            values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])

        return values 
開發者ID:leokarlin,項目名稱:LaSO,代碼行數:21,代碼來源:param_scheduler.py

示例6: save

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth')
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint 
開發者ID:potterhsu,項目名稱:easy-faster-rcnn.pytorch,代碼行數:12,代碼來源:model.py

示例7: load

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
        checkpoint = torch.load(path_to_checkpoint)
        self.load_state_dict(checkpoint['state_dict'])
        step = checkpoint['step']
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return step 
開發者ID:potterhsu,項目名稱:easy-faster-rcnn.pytorch,代碼行數:11,代碼來源:model.py

示例8: save

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str:
        path_to_checkpoint = os.path.join(path_to_checkpoints_dir, 'model-{}.pth'.format(step))
        checkpoint = {
            'state_dict': self.state_dict(),
            'step': step,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, path_to_checkpoint)
        return path_to_checkpoint
    # 
開發者ID:MagicChuyi,項目名稱:SlowFast-Network-pytorch,代碼行數:13,代碼來源:model.py

示例9: collect_state_dict

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        """Collect the state dict of the model.

        Args:
          iteration: The current iteration.
          model: The model to checkpoint.
          optimizer: The optimizer used during training process.
          lr_scheduler: Learning rate scheduler.
          metric_dict: the metric dict.

        Returns:
          The state dict.
        """
        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict 
開發者ID:SenWu,項目名稱:emmental,代碼行數:41,代碼來源:checkpointer.py

示例10: get_lr

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def get_lr(self):
        """Get updated learning rate."""
        # HACK: We need to check if this is the first time ``self.get_lr()`` was called,
        # since ``torch.optim.lr_scheduler._LRScheduler`` will call ``self.get_lr()``
        # when first initialized, but the learning rate should remain unchanged
        # for the first epoch.
        if not self._initialized:
            self._initialized = True
            return self.base_lrs

        step = self.last_epoch + 1
        self._cycle_counter = step - self._last_restart

        lrs = [
            self.eta_min + ((lr - self.eta_min) / 2) * (
                np.cos(
                    np.pi *
                    (self._cycle_counter % self._updated_cycle_len) / self._updated_cycle_len
                ) + 1
            ) for lr in self.base_lrs
        ]

        if self._cycle_counter % self._updated_cycle_len == 0:
            # Adjust the cycle length.
            self._cycle_factor *= self.factor
            self._cycle_counter = 0
            self._updated_cycle_len = int(self._cycle_factor * self.t_max)
            self._last_restart = step

        return lrs 
開發者ID:arthurdouillard,項目名稱:incremental_learning.pytorch,代碼行數:32,代碼來源:schedulers.py

示例11: maybe_update_lr

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def maybe_update_lr(self):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
開發者ID:MIC-DKFZ,項目名稱:nnUNet,代碼行數:13,代碼來源:network_trainer.py

示例12: maybe_update_lr

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
開發者ID:MIC-DKFZ,項目名稱:nnUNet,代碼行數:14,代碼來源:nnUNetTrainerV2_SGD_ReduceOnPlateau.py

示例13: maybe_update_lr

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def maybe_update_lr(self, epoch=None):
        # maybe update learning rate
        if self.lr_scheduler is not None:
            assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))

            if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
                # lr scheduler is updated with moving average val loss. should be more robust
                if self.epoch > 0 and self.train_loss_MA is not None:  # otherwise self.train_loss_MA is None
                    self.lr_scheduler.step(self.train_loss_MA)
            else:
                self.lr_scheduler.step(self.epoch + 1)
        self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr'])) 
開發者ID:MIC-DKFZ,項目名稱:nnUNet,代碼行數:14,代碼來源:nnUNetTrainerV2_Adam_ReduceOnPlateau.py

示例14: __init__

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def __init__(self, lr_scheduler, save_history=False, **kwargs):

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        self.lr_scheduler = lr_scheduler
        super(LRScheduler, self).__init__(
            optimizer=self.lr_scheduler.optimizer, param_name="lr", save_history=save_history
        )
        self._state_attrs += [
            "lr_scheduler",
        ] 
開發者ID:pytorch,項目名稱:ignite,代碼行數:17,代碼來源:param_scheduler.py

示例15: simulate_values

# 需要導入模塊: from torch.optim import lr_scheduler [as 別名]
# 或者: from torch.optim.lr_scheduler import _LRScheduler [as 別名]
def simulate_values(cls, num_events, lr_scheduler, **kwargs):
        """Method to simulate scheduled values during num_events events.

        Args:
            num_events (int): number of events during the simulation.
            lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap.

        Returns:
            list of pairs: [event_index, value]

        """

        if not isinstance(lr_scheduler, _LRScheduler):
            raise TypeError(
                "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, "
                "but given {}".format(type(lr_scheduler))
            )

        # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which
        # should be replicated in order to simulate LR values and
        # not perturb original scheduler.
        with tempfile.TemporaryDirectory() as tmpdirname:
            cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
            obj = {
                "lr_scheduler": lr_scheduler.state_dict(),
                "optimizer": lr_scheduler.optimizer.state_dict(),
            }
            torch.save(obj, cache_filepath.as_posix())

            values = []
            scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs)
            for i in range(num_events):
                params = [p[scheduler.param_name] for p in scheduler.optimizer_param_groups]
                values.append([i] + params)
                scheduler(engine=None)

            obj = torch.load(cache_filepath.as_posix())
            lr_scheduler.load_state_dict(obj["lr_scheduler"])
            lr_scheduler.optimizer.load_state_dict(obj["optimizer"])

            return values 
開發者ID:pytorch,項目名稱:ignite,代碼行數:43,代碼來源:param_scheduler.py


注:本文中的torch.optim.lr_scheduler._LRScheduler方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。