本文整理汇总了Python中torch.utils.data.IterableDataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.IterableDataset方法的具体用法?Python data.IterableDataset怎么用?Python data.IterableDataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.IterableDataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _has_len
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
示例2: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __init__(self, datasets: List[TensorDataset], probs: List[float] = None, exp: float = None, mode: str = 'exp'):
"""
:param datasets: 各个源本身的Data Set
:param probs: 按照概率采样,对应每个源的概率,长度等于datasets的数量
:param exp: 按照指数平滑采样,0<exp<1
:param mode:指示是采用概率采样还是采用指数平滑采样
"""
super().__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
assert mode in ['prob', 'exp'], 'ConcatTensorRandomDataset mode只能为prob或者exp'
if mode == 'prob':
assert probs and len(probs) == len(datasets) and sum(probs) == 1
else:
assert exp and 0 < exp < 1
self.datasets = list(datasets)
self.dataset_idxs = list(range(len(self.datasets)))
self.dataset_lens = [len(x) for x in self.datasets]
self.original_lengths = [] # 记录每个源的原始数据长度
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.original_lengths.append(len(d))
if mode == 'exp':
original_probs = self.original_lengths / np.sum(self.original_lengths)
# 指数加权
probs_exp = original_probs ** exp
# softmax
pes = np.exp(probs_exp)
self.probs = pes / np.sum(pes)
else:
assert isinstance(probs, list) and probs
self.probs = np.array(probs)
self.sample_total_length = np.sum(self.original_lengths * self.probs)
示例3: __iter__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __iter__(self):
# With IterableDataset, the same __iter__ is copied over to the multiple workers of
# a Dataloader. Hence, we need to configure the __iter__ to not yield duplicated data
# when more than 1 workers are used.
#
# To avoid duplicates, we need to split the input dicts between the workers.
# The grouper() converts a dict generator given as input and yields only the
# dicts that are to be processed by the given worker_id.
#
# For instance, consider input as [dictA, dictB, dictC, ...], then the grouper
# (with n=2) will return, [[dictA, dictB], [dictE, dictF] ...] for worker 1 and
# [[dictC, dictD], [dictG, dictH] ...] for worker 2.
worker_info = torch.utils.data.get_worker_info()
if self.distributed:
worker_id = self.rank * worker_info.num_workers + worker_info.id
total_workers = self.world_size * worker_info.num_workers
else:
worker_id = worker_info.id
total_workers = self.dataloader_workers
dicts = grouper(self.file_to_dicts_generator, n=10, worker_id=worker_id, total_workers=total_workers)
results = map(self._dataset_from_chunk, dicts)
batch = []
for datasets, tensor_names in results:
if not datasets:
continue
self.tensor_names = tensor_names
for ds in datasets:
batch.append(ds)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch:
yield batch
示例4: _wrap_dataloaders
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _wrap_dataloaders(self):
def with_sampler(loader):
# Automatically set the DistributedSampler
data_loader_args = {
"dataset": loader.dataset,
"batch_size": loader.batch_size,
"shuffle": False,
"num_workers": loader.num_workers,
"collate_fn": loader.collate_fn,
"pin_memory": loader.pin_memory,
"drop_last": loader.drop_last,
"timeout": loader.timeout,
"worker_init_fn": loader.worker_init_fn,
"sampler": DistributedSampler(loader.dataset)
}
return DataLoader(**data_loader_args)
def should_wrap_dataloader(loader):
return (isinstance(loader, DataLoader)
and not isinstance(loader.dataset, IterableDataset))
if should_wrap_dataloader(self.train_loader):
if self.add_dist_sampler:
self.train_loader = with_sampler(self.train_loader)
if self.validation_loader and should_wrap_dataloader(
self.validation_loader):
if self.add_dist_sampler:
self.validation_loader = with_sampler(self.validation_loader)
示例5: _has_iterable_dataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)
示例6: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __init__(self, dataloader: IterableDataset, dataloader_size: int):
""" Wraps around an Iterable Dataloader to report progress bars and
increase global step of SummaryWriter. At last iteration, will call
dataloader.__exit__ if needed (e.g. Petastorm DataLoader).
Args:
dataloader: the iteratable dataloader to wrap around
dataloader_size: size of the dataset we're iterating over
"""
self.dataloader = dataloader
self.dataloader_iter = iter(dataloader)
self.dataloader_size = dataloader_size
示例7: reset_train_dataloader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self._worker_check(self.train_dataloader, 'train dataloader')
self._check_batch_limits('limit_train_batches')
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
if isinstance(self.limit_train_batches, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = self.limit_train_batches
# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
self._check_batch_limits('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)