當前位置: 首頁>>代碼示例>>Python>>正文


Python pytorch_lightning.LightningModule方法代碼示例

本文整理匯總了Python中pytorch_lightning.LightningModule方法的典型用法代碼示例。如果您正苦於以下問題:Python pytorch_lightning.LightningModule方法的具體用法?Python pytorch_lightning.LightningModule怎麽用?Python pytorch_lightning.LightningModule使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在pytorch_lightning的用法示例。


在下文中一共展示了pytorch_lightning.LightningModule方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_example_input_array_types

# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def test_example_input_array_types(example_input, expected_size, mode):
    """ Test the types of example inputs supported for display in the summary. """

    class DummyModule(nn.Module):
        def forward(self, *args, **kwargs):
            return None

    class DummyLightningModule(LightningModule):

        def __init__(self):
            super().__init__()
            self.layer = DummyModule()

        # this LightningModule and submodule accept any type of input
        def forward(self, *args, **kwargs):
            return self.layer(*args, **kwargs)

    model = DummyLightningModule()
    model.example_input_array = example_input
    summary = model.summarize(mode=mode)
    assert summary.in_sizes == [expected_size] 
開發者ID:PyTorchLightning,項目名稱:pytorch-lightning,代碼行數:23,代碼來源:test_memory.py

示例2: add_model_specific_args

# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=0.02, type=float)
        parser.add_argument('--batch_size', default=32, type=int)

        # training specific (for this model)
        parser.add_argument('--max_nb_epochs', default=2, type=int)

        return parser 
開發者ID:PyTorchLightning,項目名稱:pytorch-lightning-conference-seed,代碼行數:15,代碼來源:mnist.py

示例3: _process

# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def _process(self, trainer: Trainer, pl_module: LightningModule) -> None:
        logs = trainer.callback_metrics
        epoch = pl_module.current_epoch
        current_score = logs.get(self.monitor)
        if current_score is None:
            return
        self._trial.report(current_score, step=epoch)
        if self._trial.should_prune():
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

    # NOTE (crcrpar): This method is called <0.8.0 
開發者ID:optuna,項目名稱:optuna,代碼行數:14,代碼來源:pytorch_lightning.py

示例4: on_epoch_end

# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        return self._process(trainer, pl_module)

    # NOTE (crcrpar): This method is called >=0.8.0 
開發者ID:optuna,項目名稱:optuna,代碼行數:6,代碼來源:pytorch_lightning.py

示例5: on_validation_end

# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        return self._process(trainer, pl_module) 
開發者ID:optuna,項目名稱:optuna,代碼行數:4,代碼來源:pytorch_lightning.py


注:本文中的pytorch_lightning.LightningModule方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。