當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch import_fairseq_model用法及代碼示例


本文簡要介紹python語言中 torchaudio.models.wav2vec2.utils.import_fairseq_model 的用法。

用法:

torchaudio.models.wav2vec2.utils.import_fairseq_model(original: torch.nn.Module) → torchaudio.models.Wav2Vec2Model

參數

original(torch.nn.Module) -fairseq 的 Wav2Vec2.0 或 HuBERT 模型的一個實例。 fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoderfairseq.models.wav2vec.wav2vec2.Wav2Vec2Modelfairseq.models.hubert.hubert_asr.HubertEncoder 之一。

返回

導入型號。

返回類型

Wav2Vec2模型

fairseq 的相應模型對象構建 Wav2Vec2Model 。

示例 - 加載 pretrain-only 模型
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
>>> features, _ = imported.extract_features(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> reference = original.feature_extractor(waveform).transpose(1, 2)
>>> torch.testing.assert_allclose(features, reference)
示例 - Fine-tuned 型號
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model
>>>
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small_960h.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original.w2v_encoder)
>>>
>>> # Perform encoding
>>> waveform, _ = torchaudio.load('audio.wav')
>>> emission, _ = imported(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchaudio.models.wav2vec2.utils.import_fairseq_model。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。