本文整理匯總了Python中fairseq.trainer.Trainer方法的典型用法代碼示例。如果您正苦於以下問題:Python trainer.Trainer方法的具體用法?Python trainer.Trainer怎麽用?Python trainer.Trainer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類fairseq.trainer
的用法示例。
在下文中一共展示了trainer.Trainer方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: get_parser_with_args
# 需要導入模塊: from fairseq import trainer [as 別名]
# 或者: from fairseq.trainer import Trainer [as 別名]
def get_parser_with_args(default_task="pytorch_translate"):
parser = options.get_parser("Trainer", default_task=default_task)
pytorch_translate_options.add_verbosity_args(parser, train=True)
pytorch_translate_options.add_dataset_args(parser, train=True, gen=True)
options.add_distributed_training_args(parser)
# Adds args related to training (validation and stopping criterions).
optimization_group = options.add_optimization_args(parser)
pytorch_translate_options.expand_optimization_args(optimization_group)
# Adds args related to checkpointing.
checkpointing_group = options.add_checkpoint_args(parser)
pytorch_translate_options.expand_checkpointing_args(checkpointing_group)
# Add model related args
options.add_model_args(parser)
# Adds args for generating intermediate BLEU eval while training.
generation_group = options.add_generation_args(parser)
pytorch_translate_options.expand_generation_args(generation_group, train=True)
# Adds args related to input data files (preprocessing, numberizing, and
# binarizing text files; creating vocab files)
pytorch_translate_options.add_preprocessing_args(parser)
return parser
示例2: single_process_main
# 需要導入模塊: from fairseq import trainer [as 別名]
# 或者: from fairseq.trainer import Trainer [as 別名]
def single_process_main(args, trainer_class=Trainer, **train_step_kwargs):
"""Train the model for multiple epochs."""
pytorch_translate_options.print_args(args)
trainer, task, epoch_itr = setup_training(args, trainer_class)
extra_state, epoch_itr, checkpoint_manager = setup_training_state(
args=args, trainer=trainer, task=task, epoch_itr=epoch_itr
)
train(
args=args,
extra_state=extra_state,
trainer=trainer,
task=task,
epoch_itr=epoch_itr,
checkpoint_manager=checkpoint_manager,
**train_step_kwargs,
)
示例3: main
# 需要導入模塊: from fairseq import trainer [as 別名]
# 或者: from fairseq.trainer import Trainer [as 別名]
def main(args, trainer_class=Trainer, **train_step_kwargs):
# We preprocess the data (generating vocab files and binarized data files
# if needed) outside of the train processes to prevent them from having to
# wait while the master process is doing this.
preprocess.preprocess_corpora(args)
if args.distributed_world_size == 1:
single_process_main(args, trainer_class, **train_step_kwargs)
else:
spawn_context, output_queue = multi_process_main(args=args, start_rank=0)
while not spawn_context.join(timeout=30):
# Periodically clears the output queue to ensure that the processes
# don't deadlock due to queue buffer being full. This is also
# necessary to ensure that processes join correctly, since a process
# may not terminate until all items it put on the queue have been
# consumed (per
# https://docs.python.org/3/library/multiprocessing.html#all-start-methods).
try:
while True:
output_queue.get_nowait()
except queue.Empty:
pass
示例4: gpu_train_step
# 需要導入模塊: from fairseq import trainer [as 別名]
# 或者: from fairseq.trainer import Trainer [as 別名]
def gpu_train_step(test_args: ModelParamsDict) -> Tuple[Trainer, Dict[Any, Any]]:
"""Sets up inputs from test_args then executes a single train step. A train
step always requires a GPU."""
samples, src_dict, tgt_dict = prepare_inputs(test_args)
task = tasks.DictionaryHolderTask(src_dict, tgt_dict)
model = task.build_model(test_args)
criterion = task.build_criterion(test_args)
sample = next(samples)
trainer = Trainer(test_args, task, model, criterion)
logging_dict = trainer.train_step([sample])
return trainer, logging_dict
示例5: setup_training
# 需要導入模塊: from fairseq import trainer [as 別名]
# 或者: from fairseq.trainer import Trainer [as 別名]
def setup_training(args, trainer_class=None):
""" Perform several steps:
- build model using provided criterion and task
- load data
- build trainer
"""
# Overrides the default print() to always prepend the timestamp for more
# informative logging.
builtin_print = __builtin__.print
def print(*args, **kwargs):
if "file" not in kwargs:
builtin_print(
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}]",
*args,
**kwargs,
)
else:
builtin_print(*args, **kwargs)
__builtin__.print = print
task, model, criterion = setup_training_model(args)
if trainer_class is None:
trainer_class = Trainer
trainer, epoch_itr = build_trainer(
args=args,
task=task,
model=model,
criterion=criterion,
trainer_class=trainer_class,
)
return trainer, task, epoch_itr