本文簡要介紹python語言中 torch.nn.CTCLoss
的用法。
用法:
class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
聯結主義時間分類損失。
計算連續(未分段)時間序列和目標序列之間的損失。 CTCLoss 對輸入與目標可能對齊的概率求和,產生一個相對於每個輸入節點可微分的損失值。輸入到目標的對齊被假定為“many-to-one”,這限製了目標序列的長度,因此它必須是 的輸入長度。
- 形狀:
Log_probs:大小為
torch.nn.functional.log_softmax()
獲得)。 的張量,其中 、 和 。輸出的對數概率(例如,使用目標:大小為 或 的張量,其中 和 。它代表目標序列。目標序列中的每個元素都是一個類索引。並且目標索引不能為空(默認=0)。在 形式中,目標被填充到最長序列的長度,並堆疊。在 形式中,假定目標為un-padded 並在一維內連接。
Input_lengths:大小為 的元組或張量,其中 。它表示輸入的長度(每個都必須是 )。並且在序列被填充到相等長度的假設下,為每個序列指定長度以實現掩蔽。
Target_lengths:大小為
target_n = targets[n,0:s_n]
。每個長度都必須是 如果目標是作為單個目標串聯的一維張量給出的,則 target_lengths 必須加起來就是張量的總長度。 的元組或張量,其中 。它代表目標的長度。在序列被填充到相等長度的假設下,為每個序列指定長度以實現掩蔽。如果目標形狀是 ,則 target_lengths 實際上是每個目標序列的停止索引 ,因此批次中每個目標的輸出:標量。如果
reduction
是'none'
,那麽 ,其中 。
例子:
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 參考:
A. Graves 等人:連接主義時間分類:使用遞歸神經網絡標記未分段的序列數據:https://www.cs.toronto.edu/~graves/icml_2006.pdf
注意
為了使用 CuDNN,必須滿足以下條件:
targets
必須是連接格式,所有input_lengths
必須是T
。 ,target_lengths
,整數參數必須是 dtypetorch.int32
。常規實現使用(在 PyTorch 中更常見)
torch.long
dtype。注意
在某些情況下,當將 CUDA 後端與 CuDNN 一起使用時,此運算符可能會選擇一種非確定性算法來提高性能。如果這是不可取的,您可以嘗試通過設置
torch.backends.cudnn.deterministic = True
來使操作具有確定性(可能以性能為代價)。請參閱有關背景的可重複性說明。
參數:
相關用法
- Python PyTorch Collator用法及代碼示例
- Python PyTorch ConvTranspose3d用法及代碼示例
- Python PyTorch Conv1d用法及代碼示例
- Python PyTorch CSVParser用法及代碼示例
- Python PyTorch CosineAnnealingWarmRestarts.step用法及代碼示例
- Python PyTorch CrossEntropyLoss用法及代碼示例
- Python PyTorch ChannelShuffle用法及代碼示例
- Python PyTorch CocoCaptions用法及代碼示例
- Python PyTorch CSVDictParser用法及代碼示例
- Python PyTorch ContinuousBernoulli用法及代碼示例
- Python PyTorch Cityscapes用法及代碼示例
- Python PyTorch ChainedScheduler用法及代碼示例
- Python PyTorch Cauchy用法及代碼示例
- Python PyTorch ConstantPad2d用法及代碼示例
- Python PyTorch CriteoIterDataPipe用法及代碼示例
- Python PyTorch ComplexNorm用法及代碼示例
- Python PyTorch ConvTranspose2d用法及代碼示例
- Python PyTorch CppExtension用法及代碼示例
- Python PyTorch Concater用法及代碼示例
- Python PyTorch Compose用法及代碼示例
- Python PyTorch Chi2用法及代碼示例
- Python PyTorch ConstantLR用法及代碼示例
- Python PyTorch Conv2d用法及代碼示例
- Python PyTorch CosineSimilarity用法及代碼示例
- Python PyTorch ConstantPad1d用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.CTCLoss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。