計算 CTC(連接主義時間分類)損失。
用法
tf.compat.v1.nn.ctc_loss(
labels, inputs=None, sequence_length=None, preprocess_collapse_repeated=False,
ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False,
time_major=True, logits=None
)
參數
-
labels
一個int32
SparseTensor
。labels.indices[i,:] == [b, t]
表示labels.values[i]
存儲 (batch b, time t) 的 id。labels.values[i]
必須采用[0, num_labels)
中的值。有關詳細信息,請參閱core/ops/ctc_ops.cc
。 -
inputs
3-Dfloat
Tensor
。如果 time_major == False,這將是一個Tensor
形狀:[batch_size, max_time, num_classes]
。如果 time_major == True(默認),這將是一個Tensor
形狀:[max_time, batch_size, num_classes]
。日誌。 -
sequence_length
一維int32
向量,大小[batch_size]
。序列長度。 -
preprocess_collapse_repeated
布爾值。默認值:假。如果為 True,則在 CTC 計算之前折疊重複的標簽。 -
ctc_merge_repeated
布爾值。默認值:真。 -
ignore_longer_outputs_than_inputs
布爾值。默認值:假。如果為 True,則輸出比輸入長的序列將被忽略。 -
time_major
inputs
張量的形狀格式。如果為 True,則這些Tensors
的形狀必須為[max_time, batch_size, num_classes]
。如果為 False,則這些Tensors
的形狀必須為[batch_size, max_time, num_classes]
。使用time_major = True
(默認)更有效,因為它避免了在ctc_loss 計算開始時的轉置。然而,大多數 TensorFlow 數據是batch-major,所以這個函數也接受 batch-major 形式的輸入。 -
logits
輸入的別名。
返回
-
一維
float
Tensor
,大小[batch]
,包含負對數概率。
拋出
-
TypeError
如果標簽不是SparseTensor
。
該操作實現了 (Graves et al., 2006) 中提出的 CTC 損失。
輸入要求:
sequence_length(b) <= time for all b
max(labels.indices(labels.indices[:, 1] == b, 2))
<= sequence_length(b) for all b.
注意:
此類為您執行 softmax 操作,因此輸入應為例如LSTM 輸出的線性投影。
inputs
Tensor 的最內層維度大小 num_classes
表示 num_labels + 1
類,其中 num_labels 是真實標簽的數量,最大值 (num_classes - 1)
保留給空白標簽。
例如,對於包含 3 個標簽 [a, b, c]
, num_classes = 4
且標簽索引為 {a:0, b:1, c:2, blank:3}
的詞匯表。
關於參數 preprocess_collapse_repeated
和 ctc_merge_repeated
:
如果preprocess_collapse_repeated
為 True,則在損失計算之前運行預處理步驟,其中將傳遞給損失的重複標簽合並為單個標簽。如果訓練標簽來自例如強製對齊並因此具有不必要的重複,這將很有用。
如果 ctc_merge_repeated
設置為 False,則在 CTC 計算的深處,重複的非空白標簽將不會被合並並被解釋為單獨的標簽。這是 CTC 的簡化(非標準)版本。
這是(大致)預期的一階行為的表格:
preprocess_collapse_repeated=False
,ctc_merge_repeated=True
經典的 CTC 行為:輸出真正的重複類,中間有空格,也可以輸出中間沒有空格的重複類,需要被解碼器折疊。
preprocess_collapse_repeated=True
,ctc_merge_repeated=False
永遠不會學習輸出重複的類,因為它們在訓練之前被折疊在輸入標簽中。
preprocess_collapse_repeated=False
,ctc_merge_repeated=False
輸出中間有空格的重複類,但通常不需要解碼器折疊/合並重複類。
preprocess_collapse_repeated=True
,ctc_merge_repeated=True
未經測試。很可能不會學習輸出重複的類。
ignore_longer_outputs_than_inputs
選項允許在處理輸出比輸入長的序列時指定 CTCLoss 的行為。如果為真,CTCLoss 將簡單地為這些項目返回零梯度,否則返回 InvalidArgument 錯誤,停止訓練。
參考:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: Graves et al., 2006 (pdf)
相關用法
- Python tf.compat.v1.nn.convolution用法及代碼示例
- Python tf.compat.v1.nn.conv2d用法及代碼示例
- Python tf.compat.v1.nn.static_rnn用法及代碼示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代碼示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代碼示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.separable_conv2d用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代碼示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代碼示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代碼示例
- Python tf.compat.v1.nn.nce_loss用法及代碼示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代碼示例
- Python tf.compat.v1.nn.pool用法及代碼示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代碼示例
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代碼示例
- Python tf.compat.v1.nn.erosion2d用法及代碼示例
- Python tf.compat.v1.nn.raw_rnn用法及代碼示例
- Python tf.compat.v1.nn.dilation2d用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.nn.ctc_loss。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。