本文整理汇总了Python中pytorch_lightning.LightningModule方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_lightning.LightningModule方法的具体用法?Python pytorch_lightning.LightningModule怎么用?Python pytorch_lightning.LightningModule使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pytorch_lightning
的用法示例。
在下文中一共展示了pytorch_lightning.LightningModule方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_example_input_array_types
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def test_example_input_array_types(example_input, expected_size, mode):
""" Test the types of example inputs supported for display in the summary. """
class DummyModule(nn.Module):
def forward(self, *args, **kwargs):
return None
class DummyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = DummyModule()
# this LightningModule and submodule accept any type of input
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
model = DummyLightningModule()
model.example_input_array = example_input
summary = model.summarize(mode=mode)
assert summary.in_sizes == [expected_size]
示例2: add_model_specific_args
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def add_model_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
return parser
示例3: _process
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def _process(self, trainer: Trainer, pl_module: LightningModule) -> None:
logs = trainer.callback_metrics
epoch = pl_module.current_epoch
current_score = logs.get(self.monitor)
if current_score is None:
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)
# NOTE (crcrpar): This method is called <0.8.0
示例4: on_epoch_end
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
return self._process(trainer, pl_module)
# NOTE (crcrpar): This method is called >=0.8.0
示例5: on_validation_end
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import LightningModule [as 别名]
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
return self._process(trainer, pl_module)