当前位置: 首页>>代码示例>>Python>>正文


Python data.Dataset方法代码示例

本文整理汇总了Python中torch.utils.data.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.Dataset方法的具体用法?Python data.Dataset怎么用?Python data.Dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.Dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_test_loader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_test_loader(split_name, data_name, vocab, crop_size, batch_size,
                    workers, opt, cap_suffix='caps'):
    dpath = os.path.join(opt.data_path, data_name)
    if opt.data_name.endswith('_precomp'):
        if not opt.use_external_captions:
            test_loader = get_precomp_loader(dpath, split_name, vocab, opt,
                                             batch_size, False, workers, cap_suffix)
        else:
            test_loader = get_precomp_train_caption_loader(dpath, split_name, vocab, opt,
                                                           batch_size, False, workers, cap_suffix)
    else:
        # Build Dataset Loader
        roots, ids = get_paths(dpath, data_name, opt.use_restval)

        transform = get_transform(data_name, split_name, opt)
        test_loader = get_loader_single(opt.data_name, split_name,
                                        roots[split_name]['img'],
                                        roots[split_name]['cap'],
                                        vocab, transform, ids=ids[split_name],
                                        batch_size=batch_size, shuffle=False,
                                        num_workers=workers,
                                        collate_fn=collate_fn)

    return test_loader 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:26,代码来源:data.py

示例2: get_seq

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_seq(pairs,lang,batch_size,type,max_len):   
    x_seq = []
    y_seq = []
    ptr_seq = []
    for pair in pairs:
        x_seq.append(pair[0])
        y_seq.append(pair[1])
        ptr_seq.append(pair[2])
        if(type):
            lang.index_words(pair[0])
            lang.index_words(pair[1])
    
    dataset = Dataset(x_seq, y_seq,ptr_seq,lang.word2index, lang.word2index,max_len)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=type,
                                              collate_fn=collate_fn)
    return data_loader 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:20,代码来源:utils_NMT.py

示例3: get_seq

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def get_seq(pairs,lang,batch_size,type,max_len):   
    x_seq = []
    y_seq = []
    ptr_seq = []
    gate_seq = []
    for pair in pairs:
        x_seq.append(pair[0])
        y_seq.append(pair[1])
        ptr_seq.append(pair[2])
        gate_seq.append(pair[3])
        if(type):
            lang.index_words(pair[0])
            lang.index_words(pair[1])
    
    dataset = Dataset(x_seq, y_seq,ptr_seq,gate_seq,lang.word2index, lang.word2index,max_len)
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=batch_size,
                                              shuffle=type,
                                              collate_fn=collate_fn)
    return data_loader 
开发者ID:ConvLab,项目名称:ConvLab,代码行数:22,代码来源:utils_babi.py

示例4: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, data_path=None, split_set='train', num_points=20000,
        use_color=False, use_height=False, use_v1=False,
        augment=False, scan_idx_list=None):

        assert(num_points<=50000)
        self.use_v1 = use_v1 
        if use_v1:
            self.data_path = os.path.join(data_path, 'sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
            # self.data_path = os.path.join('/scratch/cluster/yanght/Dataset/sunrgbd/sunrgbd_pc_bbox_votes_50k_v1_' + split_set)
        else:
            AssertionError("v2 data is not prepared")

        self.raw_data_path = os.path.join(ROOT_DIR, 'sunrgbd/sunrgbd_trainval')
        self.scan_names = sorted(list(set([os.path.basename(x)[0:6] \
            for x in os.listdir(self.data_path)])))

        if scan_idx_list is not None:
            self.scan_names = [self.scan_names[i] for i in scan_idx_list]
        self.num_points = num_points
        self.augment = augment
        self.use_color = use_color
        self.use_height = use_height 
开发者ID:zaiweizhang,项目名称:H3DNet,代码行数:24,代码来源:sunrgbd_detection_dataset_hd.py

示例5: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, num_epochs=None, *args, **kwargs):
        """Constructor.

        Args:
            dataset: A `Dataset` object to be loaded.
            batch_size: int, the size of each batch.
            shuffle: bool, whether to shuffle the dataset after each epoch.
            drop_last: bool, whether to drop last batch if its size is less than
                `batch_size`.
            num_epochs: int or None, number of epochs to iterate over the dataset.
                If None, defaults to infinity.
        """
        super().__init__(
            *args, **kwargs
        )
        self.finite_iterable = super().__iter__()
        self.counter = 0
        self.num_epochs = float('inf') if num_epochs is None else num_epochs 
开发者ID:bayesiains,项目名称:nsf,代码行数:20,代码来源:base.py

示例6: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root_dir, split_name, transform):
        super(ISSDataset, self).__init__()
        self.root_dir = root_dir
        self.split_name = split_name
        self.transform = transform

        # Folders
        self._img_dir = path.join(root_dir, ISSDataset._IMG_DIR)
        self._msk_dir = path.join(root_dir, ISSDataset._MSK_DIR)
        self._lst_dir = path.join(root_dir, ISSDataset._LST_DIR)
        for d in self._img_dir, self._msk_dir, self._lst_dir:
            if not path.isdir(d):
                raise IOError("Dataset sub-folder {} does not exist".format(d))

        # Load meta-data and split
        self._meta, self._images = self._load_split() 
开发者ID:mapillary,项目名称:seamseg,代码行数:18,代码来源:dataset.py

示例7: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, seed=0):
        super(MNIST, self).__init__(root)
        self.seed = seed
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) 
开发者ID:sato9hara,项目名称:sgd-influence,代码行数:21,代码来源:MyMNIST.py

示例8: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)) 
开发者ID:igolan,项目名称:bgd,代码行数:21,代码来源:datasets.py

示例9: __getitem__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __getitem__(self, idx):
        record = OrderedDict()

        for feat in chain(
                self.features.number_features,
                self.features.category_features):
            record[feat.name] = self.X_map[feat.name][idx]

        for feat in self.features.sequence_features:
            seq = self.X_map[feat.name][idx]
            record[feat.name] = Dataset.__pad_sequence(feat, seq)
            record[f"__{feat.name}_length"] = np.int64(seq.shape[0])

        if self.y is not None:
            record['label'] = self.y[idx]
        return record 
开发者ID:GitHub-HongweiZhang,项目名称:prediction-flow,代码行数:18,代码来源:dataset.py

示例10: transform

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def transform(self, fn, lazy=True):
        """Returns a new dataset with each sample transformed by the
        transformer function `fn`.

        Parameters
        ----------
        fn : callable
            A transformer function that takes a sample as input and
            returns the transformed sample.
        lazy : bool, default True
            If False, transforms all samples at once. Otherwise,
            transforms each sample on demand. Note that if `fn`
            is stochastic, you must set lazy to True or you will
            get the same result on all epochs.

        Returns
        -------
        Dataset
            The transformed dataset.
        """
        trans = _LazyTransformDataset(self, fn)
        if lazy:
            return trans
        return SimpleDataset([i for i in trans]) 
开发者ID:AceCoooool,项目名称:LEDNet,代码行数:26,代码来源:base.py

示例11: _dataset_from_chunk

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def _dataset_from_chunk(cls, chunk, processor):
        """
        Creating a dataset for a chunk (= subset) of dicts. In multiprocessing:
          * we read in all dicts from a file
          * split all dicts into chunks
          * feed *one chunk* to *one process*
          => the *one chunk*  gets converted to *one dataset* (that's what we do here)
          * all datasets get collected and concatenated
        :param chunk: Instead of only having a list of dicts here we also supply an index (ascending int) for each.
            => [(0, dict), (1, dict) ...]
        :type chunk: list of tuples
        :param processor: FARM Processor (e.g. TextClassificationProcessor)
        :return: PyTorch Dataset
        """
        dicts = [d[1] for d in chunk]
        indices = [x[0] for x in chunk]
        dataset = processor.dataset_from_dicts(dicts=dicts, indices=indices)
        return dataset 
开发者ID:deepset-ai,项目名称:FARM,代码行数:20,代码来源:data_silo.py

示例12: covert_dataset_to_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def covert_dataset_to_dataloader(dataset, sampler, batch_size):
    """
    Wraps a PyTorch Dataset with a DataLoader.

    :param dataset: Dataset to be wrapped.
    :type dataset: Dataset
    :param sampler: PyTorch sampler used to pick samples in a batch.
    :type sampler: Sampler
    :param batch_size: Number of samples in the batch.
    :return: A DataLoader that wraps the input Dataset.
    """
    sampler_initialized = sampler(dataset)
    data_loader = DataLoader(
        dataset, sampler=sampler_initialized, batch_size=batch_size
    )
    return data_loader 
开发者ID:deepset-ai,项目名称:FARM,代码行数:18,代码来源:dataloader.py

示例13: resume_training

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def resume_training(self, train_data, model_path, valid_data=None):
        """This model resume training of a classifier by reloading the appropriate state_dicts for each model

        Args:
           train_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
                X (data) and Y (labels) for the train split
            model_path: the path to the saved checpoint for resuming training
            valid_data: a tuple of Tensors (X,Y), a Dataset, or a DataLoader of
                X (data) and Y (labels) for the dev split
        """
        restore_state = self.checkpointer.restore(model_path)
        loss_fn = self._get_loss_fn()
        self.train()
        self._train_model(
            train_data=train_data,
            loss_fn=loss_fn,
            valid_data=valid_data,
            restore_state=restore_state,
        ) 
开发者ID:HazyResearch,项目名称:metal,代码行数:21,代码来源:classifier.py

示例14: _create_data_loader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def _create_data_loader(self, data, **kwargs):
        """Converts input data into a DataLoader"""
        if data is None:
            return None

        # Set DataLoader config
        # NOTE: Not applicable if data is already a DataLoader
        config = {
            **self.config["train_config"]["data_loader_config"],
            **kwargs,
            "pin_memory": self.config["device"] != "cpu",
        }
        # Return data as DataLoader
        if isinstance(data, DataLoader):
            return data
        elif isinstance(data, Dataset):
            return DataLoader(data, **config)
        elif isinstance(data, (tuple, list)):
            return DataLoader(self._create_dataset(*data), **config)
        else:
            raise ValueError("Input data type not recognized.") 
开发者ID:HazyResearch,项目名称:metal,代码行数:23,代码来源:classifier.py

示例15: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import Dataset [as 别名]
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False):
        """
        Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
        a comma. Each new line is a different sample. Example below:

        /path/to/audio.wav,/path/to/audio.txt
        ...

        :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
        :param manifest_filepath: Path to manifest csv as describe above
        :param labels: String containing all the possible characters to map to
        :param normalize: Apply standard mean and deviation normalization to audio tensor
        :param augment(default False):  Apply random tempo and gain perturbations
        """
        with open(manifest_filepath) as f:
            ids = f.readlines()
        ids = [x.strip().split(',') for x in ids]
        self.ids = ids
        self.size = len(ids)
        self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
        super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment) 
开发者ID:joseph-zhong,项目名称:LipReading,代码行数:23,代码来源:data_loader.py


注:本文中的torch.utils.data.Dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。