本文整理汇总了Python中allennlp.training.trainer.Trainer.by_name方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.by_name方法的具体用法?Python Trainer.by_name怎么用?Python Trainer.by_name使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.training.trainer.Trainer
的用法示例。
在下文中一共展示了Trainer.by_name方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import by_name [as 别名]
def train_model(params: Params,
serialization_dir: str,
file_friendly_logging: bool = False,
recover: bool = False,
force: bool = False) -> Model:
"""
Trains the model specified in the given :class:`Params` object, using the data and training
parameters also specified in that object, and saves the results in ``serialization_dir``.
Parameters
----------
params : ``Params``
A parameter object specifying an AllenNLP Experiment.
serialization_dir : ``str``
The directory in which to save results and logs.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
recover : ``bool``, optional (default=False)
If ``True``, we will try to recover a training run from an existing serialization
directory. This is only intended for use when something actually crashed during the middle
of a run. For continuing training a model on new data, see the ``fine-tune`` command.
Returns
-------
best_model: ``Model``
The model with the best epoch weights.
"""
prepare_environment(params)
create_serialization_dir(params, serialization_dir, recover, force)
prepare_global_logging(serialization_dir, file_friendly_logging)
cuda_device = params.params.get('trainer').get('cuda_device', -1)
if isinstance(cuda_device, list):
for device in cuda_device:
check_for_gpu(device)
else:
check_for_gpu(cuda_device)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
all_datasets = datasets_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("From dataset instances, %s will be considered for vocabulary creation.",
", ".join(datasets_for_vocab_creation))
vocab = Vocabulary.from_params(
params.pop("vocabulary", {}),
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation)
)
model = Model.from_params(vocab=vocab, params=params.pop('model'))
# Initializing the model can have side effect of expanding the vocabulary
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(vocab)
else:
validation_iterator = None
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
default_to_first_choice=True)
trainer = Trainer.by_name(trainer_choice).from_params(model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=validation_data,
params=trainer_params,
#.........这里部分代码省略.........
示例2: fine_tune_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import by_name [as 别名]
#.........这里部分代码省略.........
"are ignoring them, using instead the model parameters in the archive.")
vocabulary_params = params.pop('vocabulary', {})
if vocabulary_params.get('directory_path', None):
logger.warning("You passed `directory_path` in parameters for the vocabulary in "
"your configuration file, but it will be ignored. ")
all_datasets = datasets_from_params(params)
vocab = model.vocab
if extend_vocab:
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(model.vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(vocab)
else:
validation_iterator = None
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
default_to_first_choice=True)
trainer = Trainer.by_name(trainer_choice).from_params(model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=validation_data,
params=trainer_params,
validation_iterator=validation_iterator)
evaluate_on_test = params.pop_bool("evaluate_on_test", False)
params.assert_empty('base train command')
try:
metrics = trainer.train()
except KeyboardInterrupt:
# if we have completed an epoch, try to create a model archive.
if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
logging.info("Fine-tuning interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights.")
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
raise
# Now tar up results
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
if test_data and evaluate_on_test:
test_metrics = evaluate(
model,
test_data,
iterator,
cuda_device=trainer._cuda_devices[0], # pylint: disable=protected-access
batch_weight_key=batch_weight_key
)
for key, value in test_metrics.items():
metrics["test_" + key] = value
elif test_data:
logger.info("To evaluate on the test set after training, pass the "
"'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")
metrics_json = json.dumps(metrics, indent=2)
with open(os.path.join(serialization_dir, "metrics.json"), "w") as metrics_file:
metrics_file.write(metrics_json)
logger.info("Metrics: %s", metrics_json)
return model