本文整理汇总了Python中torch.utils.data.dataset.ConcatDataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset.ConcatDataset方法的具体用法?Python dataset.ConcatDataset怎么用?Python dataset.ConcatDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.dataset
的用法示例。
在下文中一共展示了dataset.ConcatDataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
newargs1 = copy.deepcopy(args)
newargs2 = copy.deepcopy(args)
vars(newargs1).update({
'train_file': args.train_file.split(';')[0],
'val_file': args.val_file.split(';')[0],
'data': args.data.split(';')[0]})
vars(newargs2).update({
'train_file': args.train_file.split(';')[1],
'val_file': args.val_file.split(';')[1],
'data': args.data.split(';')[1]})
if 'train' in splits or 'val' in splits:
train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
else:
train_datasetego, val_datasetego = None, None
train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades, cls).get(newargs2, splits=splits)
if 'train' in splits:
train_dataset.target_transform = transforms.Lambda(lambda x: -x)
train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3) # magic number to balance
if 'val' in splits:
val_dataset.target_transform = transforms.Lambda(lambda x: -x)
val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
return train_dataset, val_dataset, valvideo_dataset
示例2: get
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
newargs1 = copy.deepcopy(args)
newargs2 = copy.deepcopy(args)
vars(newargs1).update({
'train_file': args.train_file.split(';')[0],
'val_file': args.val_file.split(';')[0],
'data': args.data.split(';')[0]})
vars(newargs2).update({
'train_file': args.train_file.split(';')[1],
'val_file': args.val_file.split(';')[1],
'data': args.data.split(';')[1]})
if 'train' in splits or 'val' in splits:
train_datasetego, val_datasetego, _ = CharadesEgoVideoMeta.get(newargs1, splits=splits)
else:
train_datasetego, val_datasetego = None, None
train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoVideoPlusCharades, cls).get(newargs2, splits=splits)
if 'train' in splits:
train_dataset.target_transform = transforms.Lambda(lambda x: -x)
train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3) # magic number to balance
if 'val' in splits:
val_dataset.target_transform = transforms.Lambda(lambda x: -x)
val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
return train_dataset, val_dataset, valvideo_dataset
示例3: get
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
newargs1 = copy.deepcopy(args)
newargs2 = copy.deepcopy(args)
vars(newargs1).update({
'train_file': args.train_file.split(';')[0],
'val_file': args.val_file.split(';')[0],
'data': args.data.split(';')[0]})
vars(newargs2).update({
'train_file': args.train_file.split(';')[1],
'val_file': args.val_file.split(';')[1],
'data': args.data.split(';')[1]})
if 'train' in splits or 'val' in splits:
train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
else:
train_datasetego, val_datasetego = None, None
train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades3, cls).get(newargs2, splits=splits)
if 'train' in splits:
train_dataset.target_transform = transforms.Lambda(lambda x: -x)
train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 1) # magic number to balance
if 'val' in splits:
val_dataset.target_transform = transforms.Lambda(lambda x: -x)
val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 1)
return train_dataset, val_dataset, valvideo_dataset
示例4: get
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
newargs1 = copy.deepcopy(args)
newargs2 = copy.deepcopy(args)
vars(newargs1).update({
'train_file': args.train_file.split(';')[0],
'val_file': args.val_file.split(';')[0],
'data': args.data.split(';')[0]})
vars(newargs2).update({
'train_file': args.train_file.split(';')[1],
'val_file': args.val_file.split(';')[1],
'data': args.data.split(';')[1]})
if 'train' in splits or 'val' in splits:
train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
else:
train_datasetego, val_datasetego = None, None
train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades2, cls).get(newargs2, splits=splits)
if 'train' in splits:
train_dataset.target_transform = transforms.Lambda(lambda x: -x)
train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6) # magic number to balance
if 'val' in splits:
val_dataset.target_transform = transforms.Lambda(lambda x: -x)
val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
return train_dataset, val_dataset, valvideo_dataset
示例5: __add__
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def __add__(self, other):
from torch.utils.data.dataset import ConcatDataset
return ConcatDataset([self, other])
示例6: get
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args):
train_datasetego, val_datasetego, _ = charadesego.CharadesEgo.get(args)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
newargs = copy.deepcopy(args)
vars(newargs).update({
'train_file': args.original_charades_train,
'val_file': args.original_charades_test,
'data': args.original_charades_data})
train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusRGB, cls).get(newargs)
train_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
val_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
valvideo_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
train_dataset.target_transform = transforms.Lambda(lambda x: -x)
val_dataset.target_transform = transforms.Lambda(lambda x: -x)
valvideoego_dataset = CharadesMeta(
args.data, 'val_video',
args.egocentric_test_data,
args.cache,
args.cache_buster,
transform=transforms.Compose([
transforms.Resize(int(256. / 224 * args.inputsize)),
transforms.CenterCrop(args.inputsize),
transforms.ToTensor(),
normalize,
]))
train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6)
val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
return train_dataset, val_dataset, valvideo_dataset, valvideoego_dataset
示例7: get_train_val_loaders
# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get_train_val_loaders(
root_path: str,
train_transforms: Callable,
val_transforms: Callable,
batch_size: int = 16,
num_workers: int = 8,
val_batch_size: Optional[int] = None,
with_sbd: Optional[str] = None,
limit_train_num_samples: Optional[int] = None,
limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
train_ds = get_train_dataset(root_path)
val_ds = get_val_dataset(root_path)
if with_sbd is not None:
sbd_train_ds = get_train_noval_sbdataset(with_sbd)
train_ds = ConcatDataset([train_ds, sbd_train_ds])
if limit_train_num_samples is not None:
np.random.seed(limit_train_num_samples)
train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples]
train_ds = Subset(train_ds, train_indices)
if limit_val_num_samples is not None:
np.random.seed(limit_val_num_samples)
val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples]
val_ds = Subset(val_ds, val_indices)
# random samples for evaluation on training dataset
if len(val_ds) < len(train_ds):
np.random.seed(len(val_ds))
train_eval_indices = np.random.permutation(len(train_ds))[: len(val_ds)]
train_eval_ds = Subset(train_ds, train_eval_indices)
else:
train_eval_ds = train_ds
train_ds = TransformedDataset(train_ds, transform_fn=train_transforms)
val_ds = TransformedDataset(val_ds, transform_fn=val_transforms)
train_eval_ds = TransformedDataset(train_eval_ds, transform_fn=val_transforms)
train_loader = idist.auto_dataloader(
train_ds, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,
)
val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
val_loader = idist.auto_dataloader(
val_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
)
train_eval_loader = idist.auto_dataloader(
train_eval_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
)
return train_loader, val_loader, train_eval_loader