当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch BucketBatcher用法及代码示例


本文简要介绍python语言中 torchdata.datapipes.iter.BucketBatcher 的用法。

用法:

class torchdata.datapipes.iter.BucketBatcher(datapipe: IterDataPipe[T_co], batch_size: int, drop_last: bool = False, batch_num: int = 100, bucket_num: int = 1, sort_key: Optional[Callable] = None, in_batch_shuffle: bool = True)

参数

  • datapipe-可迭代DataPipe正在批量处理

  • batch_size-每批的大小

  • drop_last-如果最后一批未满,可选择删除最后一批

  • batch_num-桶内的批次数(即 bucket_size = batch_size * batch_num )

  • bucket_num-构成洗牌池的桶数(即 pool_size = bucket_size * bucket_num )

  • sort_key-可调用以对存储桶(列表)进行排序

  • in_batch_shuffle-如果为真,请执行 in-batch Shuffle[洗牌];如果为 False,则缓冲区洗牌

从排序存储桶创建小批量数据(函数名称:bucketbatch)。如果 drop_last 设置为 True ,则外部尺寸将添加为 batch_size ;如果 drop_last 设置为 False ,则最后一批的外部尺寸将添加为 length % batch_size

这个DataPipe的目的是根据传递的排序函数对具有一定相似性的样本进行批处理。对于文本域中的示例,它可以对具有相似数量的令牌的示例进行批处理,以最小化填充并增加吞吐量。

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(range(10))
>>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True)
>>> list(batch_dp)
[[5, 6, 7], [9, 0, 1], [4, 3, 2]]
>>> def sort_bucket(bucket):
>>>     return sorted(bucket)
>>> batch_dp = source_dp.bucketbatch(
>>>     batch_size=3, drop_last=True, batch_num=100,
>>>     bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket
>>> )
>>> list(batch_dp)
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchdata.datapipes.iter.BucketBatcher。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。