本文简要介绍python语言中 torch.nn.functional.gumbel_softmax
的用法。
用法:
torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)
logits-
[…, num_features]
非标准化日志概率tau-非负标量温度
hard-如果
True
,返回的样本将被离散化为 one-hot 向量,但会被微分,就好像它是 autograd 中的软样本dim(int) -计算 softmax 的维度。默认值:-1。
Gumbel-Softmax 分布中与
logits
形状相同的采样张量。如果是hard=True
,则返回的样本将为 one-hot,否则它们将是在dim
中总和为 1 的概率分布。来自 Gumbel-Softmax 分布 ( Link 1 Link 2 ) 的样本,并可选择离散化。
注意
由于遗留原因,此函数在这里,将来可能会从 nn.Functional 中删除。
注意
hard
的主要技巧是执行y_hard - y_soft.detach() + y_soft
它实现了两件事: - 使输出值精确one-hot(因为我们添加然后减去 y_soft 值) - 使梯度等于 y_soft 梯度(因为我们去除所有其他梯度)
- 例子::
>>> logits = torch.randn(20, 32) >>> # Sample soft categorical using reparametrization trick: >>> F.gumbel_softmax(logits, tau=1, hard=False) >>> # Sample hard categorical using "Straight-through" trick: >>> F.gumbel_softmax(logits, tau=1, hard=True)
参数:
返回:
相关用法
- Python PyTorch get_tokenizer用法及代码示例
- Python PyTorch gammainc用法及代码示例
- Python PyTorch gradient用法及代码示例
- Python PyTorch gammaincc用法及代码示例
- Python PyTorch global_unstructured用法及代码示例
- Python PyTorch greedy_partition用法及代码示例
- Python PyTorch gammaln用法及代码示例
- Python PyTorch get_gradients用法及代码示例
- Python PyTorch get_ignored_functions用法及代码示例
- Python PyTorch get_default_dtype用法及代码示例
- Python PyTorch gt用法及代码示例
- Python PyTorch gather用法及代码示例
- Python PyTorch gcd用法及代码示例
- Python PyTorch get_graph_node_names用法及代码示例
- Python PyTorch get_testing_overrides用法及代码示例
- Python PyTorch generate_sp_model用法及代码示例
- Python PyTorch gather_object用法及代码示例
- Python PyTorch ge用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch vdot用法及代码示例
- Python PyTorch ELU用法及代码示例
- Python PyTorch ScaledDotProduct.__init__用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.functional.gumbel_softmax。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。