當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.nn.ctc_loss用法及代碼示例


計算 CTC(連接主義時間分類)損失。

用法

tf.compat.v1.nn.ctc_loss(
    labels, inputs=None, sequence_length=None, preprocess_collapse_repeated=False,
    ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False,
    time_major=True, logits=None
)

參數

  • labels 一個 int32 SparseTensorlabels.indices[i,:] == [b, t] 表示 labels.values[i] 存儲 (batch b, time t) 的 id。 labels.values[i] 必須采用 [0, num_labels) 中的值。有關詳細信息,請參閱core/ops/ctc_ops.cc
  • inputs 3-D float Tensor 。如果 time_major == False,這將是一個 Tensor 形狀:[batch_size, max_time, num_classes]。如果 time_major == True(默認),這將是一個 Tensor 形狀:[max_time, batch_size, num_classes]。日誌。
  • sequence_length 一維 int32 向量,大小 [batch_size] 。序列長度。
  • preprocess_collapse_repeated 布爾值。默認值:假。如果為 True,則在 CTC 計算之前折疊重複的標簽。
  • ctc_merge_repeated 布爾值。默認值:真。
  • ignore_longer_outputs_than_inputs 布爾值。默認值:假。如果為 True,則輸出比輸入長的序列將被忽略。
  • time_major inputs 張量的形狀格式。如果為 True,則這些 Tensors 的形狀必須為 [max_time, batch_size, num_classes] 。如果為 False,則這些 Tensors 的形狀必須為 [batch_size, max_time, num_classes] 。使用time_major = True(默認)更有效,因為它避免了在ctc_loss 計算開始時的轉置。然而,大多數 TensorFlow 數據是batch-major,所以這個函數也接受 batch-major 形式的輸入。
  • logits 輸入的別名。

返回

  • 一維 float Tensor ,大小 [batch] ,包含負對數概率。

拋出

  • TypeError 如果標簽不是 SparseTensor

該操作實現了 (Graves et al., 2006) 中提出的 CTC 損失。

輸入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

注意:

此類為您執行 softmax 操作,因此輸入應為例如LSTM 輸出的線性投影。

inputs Tensor 的最內層維度大小 num_classes 表示 num_labels + 1 類,其中 num_labels 是真實標簽的數量,最大值 (num_classes - 1) 保留給空白標簽。

例如,對於包含 3 個標簽 [a, b, c] , num_classes = 4 且標簽索引為 {a:0, b:1, c:2, blank:3} 的詞匯表。

關於參數 preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated 為 True,則在損失計算之前運行預處理步驟,其中將傳遞給損失的重複標簽合並為單個標簽。如果訓練標簽來自例如強製對齊並因此具有不必要的重複,這將很有用。

如果 ctc_merge_repeated 設置為 False,則在 CTC 計算的深處,重複的非空白標簽將不會被合並並被解釋為單獨的標簽。這是 CTC 的簡化(非標準)版本。

這是(大致)預期的一階行為的表格:

  • preprocess_collapse_repeated=False , ctc_merge_repeated=True

    經典的 CTC 行為:輸出真正的重複類,中間有空格,也可以輸出中間沒有空格的重複類,需要被解碼器折疊。

  • preprocess_collapse_repeated=True , ctc_merge_repeated=False

    永遠不會學習輸出重複的類,因為它們在訓練之前被折疊在輸入標簽中。

  • preprocess_collapse_repeated=False , ctc_merge_repeated=False

    輸出中間有空格的重複類,但通常不需要解碼器折疊/合並重複類。

  • preprocess_collapse_repeated=True , ctc_merge_repeated=True

    未經測試。很可能不會學習輸出重複的類。

ignore_longer_outputs_than_inputs 選項允許在處理輸出比輸入長的序列時指定 CTCLoss 的行為。如果為真,CTCLoss 將簡單地為這些項目返回零梯度,否則返回 InvalidArgument 錯誤,停止訓練。

參考:

Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: Graves et al., 2006 (pdf)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.ctc_loss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。