當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch multinomial用法及代碼示例


本文簡要介紹python語言中 torch.multinomial 的用法。

用法:

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor

參數

  • input(Tensor) -包含概率的輸入張量

  • num_samples(int) -要抽取的樣本數

  • replacement(bool,可選的) -是否用替換繪製

關鍵字參數

  • generator(torch.Generator, 可選的) -用於采樣的偽隨機數發生器

  • out(Tensor,可選的) -輸出張量。

返回一個張量,其中每行包含從位於張量 input 的相應行中的多項概率分布中采樣的 num_samples 索引。

注意

input 的行不需要總和為 1(在這種情況下,我們將值用作權重),但必須是非負的、有限的並且總和非零。

索引根據每個采樣的時間從左到右排序(第一個樣本放在第一列)。

如果 input 是向量,則 out 是大小為 num_samples 的向量。

如果 input 是具有 m 行的矩陣,則 out 是形狀為 的矩陣。

如果替換是 True ,則使用替換抽取樣本。

如果不是,它們將在不替換的情況下繪製,這意味著當為一行繪製樣本索引時,不能為該行再次繪製它。

注意

在不替換的情況下繪製時,num_samples 必須小於 input 中非零元素的數量(如果是矩陣,則為 input 每行中非零元素的最小數量)。

例子:

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.multinomial。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。