當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch gumbel_softmax用法及代碼示例


本文簡要介紹python語言中 torch.nn.functional.gumbel_softmax 的用法。

用法:

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)

參數

  • logits-[…, num_features] 非標準化日誌概率

  • tau-非負標量溫度

  • hard-如果 True ,返回的樣本將被離散化為 one-hot 向量,但會被微分,就好像它是 autograd 中的軟樣本

  • dim(int) -計算 softmax 的維度。默認值:-1。

返回

Gumbel-Softmax 分布中與 logits 形狀相同的采樣張量。如果是 hard=True ,則返回的樣本將為 one-hot,否則它們將是在 dim 中總和為 1 的概率分布。

來自 Gumbel-Softmax 分布 ( Link 1 Link 2 ) 的樣本,並可選擇離散化。

注意

由於遺留原因,此函數在這裏,將來可能會從 nn.Functional 中刪除。

注意

hard 的主要技巧是執行 y_hard - y_soft.detach() + y_soft

它實現了兩件事: - 使輸出值精確one-hot(因為我們添加然後減去 y_soft 值) - 使梯度等於 y_soft 梯度(因為我們去除所有其他梯度)

例子::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.functional.gumbel_softmax。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。