當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch CTCLoss用法及代碼示例


本文簡要介紹python語言中 torch.nn.CTCLoss 的用法。

用法:

class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

參數

  • blank(int,可選的) -空白標簽。默認

  • reduction(string,可選的) -指定要應用於輸出的縮減:'none' | 'mean' | 'sum''none' :不會應用減少,'mean' :輸出損失將除以目標長度,然後取批次的平均值。默認值:'mean'

  • zero_infinity(bool,可選的) -是否將無限損失和相關梯度歸零。默認值:False 無限損失主要發生在輸入太短而無法與目標對齊時。

聯結主義時間分類損失。

計算連續(未分段)時間序列和目標序列之間的損失。 CTCLoss 對輸入與目標可能對齊的概率求和,產生一個相對於每個輸入節點可微分的損失值。輸入到目標的對齊被假定為“many-to-one”,這限製了目標序列的長度,因此它必須是 的輸入長度。

形狀:
  • Log_probs:大小為 的張量,其中 。輸出的對數概率(例如,使用 torch.nn.functional.log_softmax() 獲得)。

  • 目標:大小為 的張量,其中 。它代表目標序列。目標序列中的每個元素都是一個類索引。並且目標索引不能為空(默認=0)。在 形式中,目標被填充到最長序列的長度,並堆疊。在 形式中,假定目標為un-padded 並在一維內連接。

  • Input_lengths:大小為 的元組或張量,其中 。它表示輸入的長度(每個都必須是 )。並且在序列被填充到相等長度的假設下,為每個序列指定長度以實現掩蔽。

  • Target_lengths:大小為 的元組或張量,其中 。它代表目標的長度。在序列被填充到相等長度的假設下,為每個序列指定長度以實現掩蔽。如果目標形狀是 ,則 target_lengths 實際上是每個目標序列的停止索引 ,因此批次中每個目標的 target_n = targets[n,0:s_n]。每個長度都必須是 如果目標是作為單個目標串聯的一維張量給出的,則 target_lengths 必須加起來就是張量的總長度。

  • 輸出:標量。如果 reduction'none' ,那麽 ,其中

例子:

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
參考:

A. Graves 等人:連接主義時間分類:使用遞歸神經網絡標記未分段的序列數據:https://www.cs.toronto.edu/~graves/icml_2006.pdf

注意

為了使用 CuDNN,必須滿足以下條件: targets 必須是連接格式,所有 input_lengths 必須是 T , target_lengths ,整數參數必須是 dtype torch.int32

常規實現使用(在 PyTorch 中更常見)torch.long dtype。

注意

在某些情況下,當將 CUDA 後端與 CuDNN 一起使用時,此運算符可能會選擇一種非確定性算法來提高性能。如果這是不可取的,您可以嘗試通過設置 torch.backends.cudnn.deterministic = True 來使操作具有確定性(可能以性能為代價)。請參閱有關背景的可重複性說明。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.CTCLoss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。