當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch Grouper用法及代碼示例


本文簡要介紹python語言中 torchdata.datapipes.iter.Grouper 的用法。

用法:

class torchdata.datapipes.iter.Grouper(datapipe: IterDataPipe[torch.utils.data.datapipes.iter.grouping.T_co], group_key_fn: Callable, *, buffer_size: int = 10000, group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False)

參數

  • datapipe-要分組的可迭代數據管道

  • group_key_fn-用於從源數據管道的數據生成組鍵的函數

  • buffer_size-未分組數據的緩衝區大小

  • group_size-每組的最大大小,達到這個大小就產生一個批次

  • guaranteed_group_size-在緩衝區已滿的情況下保證產生的最小組大小

  • drop_remaining-指定當緩衝區已滿時是否將小於 guaranteed_group_size 的組從緩衝區中刪除

根據 group_key_fn 生成的鍵對輸入 IterDataPipe 中的數據進行分組,並生成批量大小最大為 group_size(如果已定義)的 DataChunk(函數名稱:groupby )。

從源 datapipe 中順序讀取樣本,當批次大小達到 group_size 時,將產生一批屬於同一組的樣本。當緩衝區已滿時, DataPipe 將產生具有相同 key 的最大批次,前提是其大小大於 guaranteed_group_size 。如果它的大小較小,則 drop_remaining=True 將會被刪除。

在遍曆整個源 datapipe 之後,由於緩衝區容量而未丟棄的所有內容都將從緩衝區中產生,即使組大小小於 guaranteed_group_size

示例

>>> import os
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def group_fn(file):
...    return os.path.basename(file).split(".")[0]
>>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
>>> dp0 = source_dp.groupby(group_key_fn=group_fn)
>>> list(dp0)
[['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
>>> # A group is yielded as soon as its size equals to `group_size`
>>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
>>> list(dp1)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
>>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
>>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
>>> list(dp2)
[['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchdata.datapipes.iter.Grouper。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。