计算 CTC(连接主义时间分类)损失。
用法
tf.compat.v1.nn.ctc_loss(
labels, inputs=None, sequence_length=None, preprocess_collapse_repeated=False,
ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False,
time_major=True, logits=None
)
参数
-
labels
一个int32
SparseTensor
。labels.indices[i,:] == [b, t]
表示labels.values[i]
存储 (batch b, time t) 的 id。labels.values[i]
必须采用[0, num_labels)
中的值。有关详细信息,请参阅core/ops/ctc_ops.cc
。 -
inputs
3-Dfloat
Tensor
。如果 time_major == False,这将是一个Tensor
形状:[batch_size, max_time, num_classes]
。如果 time_major == True(默认),这将是一个Tensor
形状:[max_time, batch_size, num_classes]
。日志。 -
sequence_length
一维int32
向量,大小[batch_size]
。序列长度。 -
preprocess_collapse_repeated
布尔值。默认值:假。如果为 True,则在 CTC 计算之前折叠重复的标签。 -
ctc_merge_repeated
布尔值。默认值:真。 -
ignore_longer_outputs_than_inputs
布尔值。默认值:假。如果为 True,则输出比输入长的序列将被忽略。 -
time_major
inputs
张量的形状格式。如果为 True,则这些Tensors
的形状必须为[max_time, batch_size, num_classes]
。如果为 False,则这些Tensors
的形状必须为[batch_size, max_time, num_classes]
。使用time_major = True
(默认)更有效,因为它避免了在ctc_loss 计算开始时的转置。然而,大多数 TensorFlow 数据是batch-major,所以这个函数也接受 batch-major 形式的输入。 -
logits
输入的别名。
返回
-
一维
float
Tensor
,大小[batch]
,包含负对数概率。
抛出
-
TypeError
如果标签不是SparseTensor
。
该操作实现了 (Graves et al., 2006) 中提出的 CTC 损失。
输入要求:
sequence_length(b) <= time for all b
max(labels.indices(labels.indices[:, 1] == b, 2))
<= sequence_length(b) for all b.
注意:
此类为您执行 softmax 操作,因此输入应为例如LSTM 输出的线性投影。
inputs
Tensor 的最内层维度大小 num_classes
表示 num_labels + 1
类,其中 num_labels 是真实标签的数量,最大值 (num_classes - 1)
保留给空白标签。
例如,对于包含 3 个标签 [a, b, c]
, num_classes = 4
且标签索引为 {a:0, b:1, c:2, blank:3}
的词汇表。
关于参数 preprocess_collapse_repeated
和 ctc_merge_repeated
:
如果preprocess_collapse_repeated
为 True,则在损失计算之前运行预处理步骤,其中将传递给损失的重复标签合并为单个标签。如果训练标签来自例如强制对齐并因此具有不必要的重复,这将很有用。
如果 ctc_merge_repeated
设置为 False,则在 CTC 计算的深处,重复的非空白标签将不会被合并并被解释为单独的标签。这是 CTC 的简化(非标准)版本。
这是(大致)预期的一阶行为的表格:
preprocess_collapse_repeated=False
,ctc_merge_repeated=True
经典的 CTC 行为:输出真正的重复类,中间有空格,也可以输出中间没有空格的重复类,需要被解码器折叠。
preprocess_collapse_repeated=True
,ctc_merge_repeated=False
永远不会学习输出重复的类,因为它们在训练之前被折叠在输入标签中。
preprocess_collapse_repeated=False
,ctc_merge_repeated=False
输出中间有空格的重复类,但通常不需要解码器折叠/合并重复类。
preprocess_collapse_repeated=True
,ctc_merge_repeated=True
未经测试。很可能不会学习输出重复的类。
ignore_longer_outputs_than_inputs
选项允许在处理输出比输入长的序列时指定 CTCLoss 的行为。如果为真,CTCLoss 将简单地为这些项目返回零梯度,否则返回 InvalidArgument 错误,停止训练。
参考:
Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: Graves et al., 2006 (pdf)
相关用法
- Python tf.compat.v1.nn.convolution用法及代码示例
- Python tf.compat.v1.nn.conv2d用法及代码示例
- Python tf.compat.v1.nn.static_rnn用法及代码示例
- Python tf.compat.v1.nn.sufficient_statistics用法及代码示例
- Python tf.compat.v1.nn.dynamic_rnn用法及代码示例
- Python tf.compat.v1.nn.embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.separable_conv2d用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d_native用法及代码示例
- Python tf.compat.v1.nn.weighted_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.depthwise_conv2d用法及代码示例
- Python tf.compat.v1.nn.safe_embedding_lookup_sparse用法及代码示例
- Python tf.compat.v1.nn.nce_loss用法及代码示例
- Python tf.compat.v1.nn.sampled_softmax_loss用法及代码示例
- Python tf.compat.v1.nn.pool用法及代码示例
- Python tf.compat.v1.nn.sigmoid_cross_entropy_with_logits用法及代码示例
- Python tf.compat.v1.nn.rnn_cell.MultiRNNCell用法及代码示例
- Python tf.compat.v1.nn.erosion2d用法及代码示例
- Python tf.compat.v1.nn.raw_rnn用法及代码示例
- Python tf.compat.v1.nn.dilation2d用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.ctc_loss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。