当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.nn.ctc_loss用法及代码示例


计算 CTC(连接主义时间分类)损失。

用法

tf.compat.v1.nn.ctc_loss(
    labels, inputs=None, sequence_length=None, preprocess_collapse_repeated=False,
    ctc_merge_repeated=True, ignore_longer_outputs_than_inputs=False,
    time_major=True, logits=None
)

参数

  • labels 一个 int32 SparseTensorlabels.indices[i,:] == [b, t] 表示 labels.values[i] 存储 (batch b, time t) 的 id。 labels.values[i] 必须采用 [0, num_labels) 中的值。有关详细信息,请参阅core/ops/ctc_ops.cc
  • inputs 3-D float Tensor 。如果 time_major == False,这将是一个 Tensor 形状:[batch_size, max_time, num_classes]。如果 time_major == True(默认),这将是一个 Tensor 形状:[max_time, batch_size, num_classes]。日志。
  • sequence_length 一维 int32 向量,大小 [batch_size] 。序列长度。
  • preprocess_collapse_repeated 布尔值。默认值:假。如果为 True,则在 CTC 计算之前折叠重复的标签。
  • ctc_merge_repeated 布尔值。默认值:真。
  • ignore_longer_outputs_than_inputs 布尔值。默认值:假。如果为 True,则输出比输入长的序列将被忽略。
  • time_major inputs 张量的形状格式。如果为 True,则这些 Tensors 的形状必须为 [max_time, batch_size, num_classes] 。如果为 False,则这些 Tensors 的形状必须为 [batch_size, max_time, num_classes] 。使用time_major = True(默认)更有效,因为它避免了在ctc_loss 计算开始时的转置。然而,大多数 TensorFlow 数据是batch-major,所以这个函数也接受 batch-major 形式的输入。
  • logits 输入的别名。

返回

  • 一维 float Tensor ,大小 [batch] ,包含负对数概率。

抛出

  • TypeError 如果标签不是 SparseTensor

该操作实现了 (Graves et al., 2006) 中提出的 CTC 损失。

输入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

注意:

此类为您执行 softmax 操作,因此输入应为例如LSTM 输出的线性投影。

inputs Tensor 的最内层维度大小 num_classes 表示 num_labels + 1 类,其中 num_labels 是真实标签的数量,最大值 (num_classes - 1) 保留给空白标签。

例如,对于包含 3 个标签 [a, b, c] , num_classes = 4 且标签索引为 {a:0, b:1, c:2, blank:3} 的词汇表。

关于参数 preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated 为 True,则在损失计算之前运行预处理步骤,其中将传递给损失的重复标签合并为单个标签。如果训练标签来自例如强制对齐并因此具有不必要的重复,这将很有用。

如果 ctc_merge_repeated 设置为 False,则在 CTC 计算的深处,重复的非空白标签将不会被合并并被解释为单独的标签。这是 CTC 的简化(非标准)版本。

这是(大致)预期的一阶行为的表格:

  • preprocess_collapse_repeated=False , ctc_merge_repeated=True

    经典的 CTC 行为:输出真正的重复类,中间有空格,也可以输出中间没有空格的重复类,需要被解码器折叠。

  • preprocess_collapse_repeated=True , ctc_merge_repeated=False

    永远不会学习输出重复的类,因为它们在训练之前被折叠在输入标签中。

  • preprocess_collapse_repeated=False , ctc_merge_repeated=False

    输出中间有空格的重复类,但通常不需要解码器折叠/合并重复类。

  • preprocess_collapse_repeated=True , ctc_merge_repeated=True

    未经测试。很可能不会学习输出重复的类。

ignore_longer_outputs_than_inputs 选项允许在处理输出比输入长的序列时指定 CTCLoss 的行为。如果为真,CTCLoss 将简单地为这些项目返回零梯度,否则返回 InvalidArgument 错误,停止训练。

参考:

Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: Graves et al., 2006 (pdf)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.nn.ctc_loss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。