当前位置: 首页>>代码示例>>Python>>正文


Python pytorch_lightning.Trainer方法代码示例

本文整理汇总了Python中pytorch_lightning.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python pytorch_lightning.Trainer方法的具体用法?Python pytorch_lightning.Trainer怎么用?Python pytorch_lightning.Trainer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pytorch_lightning的用法示例。


在下文中一共展示了pytorch_lightning.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train_mnist_tune_checkpoint

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def train_mnist_tune_checkpoint(config, checkpoint=None):
    trainer = pl.Trainer(
        max_epochs=10,
        progress_bar_refresh_rate=0,
        callbacks=[CheckpointCallback(),
                   TuneReportCallback()])
    if checkpoint:
        # Currently, this leads to errors:
        # model = LightningMNISTClassifier.load_from_checkpoint(
        #     os.path.join(checkpoint, "checkpoint"))
        # Workaround:
        ckpt = pl_load(
            os.path.join(checkpoint, "checkpoint"),
            map_location=lambda storage, loc: storage)
        model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
        trainer.current_epoch = ckpt["epoch"]
    else:
        model = LightningMNISTClassifier(
            config=config, data_dir=config["data_dir"])

    trainer.fit(model)
# __tune_train_checkpoint_end__


# __tune_asha_begin__ 
开发者ID:ray-project,项目名称:ray,代码行数:27,代码来源:mnist_pytorch_lightning.py

示例2: main

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def main(args: argparse.Namespace) -> None:
    """Train the model.

    Args:
        args: Model hyper-parameters

    Note:
        For the sake of the example, the images dataset will be downloaded
        to a temporary directory.
    """

    with TemporaryDirectory(dir=args.root_data_path) as tmp_dir:

        model = TransferLearningModel(dl_path=tmp_dir, **vars(args))

        trainer = pl.Trainer(
            weights_summary=None,
            show_progress_bar=True,
            num_sanity_val_steps=0,
            gpus=args.gpus,
            min_epochs=args.nb_epochs,
            max_epochs=args.nb_epochs)

        trainer.fit(model) 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:26,代码来源:computer_vision_fine_tuning.py

示例3: test_load_past_checkpoint

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_load_past_checkpoint(tmpdir, past_key):
    model = EvalModelTemplate()

    # verify we can train
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
    trainer.fit(model)

    # make sure the raw checkpoint saved the properties
    raw_checkpoint_path = _raw_checkpoint_path(trainer)
    raw_checkpoint = torch.load(raw_checkpoint_path)
    raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
    raw_checkpoint['hparams_type'] = 'Namespace'
    raw_checkpoint[past_key]['batch_size'] = -17
    del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
    # save back the checkpoint
    torch.save(raw_checkpoint, raw_checkpoint_path)

    # verify that model loads correctly
    model2 = EvalModelTemplate.load_from_checkpoint(raw_checkpoint_path)
    assert model2.hparams.batch_size == -17 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:22,代码来源:test_hparams.py

示例4: test_multi_gpu_model

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_model(tmpdir, backend):
    """Make sure DDP works."""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=0.4,
        limit_val_batches=0.2,
        gpus=[0, 1],
        distributed_backend=backend,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result

    # test memory helper functions
    memory.get_memory_profile('min_max') 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:23,代码来源:test_gpu.py

示例5: test_multi_gpu_early_stop

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_early_stop(tmpdir, backend):
    """Make sure DDP works. with early stopping"""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        early_stop_callback=True,
        max_epochs=50,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        distributed_backend=backend,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:21,代码来源:test_gpu.py

示例6: test_ddp_all_dataloaders_passed_to_fit

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
    """Make sure DDP works with dataloaders passed to fit()"""
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=1,
        limit_train_batches=0.1,
        limit_val_batches=0.1,
        gpus=[0, 1],
        distributed_backend='ddp'
    )

    model = EvalModelTemplate()
    fit_options = dict(train_dataloader=model.train_dataloader(),
                       val_dataloaders=model.val_dataloader())

    trainer = Trainer(**trainer_options)
    result = trainer.fit(model, **fit_options)
    assert result == 1, "DDP doesn't work with dataloaders passed to fit()." 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:23,代码来源:test_gpu.py

示例7: test_amp_single_gpu

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_amp_single_gpu(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    tutils.reset_seed()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=1,
        distributed_backend=backend,
        precision=16,
    )

    model = EvalModelTemplate()
    # tutils.run_model_test(trainer_options, model)
    result = trainer.fit(model)

    assert result == 1 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:18,代码来源:test_amp.py

示例8: test_amp_multi_gpu

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_amp_multi_gpu(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    tutils.set_random_master_port()

    model = EvalModelTemplate()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        # gpus=2,
        gpus='0, 1',  # test init with gpu string
        distributed_backend=backend,
        precision=16,
    )

    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:21,代码来源:test_amp.py

示例9: test_multi_gpu_wandb

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_multi_gpu_wandb(tmpdir, backend):
    """Make sure DP/DDP + AMP work."""
    from pytorch_lightning.loggers import WandbLogger
    tutils.set_random_master_port()

    model = EvalModelTemplate()
    logger = WandbLogger(name='utest')

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        gpus=2,
        distributed_backend=backend,
        precision=16,
        logger=logger,

    )
    # tutils.run_model_test(trainer_options, model)
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result
    trainer.test(model) 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:24,代码来源:test_amp.py

示例10: test_dataloaders_passed_to_fit

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_dataloaders_passed_to_fit(tmpdir):
    """Test if dataloaders passed to trainer works on TPU"""

    model = EvalModelTemplate()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        tpu_cores=8,
    )
    result = trainer.fit(
        model,
        train_dataloader=model.train_dataloader(),
        val_dataloaders=model.val_dataloader(),
    )
    assert result, "TPU doesn't work with dataloaders passed to fit()." 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:18,代码来源:test_tpu.py

示例11: test_on_before_zero_grad_called

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_on_before_zero_grad_called(tmpdir, max_steps):

    class CurrentTestModel(EvalModelTemplate):
        on_before_zero_grad_called = 0

        def on_before_zero_grad(self, optimizer):
            self.on_before_zero_grad_called += 1

    model = CurrentTestModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=max_steps,
        max_epochs=2,
        num_sanity_val_steps=5,
    )
    assert 0 == model.on_before_zero_grad_called
    trainer.fit(model)
    assert max_steps == model.on_before_zero_grad_called

    model.on_before_zero_grad_called = 0
    trainer.test(model)
    assert 0 == model.on_before_zero_grad_called 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:25,代码来源:test_hooks.py

示例12: test_early_stopping_functionality

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_early_stopping_functionality(tmpdir):

    class CurrentModel(EvalModelTemplate):
        def validation_epoch_end(self, outputs):
            losses = [8, 4, 2, 3, 4, 5, 8, 10]
            val_loss = losses[self.current_epoch]
            return {'val_loss': torch.tensor(val_loss)}

    model = CurrentModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=True,
        overfit_batches=0.20,
        max_epochs=20,
    )
    result = trainer.fit(model)
    print(trainer.current_epoch)

    assert trainer.current_epoch == 5, 'early_stopping failed' 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:22,代码来源:test_callbacks.py

示例13: test_early_stopping_no_val_step

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""

    class CurrentModel(EvalModelTemplate):
        def training_step(self, *args, **kwargs):
            output = super().training_step(*args, **kwargs)
            output.update({'my_train_metric': output['loss']})  # could be anything else
            return output

    model = CurrentModel()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=stopping,
        overfit_batches=0.20,
        max_epochs=2,
    )
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'
    assert trainer.current_epoch < trainer.max_epochs 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:26,代码来源:test_callbacks.py

示例14: test_model_checkpoint_with_non_string_input

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
    """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """
    tutils.reset_seed()
    model = EvalModelTemplate()

    checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

    trainer = Trainer(
        default_root_dir=tmpdir,
        checkpoint_callback=checkpoint,
        overfit_batches=0.20,
        max_epochs=2,
    )
    trainer.fit(model)

    # These should be different if the dirpath has be overridden
    assert trainer.ckpt_path != trainer.default_root_dir 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:19,代码来源:test_callbacks.py

示例15: test_model_checkpoint_path

# 需要导入模块: import pytorch_lightning [as 别名]
# 或者: from pytorch_lightning import Trainer [as 别名]
def test_model_checkpoint_path(tmpdir, logger_version, expected):
    """Test that "version_" prefix is only added when logger's version is an integer"""
    tutils.reset_seed()
    model = EvalModelTemplate()
    logger = TensorBoardLogger(str(tmpdir), version=logger_version)

    trainer = Trainer(
        default_root_dir=tmpdir,
        overfit_batches=0.2,
        max_epochs=2,
        logger=logger,
    )
    trainer.fit(model)

    ckpt_version = Path(trainer.ckpt_path).parent.name
    assert ckpt_version == expected 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:18,代码来源:test_callbacks.py


注:本文中的pytorch_lightning.Trainer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。