本文整理汇总了Python中allennlp.training.trainer.Trainer.from_params方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.from_params方法的具体用法?Python Trainer.from_params怎么用?Python Trainer.from_params使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.training.trainer.Trainer
的用法示例。
在下文中一共展示了Trainer.from_params方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import from_params [as 别名]
def train_model(params: Params,
serialization_dir: str,
file_friendly_logging: bool = False,
recover: bool = False) -> Model:
"""
Trains the model specified in the given :class:`Params` object, using the data and training
parameters also specified in that object, and saves the results in ``serialization_dir``.
Parameters
----------
params : ``Params``
A parameter object specifying an AllenNLP Experiment.
serialization_dir : ``str``
The directory in which to save results and logs.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
recover : ``bool`, optional (default=False)
If ``True``, we will try to recover a training run from an existing serialization
directory. This is only intended for use when something actually crashed during the middle
of a run. For continuing training a model on new data, see the ``fine-tune`` command.
"""
prepare_environment(params)
create_serialization_dir(params, serialization_dir, recover)
prepare_global_logging(serialization_dir, file_friendly_logging)
serialization_params = deepcopy(params).as_dict(quiet=True)
with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file:
json.dump(serialization_params, param_file, indent=4)
all_datasets = datasets_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("Creating a vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab = Vocabulary.from_params(params.pop("vocabulary", {}),
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
model = Model.from_params(vocab, params.pop('model'))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
trainer = Trainer.from_params(model,
serialization_dir,
iterator,
train_data,
validation_data,
trainer_params)
evaluate_on_test = params.pop_bool("evaluate_on_test", False)
params.assert_empty('base train command')
try:
metrics = trainer.train()
except KeyboardInterrupt:
# if we have completed an epoch, try to create a model archive.
if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
logging.info("Training interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights.")
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
raise
# Now tar up results
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
if test_data and evaluate_on_test:
test_metrics = evaluate(model, test_data, iterator, cuda_device=trainer._cuda_devices[0]) # pylint: disable=protected-access
for key, value in test_metrics.items():
metrics["test_" + key] = value
elif test_data:
logger.info("To evaluate on the test set after training, pass the "
"'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")
metrics_json = json.dumps(metrics, indent=2)
with open(os.path.join(serialization_dir, "metrics.json"), "w") as metrics_file:
metrics_file.write(metrics_json)
logger.info("Metrics: %s", metrics_json)
return model
示例2: fine_tune_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import from_params [as 别名]
def fine_tune_model(model: Model,
params: Params,
serialization_dir: str,
file_friendly_logging: bool = False) -> Model:
"""
Fine tunes the given model, using a set of parameters that is largely identical to those used
for :func:`~allennlp.commands.train.train_model`, except that the ``model`` section is ignored,
if it is present (as we are already given a ``Model`` here).
The main difference between the logic done here and the logic done in ``train_model`` is that
here we do not worry about vocabulary construction or creating the model object. Everything
else is the same.
Parameters
----------
archive : ``Archive``
A saved model archive that is the result of running the ``train`` command.
train_data_path : ``str``
Path to the training data to use for fine-tuning.
serialization_dir : ``str``
The directory in which to save results and logs.
validation_data_path : ``str``, optional
Path to the validation data to use while fine-tuning.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
"""
prepare_environment(params)
os.makedirs(serialization_dir)
prepare_global_logging(serialization_dir, file_friendly_logging)
serialization_params = deepcopy(params).as_dict(quiet=True)
with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file:
json.dump(serialization_params, param_file, indent=4)
if params.pop('model', None):
logger.warning("You passed parameters for the model in your configuration file, but we "
"are ignoring them, using instead the model parameters in the archive.")
vocabulary_params = params.pop('vocabulary', {})
if vocabulary_params.get('directory_path', None):
logger.warning("You passed `directory_path` in parameters for the vocabulary in "
"your configuration file, but it will be ignored. "
"Vocabulary from the saved model will be extended with current data.")
all_datasets = datasets_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab = model.vocab
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
trainer = Trainer.from_params(model,
serialization_dir,
iterator,
train_data,
validation_data,
trainer_params)
evaluate_on_test = params.pop_bool("evaluate_on_test", False)
params.assert_empty('base train command')
try:
metrics = trainer.train()
except KeyboardInterrupt:
# if we have completed an epoch, try to create a model archive.
if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
logging.info("Fine-tuning interrupted by the user. Attempting to create "
"a model archive using the current best epoch weights.")
archive_model(serialization_dir, files_to_archive=params.files_to_archive)
#.........这里部分代码省略.........