本文整理汇总了Python中torch.utils.data.ConcatDataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.ConcatDataset方法的具体用法?Python data.ConcatDataset怎么用?Python data.ConcatDataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.ConcatDataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: from_dataframe
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def from_dataframe(cls,
dataframe: 'DataFrame',
group_colname: str,
time_colname: str,
dt_unit: Optional[str],
measure_colnames: Optional[Sequence[str]] = None,
X_colnames: Optional[Sequence[str]] = None,
y_colnames: Optional[Sequence[str]] = None,
**kwargs) -> 'TimeSeriesDataLoader':
dataset = ConcatDataset(
datasets=[
TimeSeriesDataset.from_dataframe(
dataframe=df,
group_colname=group_colname,
time_colname=time_colname,
measure_colnames=measure_colnames,
X_colnames=X_colnames,
y_colnames=y_colnames,
dt_unit=dt_unit
)
for g, df in dataframe.groupby(group_colname)
]
)
return cls(dataset=dataset, **kwargs)
示例2: get_training_set_gt
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_training_set_gt(dataset_path: str, image_size: ImageSize):
num_joints = 15
left_indexes: List[int] = [3, 4, 5, 9, 10, 11]
right_indexes: List[int] = [6, 7, 8, 12, 13, 14]
datasets: List[EhpiLSTMDataset] = [
EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_GT_30fps"),
transform=transforms.Compose([
RemoveJointsOutsideImgEhpi(image_size),
ScaleEhpi(image_size),
TranslateEhpi(image_size),
FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
NormalizeEhpi(image_size)
]), num_joints=num_joints),
]
for dataset in datasets:
dataset.print_label_statistics()
return ConcatDataset(datasets)
示例3: get_test_set_lab
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_test_set_lab(dataset_path: str, image_size: ImageSize):
num_joints = 15
datasets = [
EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_TEST_VUE01_30FPS"),
transform=transforms.Compose([
RemoveJointsOutsideImgEhpi(image_size),
NormalizeEhpi(image_size)
]), num_joints=num_joints, dataset_part=DatasetPart.TEST),
EhpiLSTMDataset(os.path.join(dataset_path, "JOURNAL_2019_03_TEST_VUE02_30FPS"),
transform=transforms.Compose([
RemoveJointsOutsideImgEhpi(image_size),
NormalizeEhpi(image_size)
]), num_joints=num_joints, dataset_part=DatasetPart.TEST),
]
for dataset in datasets:
dataset.print_label_statistics()
return ConcatDataset(datasets)
示例4: get_CINIC10
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_CINIC10(root="./"):
cinic_directory = root + "data/CINIC-10"
cinic_mean = [0.47889522, 0.47227842, 0.43047404]
cinic_std = [0.24205776, 0.23828046, 0.25874835]
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()])
shared_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=cinic_mean,
std=cinic_std)])
train_dataset = datasets.ImageFolder(cinic_directory + '/train')
validation_dataset = datasets.ImageFolder(cinic_directory + '/valid')
# Concatenate train and validation set to have more samples.
merged_train_dataset = torch.utils.data.ConcatDataset([train_dataset, validation_dataset])
test_dataset = datasets.ImageFolder(cinic_directory + '/test')
return DataSource(
train_dataset=merged_train_dataset,
test_dataset=test_dataset,
shared_transform=shared_transform,
train_transform=train_transform,
)
示例5: get_targets
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_targets(dataset):
"""Get the targets of a dataset without any target target transforms(!)."""
if isinstance(dataset, TransformedDataset):
return get_targets(dataset.dataset)
if isinstance(dataset, data.Subset):
targets = get_targets(dataset.dataset)
return torch.as_tensor(targets)[dataset.indices]
if isinstance(dataset, data.ConcatDataset):
return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])
if isinstance(
dataset, (datasets.MNIST, datasets.ImageFolder,)
):
return torch.as_tensor(dataset.targets)
if isinstance(dataset, datasets.SVHN):
return dataset.labels
raise NotImplementedError(f"Unknown dataset {dataset}!")
示例6: build_dataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def build_dataset(dataset_list, transform=None, target_transform=None, is_train=True):
assert len(dataset_list) > 0
datasets = []
for dataset_name in dataset_list:
data = DatasetCatalog.get(dataset_name)
args = data['args']
factory = _DATASETS[data['factory']]
args['transform'] = transform
args['target_transform'] = target_transform
if factory == VOCDataset:
args['keep_difficult'] = not is_train
elif factory == COCODataset:
args['remove_empty'] = is_train
dataset = factory(**args)
datasets.append(dataset)
# for testing, return a list of datasets
if not is_train:
return datasets
dataset = datasets[0]
if len(datasets) > 1:
dataset = ConcatDataset(datasets)
return [dataset]
示例7: __init__
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
# Set up both the background and eval dataset
omni_background = Omniglot(self.root, background=True, download=download)
# Eval labels also start from 0.
# It's important to add 964 to label values in eval so they don't overwrite background dataset.
omni_evaluation = Omniglot(self.root,
background=False,
download=download,
target_transform=lambda x: x + len(omni_background._characters))
self.dataset = ConcatDataset((omni_background, omni_evaluation))
self._bookkeeping_path = os.path.join(self.root, 'omniglot-bookkeeping.pkl')
示例8: _create_mapping_loader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def _create_mapping_loader(config, dataset_class, partitions):
imgs_list = []
for partition in partitions:
imgs_curr = dataset_class(
**{"config": config,
"split": partition,
"purpose": "test"} # return testing tuples, image and label
)
if config.use_doersch_datasets:
imgs_curr = DoerschDataset(config, imgs_curr)
imgs_list.append(imgs_curr)
imgs = ConcatDataset(imgs_list)
dataloader = torch.utils.data.DataLoader(imgs,
batch_size=config.batch_sz,
# full batch
shuffle=False,
# no point since not trained on
num_workers=0,
drop_last=False)
return dataloader
示例9: build_training_data
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def build_training_data(args, tokenizer, tasks):
dprd_task = DPRDTask(tokenizer)
if args.wiki_data:
wiki_task = WikiWSCRTask(tokenizer)
train_data = wiki_task.get_train_dataset(args.wiki_data, args.max_seq_length, input_type=tasks)
else:
train_data = dprd_task.get_train_dataset(args.data_dir, args.max_seq_length, input_type=tasks)
if args.dev_train:
_data = dprd_task.get_dev_dataset(args.data_dir, args.max_seq_length, input_type=tasks)
_data = [e.data for e in _data if e.name=='DPRD-test'][0]
train_data = ConcatDataset([train_data, _data])
if args.gap_data:
gap_data = gap_task.get_train_dataset(args.gap_data, args.max_seq_length, input_type=tasks)
train_data = ConcatDataset([train_data, gap_data])
if args.dev_train:
gap_data = [e.data for e in gap_task.get_dev_dataset(args.gap_data, args.max_seq_length, input_type=tasks)]
train_data = ConcatDataset(gap_data + [train_data])
return train_data
示例10: random_split_ConcatDataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def random_split_ConcatDataset(self, ds, lengths):
"""
Roughly split a Concatdataset into non-overlapping new datasets of given lengths.
Samples inside Concatdataset should already be shuffled
:param ds: Dataset to be split
:type ds: Dataset
:param lengths: lengths of splits to be produced
:type lengths: list
"""
if sum(lengths) != len(ds):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
try:
idx_dataset = np.where(np.array(ds.cumulative_sizes) > lengths[0])[0][0]
except IndexError:
raise Exception("All dataset chunks are being assigned to train set leaving no samples for dev set. "
"Either consider increasing dev_split or setting it to 0.0\n"
f"Cumulative chunk sizes: {ds.cumulative_sizes}\n"
f"train/dev split: {lengths}")
assert idx_dataset >= 1, "Dev_split ratio is too large, there is no data in train set. " \
f"Please lower dev_split = {self.processor.dev_split}"
train = ConcatDataset(ds.datasets[:idx_dataset])
test = ConcatDataset(ds.datasets[idx_dataset:])
return train, test
示例11: load_data
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def load_data(data_root, dataset, phase, batch_size, sampler_dic=None, num_workers=4, test_open=False, shuffle=True):
txt = './data/%s/%s_%s.txt'%(dataset, dataset, (phase if phase != 'train_plain' else 'train'))
print('Loading data from %s' % (txt))
if phase not in ['train', 'val']:
transform = data_transforms['test']
else:
transform = data_transforms[phase]
print('Use data transformation:', transform)
set_ = LT_Dataset(data_root, txt, transform)
if phase == 'test' and test_open:
open_txt = './data/%s/%s_open.txt'%(dataset, dataset)
print('Testing with opensets from %s'%(open_txt))
open_set_ = LT_Dataset('./data/%s/%s_open'%(dataset, dataset), open_txt, transform)
set_ = ConcatDataset([set_, open_set_])
if sampler_dic and phase == 'train':
print('Using sampler.')
print('Sample %s samples per-class.' % sampler_dic['num_samples_cls'])
return DataLoader(dataset=set_, batch_size=batch_size, shuffle=False,
sampler=sampler_dic['sampler'](set_, sampler_dic['num_samples_cls']),
num_workers=num_workers)
else:
print('No sampler.')
print('Shuffle is %s.' % (shuffle))
return DataLoader(dataset=set_, batch_size=batch_size,
shuffle=shuffle, num_workers=num_workers)
示例12: MultiSlideData
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def MultiSlideData(self, paths, size=(224, 224), level=0, transform=lambda x: x):
datasets = []
for path in paths:
datasets.append(SingleSlideData(path, size=size, level=level, transform=transform))
return ConcatDataset(datasets)
示例13: load_dataset
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
loaded_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
示例14: get_concated_datasets
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_concated_datasets(meta_dir_list: List[str], batch_size: int, num_workers: int,
meta_cls_list: List[MetaFrame],
fix_len: int = 0, skip_audio: bool = False, sample_rate: int = 44100,
audio_mask: bool = False) -> Tuple[SpeechDataLoader, SpeechDataLoader]:
assert all([os.path.isdir(x) for x in meta_dir_list]), 'There are not valid directory paths!'.format()
assert len(meta_dir_list) == len(meta_cls_list), 'meta_dir_list, meta_cls_list are must have same length!'
# datasets
train_datasets = []
valid_datasets = []
for meta_cls, meta_dir in zip(meta_cls_list, meta_dir_list):
train_file, valid_file = meta_cls.frame_file_names[1:]
# load meta file
train_meta = meta_cls(os.path.join(meta_dir, train_file), sr=sample_rate)
valid_meta = meta_cls(os.path.join(meta_dir, valid_file), sr=sample_rate)
# create dataset
train_dataset = AugmentSpeechDataset(train_meta, fix_len=fix_len, skip_audio=skip_audio, audio_mask=audio_mask)
valid_dataset = AugmentSpeechDataset(valid_meta, fix_len=fix_len, skip_audio=skip_audio, audio_mask=audio_mask)
train_datasets.append(train_dataset)
valid_datasets.append(valid_dataset)
# make concat dataset
train_conc_dataset = ConcatDataset(train_datasets)
valid_conc_dataset = ConcatDataset(valid_datasets)
# create data loader
train_loader = SpeechDataLoader(train_conc_dataset, batch_size=batch_size, is_bucket=False,
num_workers=num_workers, skip_last_bucket=False)
valid_loader = SpeechDataLoader(valid_conc_dataset, batch_size=batch_size, is_bucket=False,
num_workers=num_workers, skip_last_bucket=False)
return train_loader, valid_loader
示例15: get_sim_pose_algo_only
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import ConcatDataset [as 别名]
def get_sim_pose_algo_only(dataset_path: str, image_size: ImageSize):
num_joints = 15
left_indexes: List[int] = [3, 4, 5, 9, 10, 11]
right_indexes: List[int] = [6, 7, 8, 12, 13, 14]
datasets: List[EhpiDataset] = [
EhpiDataset(os.path.join(dataset_path, "ofp_sim_pose_algo_equal_30fps"),
transform=transforms.Compose([
RemoveJointsOutsideImgEhpi(image_size),
RemoveJointsEhpi(indexes_to_remove=foot_indexes, indexes_to_remove_2=knee_indexes,
probability=0.25),
ScaleEhpi(image_size),
TranslateEhpi(image_size),
FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
NormalizeEhpi(image_size)
]), num_joints=num_joints),
EhpiDataset(os.path.join(dataset_path, "ofp_from_mocap_pose_algo_30fps"),
transform=transforms.Compose([
RemoveJointsOutsideImgEhpi(image_size),
RemoveJointsEhpi(indexes_to_remove=foot_indexes, indexes_to_remove_2=knee_indexes,
probability=0.25),
ScaleEhpi(image_size),
TranslateEhpi(image_size),
FlipEhpi(left_indexes=left_indexes, right_indexes=right_indexes),
NormalizeEhpi(image_size)
]), num_joints=num_joints),
]
for dataset in datasets:
dataset.print_label_statistics()
return ConcatDataset(datasets)