本文整理汇总了Python中allennlp.training.trainer.Trainer.list_available方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.list_available方法的具体用法?Python Trainer.list_available怎么用?Python Trainer.list_available使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类allennlp.training.trainer.Trainer
的用法示例。
在下文中一共展示了Trainer.list_available方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import list_available [as 别名]
def train_model(params: Params,
serialization_dir: str,
file_friendly_logging: bool = False,
recover: bool = False,
force: bool = False) -> Model:
"""
Trains the model specified in the given :class:`Params` object, using the data and training
parameters also specified in that object, and saves the results in ``serialization_dir``.
Parameters
----------
params : ``Params``
A parameter object specifying an AllenNLP Experiment.
serialization_dir : ``str``
The directory in which to save results and logs.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
recover : ``bool``, optional (default=False)
If ``True``, we will try to recover a training run from an existing serialization
directory. This is only intended for use when something actually crashed during the middle
of a run. For continuing training a model on new data, see the ``fine-tune`` command.
Returns
-------
best_model: ``Model``
The model with the best epoch weights.
"""
prepare_environment(params)
create_serialization_dir(params, serialization_dir, recover, force)
prepare_global_logging(serialization_dir, file_friendly_logging)
cuda_device = params.params.get('trainer').get('cuda_device', -1)
if isinstance(cuda_device, list):
for device in cuda_device:
check_for_gpu(device)
else:
check_for_gpu(cuda_device)
params.to_file(os.path.join(serialization_dir, CONFIG_NAME))
all_datasets = datasets_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("From dataset instances, %s will be considered for vocabulary creation.",
", ".join(datasets_for_vocab_creation))
vocab = Vocabulary.from_params(
params.pop("vocabulary", {}),
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation)
)
model = Model.from_params(vocab=vocab, params=params.pop('model'))
# Initializing the model can have side effect of expanding the vocabulary
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(vocab)
else:
validation_iterator = None
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
default_to_first_choice=True)
trainer = Trainer.by_name(trainer_choice).from_params(model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=validation_data,
params=trainer_params,
#.........这里部分代码省略.........
示例2: fine_tune_model
# 需要导入模块: from allennlp.training.trainer import Trainer [as 别名]
# 或者: from allennlp.training.trainer.Trainer import list_available [as 别名]
def fine_tune_model(model: Model,
params: Params,
serialization_dir: str,
extend_vocab: bool = False,
file_friendly_logging: bool = False,
batch_weight_key: str = "") -> Model:
"""
Fine tunes the given model, using a set of parameters that is largely identical to those used
for :func:`~allennlp.commands.train.train_model`, except that the ``model`` section is ignored,
if it is present (as we are already given a ``Model`` here).
The main difference between the logic done here and the logic done in ``train_model`` is that
here we do not worry about vocabulary construction or creating the model object. Everything
else is the same.
Parameters
----------
archive : ``Archive``
A saved model archive that is the result of running the ``train`` command.
train_data_path : ``str``
Path to the training data to use for fine-tuning.
serialization_dir : ``str``
The directory in which to save results and logs.
validation_data_path : ``str``, optional
Path to the validation data to use while fine-tuning.
extend_vocab: ``bool``, optional (default=False)
If ``True``, we use the new instances to extend your vocabulary.
file_friendly_logging : ``bool``, optional (default=False)
If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
down tqdm's output to only once every 10 seconds.
"""
prepare_environment(params)
if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
raise ConfigurationError(f"Serialization directory ({serialization_dir}) "
f"already exists and is not empty.")
os.makedirs(serialization_dir, exist_ok=True)
prepare_global_logging(serialization_dir, file_friendly_logging)
serialization_params = deepcopy(params).as_dict(quiet=True)
with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file:
json.dump(serialization_params, param_file, indent=4)
if params.pop('model', None):
logger.warning("You passed parameters for the model in your configuration file, but we "
"are ignoring them, using instead the model parameters in the archive.")
vocabulary_params = params.pop('vocabulary', {})
if vocabulary_params.get('directory_path', None):
logger.warning("You passed `directory_path` in parameters for the vocabulary in "
"your configuration file, but it will be ignored. ")
all_datasets = datasets_from_params(params)
vocab = model.vocab
if extend_vocab:
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")
logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation))
vocab.extend_from_instances(vocabulary_params,
(instance for key, dataset in all_datasets.items()
for instance in dataset
if key in datasets_for_vocab_creation))
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(model.vocab)
validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
validation_iterator = DataIterator.from_params(validation_iterator_params)
validation_iterator.index_with(vocab)
else:
validation_iterator = None
train_data = all_datasets['train']
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')
trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
if any(re.search(regex, name) for regex in no_grad_regexes):
parameter.requires_grad_(False)
frozen_parameter_names, tunable_parameter_names = \
get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen (without gradient):")
for name in frozen_parameter_names:
logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
logger.info(name)
trainer_choice = trainer_params.pop_choice("type",
Trainer.list_available(),
#.........这里部分代码省略.........