当前位置: 首页>>代码示例>>Python>>正文


Python trainer.Trainer类代码示例

本文整理汇总了Python中allennlp.training.trainer.Trainer的典型用法代码示例。如果您正苦于以下问题:Python Trainer类的具体用法?Python Trainer怎么用?Python Trainer使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


在下文中一共展示了Trainer类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_trainer_can_run_multiple_gpu

 def test_trainer_can_run_multiple_gpu(self):
     multigpu_iterator = BasicIterator(batch_size=4)
     multigpu_iterator.index_with(self.vocab)
     trainer = Trainer(self.model, self.optimizer,
                       multigpu_iterator, self.instances, num_epochs=2,
                       cuda_device=[0, 1])
     trainer.train()
开发者ID:pyknife,项目名称:allennlp,代码行数:7,代码来源:trainer_test.py

示例2: test_trainer_respects_keep_serialized_model_every_num_seconds

    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an iterator that sleeps for 2.5 second per epoch, so the total training
        #       time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.
        class WaitingIterator(BasicIterator):
            # pylint: disable=arguments-differ
            def _create_batches(self, *args, **kwargs):
                time.sleep(2.5)
                return super(WaitingIterator, self)._create_batches(*args, **kwargs)

        iterator = WaitingIterator(batch_size=2)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=6,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=2,
                          keep_serialized_model_every_num_seconds=5)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:30,代码来源:trainer_test.py

示例3: test_trainer_can_run

    def test_trainer_can_run(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          validation_metric='+loss',
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:34,代码来源:trainer_test.py

示例4: test_trainer_can_run_multiple_gpu

    def test_trainer_can_run_multiple_gpu(self):

        class MetaDataCheckWrapper(Model):
            """
            Checks that the metadata field has been correctly split across the batch dimension
            when running on multiple gpus.
            """
            def __init__(self, model):
                super().__init__(model.vocab)
                self.model = model

            def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore # pylint: disable=arguments-differ
                assert 'metadata' in kwargs and 'tags' in kwargs, \
                    f'tokens and metadata must be provided. Got {kwargs.keys()} instead.'
                batch_size = kwargs['tokens']['tokens'].size()[0]
                assert len(kwargs['metadata']) == batch_size, \
                    f'metadata must be split appropriately. Expected {batch_size} elements, ' \
                    f"got {len(kwargs['metadata'])} elements."
                return self.model.forward(**kwargs)

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = Trainer(MetaDataCheckWrapper(self.model), self.optimizer,
                          multigpu_iterator, self.instances, num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:26,代码来源:trainer_test.py

示例5: test_should_stop_early_with_increasing_metric

 def test_should_stop_early_with_increasing_metric(self):
     new_trainer = Trainer(self.model, self.optimizer,
                           self.iterator, self.instances,
                           validation_dataset=self.instances,
                           num_epochs=3, serialization_dir=self.TEST_DIR,
                           patience=5, validation_metric="+test")
     assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
     assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:8,代码来源:trainer_test.py

示例6: test_trainer_raises_on_model_with_no_loss_key

 def test_trainer_raises_on_model_with_no_loss_key(self):
     class FakeModel(torch.nn.Module):
         def forward(self, **kwargs):  # pylint: disable=arguments-differ,unused-argument
             return {}
     with pytest.raises(RuntimeError):
         trainer = Trainer(FakeModel(), self.optimizer,
                           self.iterator, self.instances,
                           num_epochs=2, serialization_dir=self.TEST_DIR)
         trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:9,代码来源:trainer_test.py

示例7: test_trainer_can_log_histograms

    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          histogram_interval=2)
        trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:10,代码来源:trainer_test.py

示例8: test_trainer_can_log_learning_rates_tensorboard

    def test_trainer_can_log_learning_rates_tensorboard(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          should_log_learning_rate=True,
                          summary_interval=2)

        trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:11,代码来源:trainer_test.py

示例9: test_trainer_can_run_with_lr_scheduler

 def test_trainer_can_run_with_lr_scheduler(self):
     lr_params = Params({"type": "reduce_on_plateau"})
     lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
     trainer = Trainer(model=self.model,
                       optimizer=self.optimizer,
                       iterator=self.iterator,
                       learning_rate_scheduler=lr_scheduler,
                       validation_metric="-loss",
                       train_dataset=self.instances,
                       validation_dataset=self.instances,
                       num_epochs=2)
     trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:12,代码来源:trainer_test.py

示例10: test_trainer_respects_num_serialized_models_to_keep

    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            assert sorted(epochs) == [2, 3, 4]
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:13,代码来源:trainer_test.py

示例11: test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_decreasing_metric

 def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_decreasing_metric(self):
     new_trainer = Trainer(self.model, self.optimizer,
                           self.iterator, self.instances,
                           validation_dataset=self.instances,
                           num_epochs=3, serialization_dir=self.TEST_DIR,
                           patience=5, validation_metric="-test")
     # when it is the only metric it should be considered the best
     assert new_trainer._is_best_so_far(1, [])  # pylint: disable=protected-access
     # when it is the same as one before it it is not considered the best
     assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
     # when it is the best it is considered the best
     assert new_trainer._is_best_so_far(.013, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
     # when it is not the the best it is not considered the best
     assert not new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:14,代码来源:trainer_test.py

示例12: test_should_stop_early_with_early_stopping_disabled

    def test_should_stop_early_with_early_stopping_disabled(self):
        # Increasing metric
        trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances,
                          validation_dataset=self.instances, num_epochs=100,
                          patience=None, validation_metric="+test")
        decreasing_history = [float(i) for i in reversed(range(20))]
        assert not trainer._should_stop_early(decreasing_history)  # pylint: disable=protected-access

        # Decreasing metric
        trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances,
                          validation_dataset=self.instances, num_epochs=100,
                          patience=None, validation_metric="-test")
        increasing_history = [float(i) for i in range(20)]
        assert not trainer._should_stop_early(increasing_history)  # pylint: disable=protected-access
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:14,代码来源:trainer_test.py

示例13: test_trainer_can_resume_with_lr_scheduler

    def test_trainer_can_resume_with_lr_scheduler(self):
        # pylint: disable=protected-access
        lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2, serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        new_trainer = Trainer(model=self.model,
                              optimizer=self.optimizer,
                              iterator=self.iterator,
                              learning_rate_scheduler=new_lr_scheduler,
                              train_dataset=self.instances,
                              validation_dataset=self.instances,
                              num_epochs=4, serialization_dir=self.TEST_DIR)
        epoch, _ = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:26,代码来源:trainer_test.py

示例14: test_trainer_saves_metrics_every_epoch

    def test_trainer_saves_metrics_every_epoch(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        for epoch in range(5):
            epoch_file = self.TEST_DIR / f'metrics_epoch_{epoch}.json'
            assert epoch_file.exists()
            metrics = json.load(open(epoch_file))
            assert "validation_loss" in metrics
            assert "best_validation_loss" in metrics
            assert metrics.get("epoch") == epoch
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:18,代码来源:trainer_test.py

示例15: test_trainer_saves_models_at_specified_interval

    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
                  for fname in file_names]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
            os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model, self.optimizer,
                                  self.iterator, self.instances, num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2  # pylint: disable=protected-access
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:38,代码来源:trainer_test.py


注:本文中的allennlp.training.trainer.Trainer类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。