本文简要介绍python语言中 torch.nn.CTCLoss
的用法。
用法:
class torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
联结主义时间分类损失。
计算连续(未分段)时间序列和目标序列之间的损失。 CTCLoss 对输入与目标可能对齐的概率求和,产生一个相对于每个输入节点可微分的损失值。输入到目标的对齐被假定为“many-to-one”,这限制了目标序列的长度,因此它必须是 的输入长度。
- 形状:
Log_probs:大小为
torch.nn.functional.log_softmax()
获得)。 的张量,其中 、 和 。输出的对数概率(例如,使用目标:大小为 或 的张量,其中 和 。它代表目标序列。目标序列中的每个元素都是一个类索引。并且目标索引不能为空(默认=0)。在 形式中,目标被填充到最长序列的长度,并堆叠。在 形式中,假定目标为un-padded 并在一维内连接。
Input_lengths:大小为 的元组或张量,其中 。它表示输入的长度(每个都必须是 )。并且在序列被填充到相等长度的假设下,为每个序列指定长度以实现掩蔽。
Target_lengths:大小为
target_n = targets[n,0:s_n]
。每个长度都必须是 如果目标是作为单个目标串联的一维张量给出的,则 target_lengths 必须加起来就是张量的总长度。 的元组或张量,其中 。它代表目标的长度。在序列被填充到相等长度的假设下,为每个序列指定长度以实现掩蔽。如果目标形状是 ,则 target_lengths 实际上是每个目标序列的停止索引 ,因此批次中每个目标的输出:标量。如果
reduction
是'none'
,那么 ,其中 。
例子:
>>> # Target are to be padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> S = 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() >>> >>> >>> # Target are to be un-padded >>> T = 50 # Input sequence length >>> C = 20 # Number of classes (including blank) >>> N = 16 # Batch size >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
- 参考:
A. Graves 等人:连接主义时间分类:使用递归神经网络标记未分段的序列数据:https://www.cs.toronto.edu/~graves/icml_2006.pdf
注意
为了使用 CuDNN,必须满足以下条件:
targets
必须是连接格式,所有input_lengths
必须是T
。 ,target_lengths
,整数参数必须是 dtypetorch.int32
。常规实现使用(在 PyTorch 中更常见)
torch.long
dtype。注意
在某些情况下,当将 CUDA 后端与 CuDNN 一起使用时,此运算符可能会选择一种非确定性算法来提高性能。如果这是不可取的,您可以尝试通过设置
torch.backends.cudnn.deterministic = True
来使操作具有确定性(可能以性能为代价)。请参阅有关背景的可重复性说明。
参数:
相关用法
- Python PyTorch Collator用法及代码示例
- Python PyTorch ConvTranspose3d用法及代码示例
- Python PyTorch Conv1d用法及代码示例
- Python PyTorch CSVParser用法及代码示例
- Python PyTorch CosineAnnealingWarmRestarts.step用法及代码示例
- Python PyTorch CrossEntropyLoss用法及代码示例
- Python PyTorch ChannelShuffle用法及代码示例
- Python PyTorch CocoCaptions用法及代码示例
- Python PyTorch CSVDictParser用法及代码示例
- Python PyTorch ContinuousBernoulli用法及代码示例
- Python PyTorch Cityscapes用法及代码示例
- Python PyTorch ChainedScheduler用法及代码示例
- Python PyTorch Cauchy用法及代码示例
- Python PyTorch ConstantPad2d用法及代码示例
- Python PyTorch CriteoIterDataPipe用法及代码示例
- Python PyTorch ComplexNorm用法及代码示例
- Python PyTorch ConvTranspose2d用法及代码示例
- Python PyTorch CppExtension用法及代码示例
- Python PyTorch Concater用法及代码示例
- Python PyTorch Compose用法及代码示例
- Python PyTorch Chi2用法及代码示例
- Python PyTorch ConstantLR用法及代码示例
- Python PyTorch Conv2d用法及代码示例
- Python PyTorch CosineSimilarity用法及代码示例
- Python PyTorch ConstantPad1d用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.CTCLoss。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。