当前位置: 首页>>代码示例>>Python>>正文


Python data.ConcatDataset方法代码示例

本文整理汇总了Python中torch.utils.data.ConcatDataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.ConcatDataset方法的具体用法?Python data.ConcatDataset怎么用?Python data.ConcatDataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.ConcatDataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: from_dataframe

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def from_dataframe(cls,
                       dataframe: 'DataFrame',
                       group_colname: str,
                       time_colname: str,
                       dt_unit: Optional[str],
                       measure_colnames: Optional[Sequence[str]] = None,
                       X_colnames: Optional[Sequence[str]] = None,
                       y_colnames: Optional[Sequence[str]] = None,
                       **kwargs) -> 'TimeSeriesDataLoader':
        dataset = ConcatDataset(
            datasets=[
                TimeSeriesDataset.from_dataframe(
                    dataframe=df,
                    group_colname=group_colname,
                    time_colname=time_colname,
                    measure_colnames=measure_colnames,
                    X_colnames=X_colnames,
                    y_colnames=y_colnames,
                    dt_unit=dt_unit
                )
                for g, df in dataframe.groupby(group_colname)
            ]
        )
        return cls(dataset=dataset, **kwargs) 
开发者ID:strongio,项目名称:torch-kalman,代码行数:26,代码来源:data.py

示例2: get_training_set_gt

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_training_set_gt(dataset_path: str, image_size: ImageSize):
    num_joints = 15
    left_indexes: List[int] = [3, 4, 5, 9, 10, 11]
    right_indexes: List[int] = [6, 7, 8, 12, 13, 14]

    datasets: List[EhpiLSTMDataset] = [
        EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_GT_30fps"),
                        transform=transforms.Compose([
                            RemoveJointsOutsideImgEhpi(image_size),
                            ScaleEhpi(image_size),
                            TranslateEhpi(image_size),
                            FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
                            NormalizeEhpi(image_size)
                        ]), num_joints=num_joints),
    ]
    for dataset in datasets:
        dataset.print_label_statistics()

    return ConcatDataset(datasets) 
开发者ID:noboevbo,项目名称:ehpi_action_recognition,代码行数:21,代码来源:train_its_journal_2019.py

示例3: get_test_set_lab

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_test_set_lab(dataset_path: str, image_size: ImageSize):
    num_joints = 15
    datasets = [
    EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_TEST_VUE01_30FPS"),
                             transform=transforms.Compose([
                                 RemoveJointsOutsideImgEhpi(image_size),
                                 NormalizeEhpi(image_size)
                             ]), num_joints=num_joints, dataset_part=DatasetPart.TEST),
    EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_TEST_VUE02_30FPS"),
                             transform=transforms.Compose([
                                 RemoveJointsOutsideImgEhpi(image_size),
                                 NormalizeEhpi(image_size)
                             ]), num_joints=num_joints, dataset_part=DatasetPart.TEST),
    ]
    for dataset in datasets:
        dataset.print_label_statistics()
    return ConcatDataset(datasets) 
开发者ID:noboevbo,项目名称:ehpi_action_recognition,代码行数:19,代码来源:test_its_journal_2019.py

示例4: get_CINIC10

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_CINIC10(root="./"):
    cinic_directory = root + "data/CINIC-10"
    cinic_mean = [0.47889522, 0.47227842, 0.43047404]
    cinic_std = [0.24205776, 0.23828046, 0.25874835]

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()])
    shared_transform = transforms.Compose([transforms.ToTensor(),
                                           transforms.Normalize(mean=cinic_mean,
                                                                std=cinic_std)])

    train_dataset = datasets.ImageFolder(cinic_directory + '/train')
    validation_dataset = datasets.ImageFolder(cinic_directory + '/valid')

    # Concatenate train and validation set to have more samples.
    merged_train_dataset = torch.utils.data.ConcatDataset([train_dataset, validation_dataset])

    test_dataset = datasets.ImageFolder(cinic_directory + '/test')

    return DataSource(
        train_dataset=merged_train_dataset,
        test_dataset=test_dataset,
        shared_transform=shared_transform,
        train_transform=train_transform,
    ) 
开发者ID:BlackHC,项目名称:BatchBALD,代码行数:26,代码来源:dataset_enum.py

示例5: get_targets

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_targets(dataset):
    """Get the targets of a dataset without any target target transforms(!)."""
    if isinstance(dataset, TransformedDataset):
        return get_targets(dataset.dataset)
    if isinstance(dataset, data.Subset):
        targets = get_targets(dataset.dataset)
        return torch.as_tensor(targets)[dataset.indices]
    if isinstance(dataset, data.ConcatDataset):
        return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

    if isinstance(
            dataset, (datasets.MNIST, datasets.ImageFolder,)
    ):
        return torch.as_tensor(dataset.targets)
    if isinstance(dataset, datasets.SVHN):
        return dataset.labels

    raise NotImplementedError(f"Unknown dataset {dataset}!") 
开发者ID:BlackHC,项目名称:BatchBALD,代码行数:20,代码来源:dataset_enum.py

示例6: build_dataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def build_dataset(dataset_list, transform=None, target_transform=None, is_train=True):
    assert len(dataset_list) > 0
    datasets = []
    for dataset_name in dataset_list:
        data = DatasetCatalog.get(dataset_name)
        args = data['args']
        factory = _DATASETS[data['factory']]
        args['transform'] = transform
        args['target_transform'] = target_transform
        if factory == VOCDataset:
            args['keep_difficult'] = not is_train
        elif factory == COCODataset:
            args['remove_empty'] = is_train
        dataset = factory(**args)
        datasets.append(dataset)
    # for testing, return a list of datasets
    if not is_train:
        return datasets
    dataset = datasets[0]
    if len(datasets) > 1:
        dataset = ConcatDataset(datasets)

    return [dataset] 
开发者ID:lufficc,项目名称:SSD,代码行数:25,代码来源:__init__.py

示例7: __init__

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def __init__(self, root, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform

        # Set up both the background and eval dataset
        omni_background = Omniglot(self.root, background=True, download=download)
        # Eval labels also start from 0.
        # It's important to add 964 to label values in eval so they don't overwrite background dataset.
        omni_evaluation = Omniglot(self.root,
                                   background=False,
                                   download=download,
                                   target_transform=lambda x: x + len(omni_background._characters))

        self.dataset = ConcatDataset((omni_background, omni_evaluation))
        self._bookkeeping_path = os.path.join(self.root, 'omniglot-bookkeeping.pkl') 
开发者ID:learnables,项目名称:learn2learn,代码行数:18,代码来源:full_omniglot.py

示例8: _create_mapping_loader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def _create_mapping_loader(config, dataset_class, partitions):
  imgs_list = []
  for partition in partitions:
    imgs_curr = dataset_class(
      **{"config": config,
         "split": partition,
         "purpose": "test"}  # return testing tuples, image and label
    )
    if config.use_doersch_datasets:
      imgs_curr = DoerschDataset(config, imgs_curr)
    imgs_list.append(imgs_curr)

  imgs = ConcatDataset(imgs_list)
  dataloader = torch.utils.data.DataLoader(imgs,
                                           batch_size=config.batch_sz,
                                           # full batch
                                           shuffle=False,
                                           # no point since not trained on
                                           num_workers=0,
                                           drop_last=False)
  return dataloader 
开发者ID:xu-ji,项目名称:IIC,代码行数:23,代码来源:data.py

示例9: build_training_data

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def build_training_data(args, tokenizer, tasks):
  dprd_task = DPRDTask(tokenizer)
  if args.wiki_data:
    wiki_task = WikiWSCRTask(tokenizer)
    train_data = wiki_task.get_train_dataset(args.wiki_data, args.max_seq_length, input_type=tasks)
  else:
    train_data = dprd_task.get_train_dataset(args.data_dir, args.max_seq_length, input_type=tasks)
    if args.dev_train:
      _data = dprd_task.get_dev_dataset(args.data_dir, args.max_seq_length, input_type=tasks)
      _data = [e.data for e in _data if e.name=='DPRD-test'][0]
      train_data = ConcatDataset([train_data, _data])
    if args.gap_data:
      gap_data = gap_task.get_train_dataset(args.gap_data, args.max_seq_length, input_type=tasks)
      train_data = ConcatDataset([train_data, gap_data])
      if args.dev_train:
        gap_data = [e.data for e in gap_task.get_dev_dataset(args.gap_data, args.max_seq_length, input_type=tasks)]
        train_data = ConcatDataset(gap_data + [train_data])
  return train_data 
开发者ID:namisan,项目名称:mt-dnn,代码行数:20,代码来源:run_hnn.py

示例10: random_split_ConcatDataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def random_split_ConcatDataset(self, ds, lengths):
        """
        Roughly split a Concatdataset into non-overlapping new datasets of given lengths.
        Samples inside Concatdataset should already be shuffled

        :param ds: Dataset to be split
        :type ds: Dataset
        :param lengths: lengths of splits to be produced
        :type lengths: list
        """
        if sum(lengths) != len(ds):
            raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

        try:
            idx_dataset = np.where(np.array(ds.cumulative_sizes) > lengths[0])[0][0]
        except IndexError:
            raise Exception("All dataset chunks are being assigned to train set leaving no samples for dev set. "
                            "Either consider increasing dev_split or setting it to 0.0\n"
                            f"Cumulative chunk sizes: {ds.cumulative_sizes}\n"
                            f"train/dev split: {lengths}")

        assert idx_dataset >= 1, "Dev_split ratio is too large, there is no data in train set. " \
                             f"Please lower dev_split = {self.processor.dev_split}"

        train = ConcatDataset(ds.datasets[:idx_dataset])
        test = ConcatDataset(ds.datasets[idx_dataset:])
        return train, test 
开发者ID:deepset-ai,项目名称:FARM,代码行数:29,代码来源:data_silo.py

示例11: load_data

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def load_data(data_root, dataset, phase, batch_size, sampler_dic=None, num_workers=4, test_open=False, shuffle=True):
    
    txt = './data/%s/%s_%s.txt'%(dataset, dataset, (phase if phase != 'train_plain' else 'train'))

    print('Loading data from %s' % (txt))

    if phase not in ['train', 'val']:
        transform = data_transforms['test']
    else:
        transform = data_transforms[phase]

    print('Use data transformation:', transform)

    set_ = LT_Dataset(data_root, txt, transform)

    if phase == 'test' and test_open:
        open_txt = './data/%s/%s_open.txt'%(dataset, dataset)
        print('Testing with opensets from %s'%(open_txt))
        open_set_ = LT_Dataset('./data/%s/%s_open'%(dataset, dataset), open_txt, transform)
        set_ = ConcatDataset([set_, open_set_])

    if sampler_dic and phase == 'train':
        print('Using sampler.')
        print('Sample %s samples per-class.' % sampler_dic['num_samples_cls'])
        return DataLoader(dataset=set_, batch_size=batch_size, shuffle=False,
                           sampler=sampler_dic['sampler'](set_, sampler_dic['num_samples_cls']),
                           num_workers=num_workers)
    else:
        print('No sampler.')
        print('Shuffle is %s.' % (shuffle))
        return DataLoader(dataset=set_, batch_size=batch_size,
                          shuffle=shuffle, num_workers=num_workers) 
开发者ID:zhmiao,项目名称:OpenLongTailRecognition-OLTR,代码行数:34,代码来源:dataloader.py

示例12: MultiSlideData

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def MultiSlideData(self, paths, size=(224, 224), level=0, transform=lambda x: x):
  datasets = []
  for path in paths:
    datasets.append(SingleSlideData(path, size=size, level=level, transform=transform))
  return ConcatDataset(datasets) 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:7,代码来源:slides.py

示例13: load_dataset

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def load_dataset(self, split, combine=False):
        """Load a dataset split."""

        loaded_datasets = []

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(self.args.data, split_k)

            if self.args.raw_text and IndexedRawTextDataset.exists(path):
                ds = IndexedRawTextDataset(path, self.dictionary)
                tokens = [t for l in ds.tokens_list for t in l]
            elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
                ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
                tokens = ds.buffer
            else:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))

            loaded_datasets.append(
                TokenBlockDataset(
                    tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
                    include_targets=True
                ))

            print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:42,代码来源:language_modeling.py

示例14: get_concated_datasets

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_concated_datasets(meta_dir_list: List[str], batch_size: int, num_workers: int,
                          meta_cls_list: List[MetaFrame],
                          fix_len: int = 0, skip_audio: bool = False, sample_rate: int = 44100,
                          audio_mask: bool = False) -> Tuple[SpeechDataLoader, SpeechDataLoader]:

    assert all([os.path.isdir(x) for x in meta_dir_list]), 'There are not valid directory paths!'.format()
    assert len(meta_dir_list) == len(meta_cls_list), 'meta_dir_list, meta_cls_list are must have same length!'

    # datasets
    train_datasets = []
    valid_datasets = []

    for meta_cls, meta_dir in zip(meta_cls_list, meta_dir_list):
        train_file, valid_file = meta_cls.frame_file_names[1:]

        # load meta file
        train_meta = meta_cls(os.path.join(meta_dir, train_file), sr=sample_rate)
        valid_meta = meta_cls(os.path.join(meta_dir, valid_file), sr=sample_rate)

        # create dataset
        train_dataset = AugmentSpeechDataset(train_meta, fix_len=fix_len, skip_audio=skip_audio, audio_mask=audio_mask)
        valid_dataset = AugmentSpeechDataset(valid_meta, fix_len=fix_len, skip_audio=skip_audio, audio_mask=audio_mask)

        train_datasets.append(train_dataset)
        valid_datasets.append(valid_dataset)

    # make concat dataset
    train_conc_dataset = ConcatDataset(train_datasets)
    valid_conc_dataset = ConcatDataset(valid_datasets)

    # create data loader
    train_loader = SpeechDataLoader(train_conc_dataset, batch_size=batch_size, is_bucket=False,
                                    num_workers=num_workers, skip_last_bucket=False)
    valid_loader = SpeechDataLoader(valid_conc_dataset, batch_size=batch_size, is_bucket=False,
                                    num_workers=num_workers, skip_last_bucket=False)

    return train_loader, valid_loader 
开发者ID:AppleHolic,项目名称:source_separation,代码行数:39,代码来源:dataset.py

示例15: get_sim_pose_algo_only

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_sim_pose_algo_only(dataset_path: str, image_size: ImageSize):
    num_joints = 15
    left_indexes: List[int] = [3, 4, 5, 9, 10, 11]
    right_indexes: List[int] = [6, 7, 8, 12, 13, 14]

    datasets: List[EhpiDataset] = [
        EhpiDataset(os.path.join(dataset_path, "ofp_sim_pose_algo_equal_30fps"),
                    transform=transforms.Compose([
                        RemoveJointsOutsideImgEhpi(image_size),
                        RemoveJointsEhpi(indexes_to_remove=foot_indexes, indexes_to_remove_2=knee_indexes,
                                         probability=0.25),
                        ScaleEhpi(image_size),
                        TranslateEhpi(image_size),
                        FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
                        NormalizeEhpi(image_size)
                    ]), num_joints=num_joints),
        EhpiDataset(os.path.join(dataset_path, "ofp_from_mocap_pose_algo_30fps"),
                    transform=transforms.Compose([
                        RemoveJointsOutsideImgEhpi(image_size),
                        RemoveJointsEhpi(indexes_to_remove=foot_indexes, indexes_to_remove_2=knee_indexes,
                                         probability=0.25),
                        ScaleEhpi(image_size),
                        TranslateEhpi(image_size),
                        FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
                        NormalizeEhpi(image_size)
                    ]), num_joints=num_joints),
    ]
    for dataset in datasets:
        dataset.print_label_statistics()

    return ConcatDataset(datasets) 
开发者ID:noboevbo,项目名称:ehpi_action_recognition,代码行数:33,代码来源:train_ehpi_itsc_2019_ofp.py


注:本文中的torch.utils.data.ConcatDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。