本文整理汇总了Python中fairseq.models.build_model方法的典型用法代码示例。如果您正苦于以下问题:Python models.build_model方法的具体用法?Python models.build_model怎么用?Python models.build_model使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类fairseq.models
的用法示例。
在下文中一共展示了models.build_model方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self)
if getattr(args, 'tpu', False):
model.prepare_for_tpu_()
model = quantization_utils.quantize_model_scalar(model, args)
return model
示例2: __init__
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def __init__(self, args, dicts, training):
super().__init__(args)
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = self.lang_pairs
self.langs = list(dicts.keys())
示例3: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
def check_args():
messages = []
if len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) != 0:
messages.append('--lang-pairs should include all the language pairs {}.'.format(args.lang_pairs))
if self.args.encoder_langtok != args.encoder_langtok:
messages.append('--encoder-langtok should be {}.'.format(args.encoder_langtok))
if self.args.decoder_langtok != args.decoder_langtok:
messages.append('--decoder-langtok should {} be set.'.format("" if args.decoder_langtok else "not"))
if len(messages) > 0:
raise ValueError(' '.join(messages))
# Check if task args are consistant with model args
check_args()
from fairseq import models
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture')
return model
示例4: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
return models.build_model(args, self)
示例5: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
'sentence_classification_head',
num_classes=1,
)
return model
示例6: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, 'gating_network'):
if self.args.mean_pool_gating_network:
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, 'encoder_embed_dim', None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
if getattr(args, 'mean_pool_gating_network_dropout', None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, 'dropout', None):
dropout = args.dropout
else:
raise ValueError('Must specify --mean-pool-gating-network-dropout')
model.gating_network = MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout,
)
else:
raise ValueError(
'translation_moe task with learned prior requires the model to '
'have a gating network; try using --mean-pool-gating-network'
)
return model
示例7: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
model.register_classification_head(
getattr(args, 'ranking_head_name', 'sentence_classification_head'),
num_classes=1,
)
return model
示例8: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError('SemisupervisedTranslationTask requires a FairseqMultiModel architecture')
# create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training:
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split('-')
key = '{}-{}'.format(tgt, src)
self.sequence_generators[key] = SequenceGenerator(
[model.models[key]],
tgt_dict=self.dicts[src],
beam_size=args.bt_beam_size,
max_len_a=args.bt_max_len_a,
max_len_b=args.bt_max_len_b,
)
decoder_lang_tok_idx = self.get_decoder_langtok(src)
def backtranslate_fn(
sample, model=model.models[key],
bos_token=decoder_lang_tok_idx,
sequence_generator=self.sequence_generators[key],
):
return sequence_generator.generate(
[model],
sample,
bos_token=bos_token,
)
self.backtranslators[lang_pair] = backtranslate_fn
return model
示例9: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
model = models.build_model(args, self)
if not isinstance(model, SemiSupervisedModel):
raise ValueError(
"PytorchTranslateDenoisingAutoencoder task requires a "
"SemiSupervisedModel architecture"
)
# TODO(T35539829): implement a Noising registry so this can be built
# with any noising class as long as it has a @register_noising decorator
self.source_noiser = noising.UnsupervisedMTNoising(
dictionary=self.source_dictionary,
max_word_shuffle_distance=args.max_word_shuffle_distance,
word_dropout_prob=args.word_dropout_prob,
word_blanking_prob=args.word_blanking_prob,
bpe_cont_marker=self.args.source_bpe_cont_marker,
bpe_end_marker=self.args.source_bpe_end_marker,
)
self.target_noiser = noising.UnsupervisedMTNoising(
dictionary=self.target_dictionary,
max_word_shuffle_distance=args.max_word_shuffle_distance,
word_dropout_prob=args.word_dropout_prob,
word_blanking_prob=args.word_blanking_prob,
bpe_cont_marker=self.args.target_bpe_cont_marker,
bpe_end_marker=self.args.target_bpe_end_marker,
)
return model
示例10: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
model = models.build_model(args, self)
self.model = model
if not isinstance(model, FairseqMultiModel):
raise ValueError(
"PytorchTranslateSemiSupervised task requires a FairseqMultiModel "
"architecture"
)
forward_pair = "-".join([self.source_lang, self.target_lang])
backward_pair = "-".join([self.target_lang, self.source_lang])
self.forward_model = model.models[forward_pair]
self.backward_model = model.models[backward_pair]
return model
示例11: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError(
"PyTorchTranslateMultiTask requires a FairseqMultiModel architecture"
)
return model
示例12: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
return models.build_model(args, self)
示例13: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if args.reload_checkpoint is not None:
filename = args.reload_checkpoint
if os.path.exists(filename):
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
model.load_state_dict(state['model'], strict=False)
return model
示例14: load_ensemble_for_inference
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
from fairseq import data, models
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
args = states[0]['args']
args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble, args
示例15: build_model
# 需要导入模块: from fairseq import models [as 别名]
# 或者: from fairseq.models import build_model [as 别名]
def build_model(self, args):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self)
return quantization_utils.quantize_model_scalar(model, args)