当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch RNNTLoss用法及代码示例


本文简要介绍python语言中 torchaudio.transforms.RNNTLoss 的用法。

用法:

class torchaudio.transforms.RNNTLoss(blank: int = - 1, clamp: float = - 1.0, reduction: str = 'mean')

参数

  • blank(int,可选的) -空白标签(默认:-1)

  • clamp(float,可选的) -渐变夹(默认值:-1)

  • reduction(string,可选的) -指定要应用于输出的缩减:'none' | 'mean' | 'sum' 。 (默认:'mean')

计算 RNN 换能器损失循环神经网络的序列转导[5]。 RNN 换能器损失通过定义所有长度的输出序列上的分布以及通过对输入-输出和output-output 依赖项进行联合建模来扩展 CTC 损失。

示例
>>> # Hypothetical values
>>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.6, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.2, 0.8, 0.1]],
>>>                         [[0.1, 0.6, 0.1, 0.1, 0.1],
>>>                          [0.1, 0.1, 0.2, 0.1, 0.1],
>>>                          [0.7, 0.1, 0.2, 0.1, 0.1]]]],
>>>                       dtype=torch.float32,
>>>                       requires_grad=True)
>>> targets = torch.tensor([[1, 2]], dtype=torch.int)
>>> logit_lengths = torch.tensor([2], dtype=torch.int)
>>> target_lengths = torch.tensor([2], dtype=torch.int)
>>> transform = transforms.RNNTLoss(blank=0)
>>> loss = transform(logits, targets, logit_lengths, target_lengths)
>>> loss.backward()

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchaudio.transforms.RNNTLoss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。