本文整理汇总了Python中allennlp.training.trainer.Trainer._restore_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer._restore_checkpoint方法的具体用法?Python Trainer._restore_checkpoint怎么用?Python Trainer._restore_checkpoint使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.training.trainer.Trainer
的用法示例。
在下文中一共展示了Trainer._restore_checkpoint方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_trainer_can_resume_with_lr_scheduler
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import _restore_checkpoint [as 别名]
def test_trainer_can_resume_with_lr_scheduler(self):
# pylint: disable=protected-access
lr_scheduler = LearningRateScheduler.from_params(
self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
trainer = Trainer(model=self.model,
optimizer=self.optimizer,
iterator=self.iterator,
learning_rate_scheduler=lr_scheduler,
train_dataset=self.instances,
validation_dataset=self.instances,
num_epochs=2, serialization_dir=self.TEST_DIR)
trainer.train()
new_lr_scheduler = LearningRateScheduler.from_params(
self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
new_trainer = Trainer(model=self.model,
optimizer=self.optimizer,
iterator=self.iterator,
learning_rate_scheduler=new_lr_scheduler,
train_dataset=self.instances,
validation_dataset=self.instances,
num_epochs=4, serialization_dir=self.TEST_DIR)
epoch, _ = new_trainer._restore_checkpoint()
assert epoch == 2
assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
new_trainer.train()
示例2: test_trainer_can_resume_training
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import _restore_checkpoint [as 别名]
def test_trainer_can_resume_training(self):
trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances,
validation_dataset=self.instances,
num_epochs=1, serialization_dir=self.TEST_DIR)
trainer.train()
new_trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances,
validation_dataset=self.instances,
num_epochs=3, serialization_dir=self.TEST_DIR)
epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint() # pylint: disable=protected-access
assert epoch == 1
assert len(val_metrics_per_epoch) == 1
assert isinstance(val_metrics_per_epoch[0], float)
assert val_metrics_per_epoch[0] != 0.
new_trainer.train()
示例3: test_trainer_saves_models_at_specified_interval
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import _restore_checkpoint [as 别名]
def test_trainer_saves_models_at_specified_interval(self):
iterator = BasicIterator(batch_size=4)
iterator.index_with(self.vocab)
trainer = Trainer(self.model, self.optimizer,
iterator, self.instances, num_epochs=2,
serialization_dir=self.TEST_DIR,
model_save_interval=0.0001)
trainer.train()
# Now check the serialized files for models saved during the epoch.
prefix = 'model_state_epoch_*'
file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
for fname in file_names]
# We should have checkpoints at the end of each epoch and during each, e.g.
# [0.timestamp, 0, 1.timestamp, 1]
assert len(epochs) == 4
assert epochs[3] == '1'
assert '.' in epochs[0]
# Now make certain we can restore from timestamped checkpoint.
# To do so, remove the checkpoint from the end of epoch 1&2, so
# that we are forced to restore from the timestamped checkpoints.
for k in range(2):
os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
os.remove(os.path.join(self.TEST_DIR, 'best.th'))
restore_trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances, num_epochs=2,
serialization_dir=self.TEST_DIR,
model_save_interval=0.0001)
epoch, _ = restore_trainer._restore_checkpoint() # pylint: disable=protected-access
assert epoch == 2
# One batch per epoch.
assert restore_trainer._batch_num_total == 2 # pylint: disable=protected-access