本文整理匯總了Python中pytorch_lightning.LightningModule方法的典型用法代碼示例。如果您正苦於以下問題:Python pytorch_lightning.LightningModule方法的具體用法?Python pytorch_lightning.LightningModule怎麽用?Python pytorch_lightning.LightningModule使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類pytorch_lightning
的用法示例。
在下文中一共展示了pytorch_lightning.LightningModule方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test_example_input_array_types
# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def test_example_input_array_types(example_input, expected_size, mode):
""" Test the types of example inputs supported for display in the summary. """
class DummyModule(nn.Module):
def forward(self, *args, **kwargs):
return None
class DummyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = DummyModule()
# this LightningModule and submodule accept any type of input
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
model = DummyLightningModule()
model.example_input_array = example_input
summary = model.summarize(mode=mode)
assert summary.in_sizes == [expected_size]
示例2: add_model_specific_args
# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def add_model_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
return parser
示例3: _process
# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def _process(self, trainer: Trainer, pl_module: LightningModule) -> None:
logs = trainer.callback_metrics
epoch = pl_module.current_epoch
current_score = logs.get(self.monitor)
if current_score is None:
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)
# NOTE (crcrpar): This method is called <0.8.0
示例4: on_epoch_end
# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
return self._process(trainer, pl_module)
# NOTE (crcrpar): This method is called >=0.8.0
示例5: on_validation_end
# 需要導入模塊: import pytorch_lightning [as 別名]
# 或者: from pytorch_lightning import LightningModule [as 別名]
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
return self._process(trainer, pl_module)