当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch import_fairseq_model用法及代码示例


本文简要介绍python语言中 torchaudio.models.wav2vec2.utils.import_fairseq_model 的用法。

用法:

torchaudio.models.wav2vec2.utils.import_fairseq_model(original: torch.nn.Module) → torchaudio.models.Wav2Vec2Model

参数

original(torch.nn.Module) -fairseq 的 Wav2Vec2.0 或 HuBERT 模型的一个实例。 fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoderfairseq.models.wav2vec.wav2vec2.Wav2Vec2Modelfairseq.models.hubert.hubert_asr.HubertEncoder 之一。

返回

导入型号。

返回类型

Wav2Vec2模型

fairseq 的相应模型对象构建 Wav2Vec2Model 。

示例 - 加载 pretrain-only 模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
>>> features, _ = imported.extract_features(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> reference = original.feature_extractor(waveform).transpose(1, 2)
>>> torch.testing.assert_allclose(features, reference)
示例 - Fine-tuned 型号
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small_960h.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original.w2v_encoder)
>>>
>>> # Perform encoding
>>> waveform, _ = torchaudio.load('audio.wav')
>>> emission, _ = imported(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchaudio.models.wav2vec2.utils.import_fairseq_model。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。