本文整理汇总了Python中fairseq.options.get_generation_parser方法的典型用法代码示例。如果您正苦于以下问题:Python options.get_generation_parser方法的具体用法?Python options.get_generation_parser怎么用?Python options.get_generation_parser使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类fairseq.options
的用法示例。
在下文中一共展示了options.get_generation_parser方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: generate_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def generate_main(data_dir):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
],
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
示例2: cli_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def cli_main():
parser = options.get_generation_parser()
args = options.parse_args_and_arch(parser)
main(args)
示例3: cli_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def cli_main():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
args = options.parse_args_and_arch(parser)
main(args)
示例4: generate_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def generate_main(data_dir, extra_flags=None):
if extra_flags is None:
extra_flags = [
'--print-alignment',
]
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
] + (extra_flags or []),
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.input = '-'
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
示例5: generate_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '3',
'--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
示例6: cli_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def cli_main():
parser = options.get_generation_parser(interactive=True)
args = options.parse_args_and_arch(parser)
main(args)
示例7: infer
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def infer(model_path, vocab_dir, arch, test_data, max_len, temperature):
parser = options.get_generation_parser(interactive=True)
parser.set_defaults(arch=arch,
input=test_data,
max_tokens=max_len,
temperature=temperature,
path=model_path)
args = options.parse_args_and_arch(parser, input_args=[vocab_dir])
return interactive.main(args)
示例8: cli_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def cli_main():
script_parser = get_script_parser()
script_args, extra = script_parser.parse_known_args()
parser = options.get_generation_parser(interactive=True, default_task='captioning')
model_args = options.parse_args_and_arch(parser, input_args=extra)
main(script_args, model_args)
示例9: cli_main
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def cli_main():
"""
MODIFIED: task defaults to gec
"""
parser = options.get_generation_parser(default_task='gec')
args = options.parse_args_and_arch(parser)
main(args)
示例10: __init__
# 需要导入模块: from fairseq import options [as 别名]
# 或者: from fairseq.options import get_generation_parser [as 别名]
def __init__(self, model_path, user_dir, lang_pair, n_cpu_threads=-1):
"""Initializes a fairseq predictor.
Args:
model_path (string): Path to the fairseq model (*.pt). Like
--path in fairseq-interactive.
lang_pair (string): Language pair string (e.g. 'en-fr').
user_dir (string): Path to fairseq user directory.
n_cpu_threads (int): Number of CPU threads. If negative,
use GPU.
"""
super(FairseqPredictor, self).__init__()
_initialize_fairseq(user_dir)
self.use_cuda = torch.cuda.is_available() and n_cpu_threads < 0
parser = options.get_generation_parser()
input_args = ["--path", model_path, os.path.dirname(model_path)]
if lang_pair:
src, trg = lang_pair.split("-")
input_args.extend(["--source-lang", src, "--target-lang", trg])
args = options.parse_args_and_arch(parser, input_args)
# Setup task, e.g., translation
task = tasks.setup_task(args)
self.src_vocab_size = len(task.source_dictionary)
self.trg_vocab_size = len(task.target_dictionary)
self.pad_id = task.source_dictionary.pad()
# Load ensemble
logging.info('Loading fairseq model(s) from {}'.format(model_path))
self.models, _ = checkpoint_utils.load_model_ensemble(
model_path.split(':'),
task=task,
)
# Optimize ensemble for generation
for model in self.models:
model.make_generation_fast_(
beamable_mm_beam_size=1,
need_attn=False,
)
if self.use_cuda:
model.cuda()
self.model = EnsembleModel(self.models)
self.model.eval()