本文整理汇总了Python中pytorch_lightning.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_lightning.Trainer方法的具体用法?Python pytorch_lightning.Trainer怎么用?Python pytorch_lightning.Trainer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pytorch_lightning
的用法示例。
在下文中一共展示了pytorch_lightning.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_mnist_tune_checkpoint
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def train_mnist_tune_checkpoint(config, checkpoint=None):
trainer = pl.Trainer(
max_epochs=10,
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
if checkpoint:
# Currently, this leads to errors:
# model = LightningMNISTClassifier.load_from_checkpoint(
# os.path.join(checkpoint, "checkpoint"))
# Workaround:
ckpt = pl_load(
os.path.join(checkpoint, "checkpoint"),
map_location=lambda storage, loc: storage)
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
trainer.current_epoch = ckpt["epoch"]
else:
model = LightningMNISTClassifier(
config=config, data_dir=config["data_dir"])
trainer.fit(model)
# __tune_train_checkpoint_end__
# __tune_asha_begin__
示例2: main
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def main(args: argparse.Namespace) -> None:
"""Train the model.
Args:
args: Model hyper-parameters
Note:
For the sake of the example, the images dataset will be downloaded
to a temporary directory.
"""
with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:
model = TransferLearningModel(dl_path=tmp_dir, **vars(args))
trainer = pl.Trainer(
weights_summary=None,
show_progress_bar=True,
num_sanity_val_steps=0,
gpus=args.gpus,
min_epochs=args.nb_epochs,
max_epochs=args.nb_epochs)
trainer.fit(model)
示例3: test_load_past_checkpoint
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_load_past_checkpoint(tmpdir, past_key):
model = EvalModelTemplate()
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
raw_checkpoint['hparams_type'] = 'Namespace'
raw_checkpoint[past_key]['batch_size'] = -17
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
# save back the checkpoint
torch.save(raw_checkpoint, raw_checkpoint_path)
# verify that model loads correctly
model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path)
assert model2.hparams.batch_size == -17
示例4: test_multi_gpu_model
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_model(tmpdir, backend):
"""Make sure DDP works."""
tutils.set_random_master_port()
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=[0, 1],
distributed_backend=backend,
)
model = EvalModelTemplate()
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result
# test memory helper functions
memory.get_memory_profile('min_max')
示例5: test_multi_gpu_early_stop
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_early_stop(tmpdir, backend):
"""Make sure DDP works. with early stopping"""
tutils.set_random_master_port()
trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend=backend,
)
model = EvalModelTemplate()
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result
示例6: test_ddp_all_dataloaders_passed_to_fit
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()"""
tutils.set_random_master_port()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
gpus=[0, 1],
distributed_backend='ddp'
)
model = EvalModelTemplate()
fit_options = dict(train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader())
trainer = Trainer(**trainer_options)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
示例7: test_amp_single_gpu
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_amp_single_gpu(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
tutils.reset_seed()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gpus=1,
distributed_backend=backend,
precision=16,
)
model = EvalModelTemplate()
# tutils.run_model_test(trainer_options, model)
result = trainer.fit(model)
assert result == 1
示例8: test_amp_multi_gpu
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_amp_multi_gpu(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
tutils.set_random_master_port()
model = EvalModelTemplate()
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
# gpus=2,
gpus='0, 1', # test init with gpu string
distributed_backend=backend,
precision=16,
)
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result
示例9: test_multi_gpu_wandb
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_wandb(tmpdir, backend):
"""Make sure DP/DDP + AMP work."""
from pytorch_lightning.loggers import WandbLogger
tutils.set_random_master_port()
model = EvalModelTemplate()
logger = WandbLogger(name='utest')
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
gpus=2,
distributed_backend=backend,
precision=16,
logger=logger,
)
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result
trainer.test(model)
示例10: test_dataloaders_passed_to_fit
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_dataloaders_passed_to_fit(tmpdir):
"""Test if dataloaders passed to trainer works on TPU"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
tpu_cores=8,
)
result = trainer.fit(
model,
train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader(),
)
assert result, "TPU doesn't work with dataloaders passed to fit()."
示例11: test_on_before_zero_grad_called
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_on_before_zero_grad_called(tmpdir, max_steps):
class CurrentTestModel(EvalModelTemplate):
on_before_zero_grad_called = 0
def on_before_zero_grad(self, optimizer):
self.on_before_zero_grad_called += 1
model = CurrentTestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=max_steps,
max_epochs=2,
num_sanity_val_steps=5,
)
assert 0 == model.on_before_zero_grad_called
trainer.fit(model)
assert max_steps == model.on_before_zero_grad_called
model.on_before_zero_grad_called = 0
trainer.test(model)
assert 0 == model.on_before_zero_grad_called
示例12: test_early_stopping_functionality
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_early_stopping_functionality(tmpdir):
class CurrentModel(EvalModelTemplate):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'val_loss': torch.tensor(val_loss)}
model = CurrentModel()
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
overfit_batches=0.20,
max_epochs=20,
)
result = trainer.fit(model)
print(trainer.current_epoch)
assert trainer.current_epoch == 5, 'early_stopping failed'
示例13: test_early_stopping_no_val_step
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""
class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output
model = CurrentModel()
model.validation_step = None
model.val_dataloader = None
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_batches=0.20,
max_epochs=2,
)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs
示例14: test_model_checkpoint_with_non_string_input
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_batches=0.20,
max_epochs=2,
)
trainer.fit(model)
# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
示例15: test_model_checkpoint_path
# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
trainer = Trainer(
default_root_dir=tmpdir,
overfit_batches=0.2,
max_epochs=2,
logger=logger,
)
trainer.fit(model)
ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected