当前位置: 首页>>代码示例>>Python>>正文


Python data.IterableDataset方法代码示例

本文整理汇总了Python中torch.utils.data.IterableDataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.IterableDataset方法的具体用法?Python data.IterableDataset怎么用?Python data.IterableDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.IterableDataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _has_len

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _has_len(dataloader: DataLoader) -> bool:
    """ Checks if a given Dataloader has __len__ method implemented i.e. if
    it is a finite dataloader or infinite dataloader. """

    try:
        # try getting the length
        if len(dataloader) == 0:
            raise ValueError('`Dataloader` returned 0 length.'
                             ' Please make sure that your Dataloader at least returns 1 batch')
        has_len = True
    except TypeError:
        has_len = False
    except NotImplementedError:  # e.g. raised by torchtext if a batch_size_fn is used
        has_len = False

    if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
        rank_zero_warn(
            'Your `IterableDataset` has `__len__` defined.'
            ' In combination with multi-processing data loading (e.g. batch size > 1),'
            ' this can lead to unintended side effects since the samples will be duplicated.'
        )
    return has_len 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:24,代码来源:data_loading.py

示例2: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __init__(self, datasets: List[TensorDataset], probs: List[float] = None, exp: float = None, mode: str = 'exp'):
        """

        :param datasets: 各个源本身的Data Set
        :param probs: 按照概率采样,对应每个源的概率,长度等于datasets的数量
        :param exp: 按照指数平滑采样,0<exp<1
        :param mode:指示是采用概率采样还是采用指数平滑采样
        """
        super().__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        assert mode in ['prob', 'exp'], 'ConcatTensorRandomDataset mode只能为prob或者exp'
        if mode == 'prob':
            assert probs and len(probs) == len(datasets) and sum(probs) == 1
        else:
            assert exp and 0 < exp < 1
        self.datasets = list(datasets)
        self.dataset_idxs = list(range(len(self.datasets)))
        self.dataset_lens = [len(x) for x in self.datasets]
        self.original_lengths = []  # 记录每个源的原始数据长度
        for d in self.datasets:
            assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
            self.original_lengths.append(len(d))
        if mode == 'exp':
            original_probs = self.original_lengths / np.sum(self.original_lengths)
            # 指数加权
            probs_exp = original_probs ** exp
            # softmax
            pes = np.exp(probs_exp)
            self.probs = pes / np.sum(pes)
        else:
            assert isinstance(probs, list) and probs
            self.probs = np.array(probs)
        self.sample_total_length = np.sum(self.original_lengths * self.probs) 
开发者ID:NLPInBLCU,项目名称:BiaffineDependencyParsing,代码行数:35,代码来源:custom_dataset.py

示例3: __iter__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __iter__(self):
        #  With IterableDataset, the same __iter__ is copied over to the multiple workers of
        #  a Dataloader. Hence, we need to configure the __iter__ to not yield duplicated data
        #  when more than 1 workers are used.
        #
        #  To avoid duplicates, we need to split the input dicts between the workers.
        #  The grouper() converts a dict generator given as input and yields only the
        #  dicts that are to be processed by the given worker_id.
        #
        #  For instance, consider input as [dictA, dictB, dictC, ...], then the grouper
        #  (with n=2) will return, [[dictA, dictB], [dictE, dictF] ...] for worker 1 and
        #  [[dictC, dictD], [dictG, dictH] ...] for worker 2.

        worker_info = torch.utils.data.get_worker_info()
        if self.distributed:
            worker_id = self.rank * worker_info.num_workers + worker_info.id
            total_workers = self.world_size * worker_info.num_workers
        else:
            worker_id = worker_info.id
            total_workers = self.dataloader_workers

        dicts = grouper(self.file_to_dicts_generator, n=10, worker_id=worker_id, total_workers=total_workers)
        results = map(self._dataset_from_chunk, dicts)

        batch = []
        for datasets, tensor_names in results:
            if not datasets:
                continue
            self.tensor_names = tensor_names
            for ds in datasets:
                batch.append(ds)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
        if batch:
            yield batch 
开发者ID:deepset-ai,项目名称:FARM,代码行数:38,代码来源:data_silo.py

示例4: _wrap_dataloaders

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _wrap_dataloaders(self):
        def with_sampler(loader):
            # Automatically set the DistributedSampler
            data_loader_args = {
                "dataset": loader.dataset,
                "batch_size": loader.batch_size,
                "shuffle": False,
                "num_workers": loader.num_workers,
                "collate_fn": loader.collate_fn,
                "pin_memory": loader.pin_memory,
                "drop_last": loader.drop_last,
                "timeout": loader.timeout,
                "worker_init_fn": loader.worker_init_fn,
                "sampler": DistributedSampler(loader.dataset)
            }
            return DataLoader(**data_loader_args)

        def should_wrap_dataloader(loader):
            return (isinstance(loader, DataLoader)
                    and not isinstance(loader.dataset, IterableDataset))

        if should_wrap_dataloader(self.train_loader):
            if self.add_dist_sampler:
                self.train_loader = with_sampler(self.train_loader)

        if self.validation_loader and should_wrap_dataloader(
                self.validation_loader):
            if self.add_dist_sampler:
                self.validation_loader = with_sampler(self.validation_loader) 
开发者ID:ray-project,项目名称:ray,代码行数:31,代码来源:distributed_torch_runner.py

示例5: _has_iterable_dataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def _has_iterable_dataset(dataloader: DataLoader):
    return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
        and isinstance(dataloader.dataset, IterableDataset) 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:5,代码来源:data_loading.py

示例6: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def __init__(self, dataloader: IterableDataset, dataloader_size: int):
        """ Wraps around an Iterable Dataloader to report progress bars and
        increase global step of SummaryWriter. At last iteration, will call
        dataloader.__exit__ if needed (e.g. Petastorm DataLoader).

        Args:
            dataloader: the iteratable dataloader to wrap around
            dataloader_size: size of the dataset we're iterating over
        """

        self.dataloader = dataloader
        self.dataloader_iter = iter(dataloader)
        self.dataloader_size = dataloader_size 
开发者ID:facebookresearch,项目名称:ReAgent,代码行数:15,代码来源:iterators.py

示例7: reset_train_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import IterableDataset [as 别名]
def reset_train_dataloader(self, model: LightningModule) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model.train_dataloader)

        self.num_training_batches = 0

        # automatically add samplers
        self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

        self._worker_check(self.train_dataloader, 'train dataloader')
        self._check_batch_limits('limit_train_batches')

        if not _has_len(self.train_dataloader):
            self.num_training_batches = float('inf')
        else:
            # try getting the length
            if isinstance(self.limit_train_batches, float):
                self.num_training_batches = len(self.train_dataloader)
                self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
            else:
                self.num_training_batches = self.limit_train_batches

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
        else:
            if not _has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an infinite DataLoader (e.g. with an IterableDataset'
                        ' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.')
            else:
                self._check_batch_limits('val_check_interval')

                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch) 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:54,代码来源:data_loading.py


注:本文中的torch.utils.data.IterableDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。