當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch RNNTLoss用法及代碼示例


本文簡要介紹python語言中 torchaudio.transforms.RNNTLoss 的用法。

用法:

class torchaudio.transforms.RNNTLoss(blank: int = - 1, clamp: float = - 1.0, reduction: str = 'mean')

參數

  • blank(int,可選的) -空白標簽(默認:-1)

  • clamp(float,可選的) -漸變夾(默認值:-1)

  • reduction(string,可選的) -指定要應用於輸出的縮減:'none' | 'mean' | 'sum' 。 (默認:'mean')

計算 RNN 換能器損失循環神經網絡的序列轉導[5]。 RNN 換能器損失通過定義所有長度的輸出序列上的分布以及通過對輸入-輸出和output-output 依賴項進行聯合建模來擴展 CTC 損失。

示例
>>> # Hypothetical values
>>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.6, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.2, 0.8, 0.1]],
>>>                         [[0.1, 0.6, 0.1, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.2, 0.1, 0.1],
>>>                          [0.7, 0.1, 0.2, 0.1, 0.1]]]],
>>>                       dtype=torch.float32,
>>>                       requires_grad=True)
>>> targets = torch.tensor([[1, 2]], dtype=torch.int)
>>> logit_lengths = torch.tensor([2], dtype=torch.int)
>>> target_lengths = torch.tensor([2], dtype=torch.int)
>>> transform = transforms.RNNTLoss(blank=0)
>>> loss = transform(logits, targets, logit_lengths, target_lengths)
>>> loss.backward()

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchaudio.transforms.RNNTLoss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。