當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch WeightedRandomSampler用法及代碼示例

本文簡要介紹python語言中 torch.utils.data.WeightedRandomSampler 的用法。

用法:

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

參數

  • weights(序列) -一係列權重,不一定總和為一個

  • num_samples(int) -要抽取的樣本數

  • replacement(bool) -如果 True ,則使用替換抽取樣本。如果不是,它們將在不替換的情況下繪製,這意味著當為一行繪製樣本索引時,不能為該行再次繪製它。

  • generator(torch.Generator) -采樣中使用的生成器。

以給定的概率(權重)從[0,..,len(weights)-1] 中采樣元素。

示例

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.data.WeightedRandomSampler。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。