本文簡要介紹python語言中 torch.utils.data.IterableDataset
的用法。
用法:
class torch.utils.data.IterableDataset(*args, **kwds)
一個可迭代的數據集。
所有代表數據樣本可迭代的數據集都應該對其進行子類化。當數據來自流時,這種形式的數據集特別有用。
所有子類都應覆蓋
__iter__()
,這將返回此數據集中樣本的迭代器。當子類與
DataLoader
一起使用時,數據集中的每個項目都將從DataLoader
迭代器生成。當num_workers > 0
時,每個工作進程將擁有不同的數據集對象副本,因此通常需要獨立配置每個副本,以避免從工作進程返回重複的數據。get_worker_info()
在工作進程中調用時,返回有關工作進程的信息。它可以在數據集的__iter__()
方法或DataLoader
的worker_init_fn
選項中使用,以修改每個副本的行為。示例 1:在
__iter__()
中將工作負載分配給所有工作人員:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6]
示例 2:使用
worker_init_fn
將工作負載分配給所有工作人員:>>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
相關用法
- Python PyTorch IterableWrapper用法及代碼示例
- Python PyTorch IterDataPipe用法及代碼示例
- Python PyTorch IterKeyZipper用法及代碼示例
- Python PyTorch InMemoryCacheHolder用法及代碼示例
- Python PyTorch IndexAdder用法及代碼示例
- Python PyTorch IoPathSaver用法及代碼示例
- Python PyTorch InMemoryBinaryCriteoIterDataPipe用法及代碼示例
- Python PyTorch Identity用法及代碼示例
- Python PyTorch IoPathFileOpener用法及代碼示例
- Python PyTorch Independent用法及代碼示例
- Python PyTorch Interpreter用法及代碼示例
- Python PyTorch InstanceNorm2d用法及代碼示例
- Python PyTorch InstanceNorm3d用法及代碼示例
- Python PyTorch IoPathFileLister用法及代碼示例
- Python PyTorch ImageFolder用法及代碼示例
- Python PyTorch IWSLT2016用法及代碼示例
- Python PyTorch InteractionArch用法及代碼示例
- Python PyTorch InstanceNorm1d用法及代碼示例
- Python PyTorch InProjContainer.forward用法及代碼示例
- Python PyTorch InverseSpectrogram用法及代碼示例
- Python PyTorch IWSLT2017用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
- Python PyTorch cholesky用法及代碼示例
- Python PyTorch vdot用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.data.IterableDataset。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。