当前位置: 首页>>代码示例>>Python>>正文


Python pytorch_lightning.LightningModule方法代码示例

本文整理汇总了Python中pytorch_lightning.LightningModule方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_lightning.LightningModule方法的具体用法?Python pytorch_lightning.LightningModule怎么用?Python pytorch_lightning.LightningModule使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pytorch_lightning的用法示例。


在下文中一共展示了pytorch_lightning.LightningModule方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_example_input_array_types

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def test_example_input_array_types(example_input, expected_size, mode):
    """ Test the types of example inputs supported for display in the summary. """

    class DummyModule(nn.Module):
        def forward(self, *args, **kwargs):
            return None

    class DummyLightningModule(LightningModule):

        def __init__(self):
            super().__init__()
            self.layer = DummyModule()

        # this LightningModule and submodule accept any type of input
        def forward(self, *args, **kwargs):
            return self.layer(*args, **kwargs)

    model = DummyLightningModule()
    model.example_input_array = example_input
    summary = model.summarize(mode=mode)
    assert summary.in_sizes == [expected_size] 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:23,代码来源:test_memory.py

示例2: add_model_specific_args

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=0.02, type=float)
        parser.add_argument('--batch_size', default=32, type=int)

        # training specific (for this model)
        parser.add_argument('--max_nb_epochs', default=2, type=int)

        return parser 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning-conference-seed,代码行数:15,代码来源:mnist.py

示例3: _process

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def _process(self, trainer: Trainer, pl_module: LightningModule) -> None:
        logs = trainer.callback_metrics
        epoch = pl_module.current_epoch
        current_score = logs.get(self.monitor)
        if current_score is None:
            return
        self._trial.report(current_score, step=epoch)
        if self._trial.should_prune():
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

    # NOTE (crcrpar): This method is called <0.8.0 
开发者ID:optuna,项目名称:optuna,代码行数:14,代码来源:pytorch_lightning.py

示例4: on_epoch_end

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        return self._process(trainer, pl_module)

    # NOTE (crcrpar): This method is called >=0.8.0 
开发者ID:optuna,项目名称:optuna,代码行数:6,代码来源:pytorch_lightning.py

示例5: on_validation_end

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        return self._process(trainer, pl_module) 
开发者ID:optuna,项目名称:optuna,代码行数:4,代码来源:pytorch_lightning.py


注:本文中的pytorch_lightning.LightningModule方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。