当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch gumbel_softmax用法及代码示例


本文简要介绍python语言中 torch.nn.functional.gumbel_softmax 的用法。

用法:

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)

参数

  • logits-[…, num_features] 非标准化日志概率

  • tau-非负标量温度

  • hard-如果 True ,返回的样本将被离散化为 one-hot 向量,但会被微分,就好像它是 autograd 中的软样本

  • dim(int) -计算 softmax 的维度。默认值:-1。

返回

Gumbel-Softmax 分布中与 logits 形状相同的采样张量。如果是 hard=True ,则返回的样本将为 one-hot,否则它们将是在 dim 中总和为 1 的概率分布。

来自 Gumbel-Softmax 分布 ( Link 1 Link 2 ) 的样本,并可选择离散化。

注意

由于遗留原因,此函数在这里,将来可能会从 nn.Functional 中删除。

注意

hard 的主要技巧是执行 y_hard - y_soft.detach() + y_soft

它实现了两件事: - 使输出值精确one-hot(因为我们添加然后减去 y_soft 值) - 使梯度等于 y_soft 梯度(因为我们去除所有其他梯度)

例子::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.functional.gumbel_softmax。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。