當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch BucketBatcher用法及代碼示例


本文簡要介紹python語言中 torchdata.datapipes.iter.BucketBatcher 的用法。

用法:

class torchdata.datapipes.iter.BucketBatcher(datapipe: IterDataPipe[T_co], batch_size: int, drop_last: bool = False, batch_num: int = 100, bucket_num: int = 1, sort_key: Optional[Callable] = None, in_batch_shuffle: bool = True)

參數

  • datapipe-可迭代DataPipe正在批量處理

  • batch_size-每批的大小

  • drop_last-如果最後一批未滿,可選擇刪除最後一批

  • batch_num-桶內的批次數(即 bucket_size = batch_size * batch_num )

  • bucket_num-構成洗牌池的桶數(即 pool_size = bucket_size * bucket_num )

  • sort_key-可調用以對存儲桶(列表)進行排序

  • in_batch_shuffle-如果為真,請執行 in-batch Shuffle[洗牌];如果為 False,則緩衝區洗牌

從排序存儲桶創建小批量數據(函數名稱:bucketbatch)。如果 drop_last 設置為 True ,則外部尺寸將添加為 batch_size ;如果 drop_last 設置為 False ,則最後一批的外部尺寸將添加為 length % batch_size

這個DataPipe的目的是根據傳遞的排序函數對具有一定相似性的樣本進行批處理。對於文本域中的示例,它可以對具有相似數量的令牌的示例進行批處理,以最小化填充並增加吞吐量。

示例

>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper(range(10))
>>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True)
>>> list(batch_dp)
[[5, 6, 7], [9, 0, 1], [4, 3, 2]]
>>> def sort_bucket(bucket):
>>>     return sorted(bucket)
>>> batch_dp = source_dp.bucketbatch(
>>>     batch_size=3, drop_last=True, batch_num=100,
>>>     bucket_num=1, in_batch_shuffle=False, sort_key=sort_bucket
>>> )
>>> list(batch_dp)
[[3, 4, 5], [6, 7, 8], [0, 1, 2]]

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchdata.datapipes.iter.BucketBatcher。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。