当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Grouper用法及代码示例


本文简要介绍python语言中 torchdata.datapipes.iter.Grouper 的用法。

用法:

class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[torch.utils.data.datapipes.iter.grouping.T_co], group_key_fn: Callable, *, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)

参数

  • datapipe-要分组的可迭代数据管道

  • group_key_fn-用于从源数据管道的数据生成组键的函数

  • buffer_size-未分组数据的缓冲区大小

  • group_size-每组的最大大小,达到这个大小就产生一个批次

  • guaranteed_group_size-在缓冲区已满的情况下保证产生的最小组大小

  • drop_remaining-指定当缓冲区已满时是否将小于 guaranteed_group_size 的组从缓冲区中删除

根据 group_key_fn 生成的键对输入 IterDataPipe 中的数据进行分组,并生成批量大小最大为 group_size(如果已定义)的 DataChunk(函数名称:groupby )。

从源 datapipe 中顺序读取样本,当批次大小达到 group_size 时,将产生一批属于同一组的样本。当缓冲区已满时, DataPipe 将产生具有相同 key 的最大批次,前提是其大小大于 guaranteed_group_size 。如果它的大小较小,则 drop_remaining=True 将会被删除。

在遍历整个源 datapipe 之后,由于缓冲区容量而未丢弃的所有内容都将从缓冲区中产生,即使组大小小于 guaranteed_group_size

示例

>>> import os
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file):
...    return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
>>> # A group is yielded as soon as its size equals to `group_size`
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
>>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
>>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchdata.datapipes.iter.Grouper。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。