当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Shuffler用法及代码示例


本文简要介绍python语言中 torchdata.datapipes.iter.Shuffler 的用法。

用法:

class torchdata.datapipes.iter.Shuffler(datapipe: IterDataPipe[T_co], *, default: bool = True, buffer_size: int = 10000, unbatch_level: int = 0)

参数

  • datapipe-正在洗牌的IterDataPipe

  • buffer_size-洗牌的缓冲区大小(默认为 10000 )

  • unbatch_level-指定在应用 shuffle 之前是否需要取消批处理源数据

使用缓冲区对输入 DataPipe 进行混洗(函数名称:shuffle )。 buffer_size 的缓冲区首先填充来自数据管道的元素。然后,每个项目将通过迭代器通过容器采样从缓冲区中产生。

buffer_size 必须大于 0 。对于 buffer_size == 1 ,数据管道不会被洗牌。为了完全打乱数据管道中的所有元素,buffer_size 必须大于或等于数据管道的大小。

当它与 torch.utils.data.DataLoader 一起使用时,设置随机种子的方法根据 num_workers 不同。

对于single-process模式(num_workers == 0),随机种子设置在主进程中的DataLoader之前。对于multi-process 模式(num_worker > 0),worker_init_fn 用于为每个工作进程设置随机种子。

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(range(10))
>>> shuffle_dp = dp.shuffle()
[0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
>>> list(shuffle_dp)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchdata.datapipes.iter.Shuffler。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。