当前位置: 首页>>代码示例>>Python>>正文


Python dataset.ConcatDataset方法代码示例

本文整理汇总了Python中torch.utils.data.dataset.ConcatDataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset.ConcatDataset方法的具体用法?Python dataset.ConcatDataset怎么用?Python dataset.ConcatDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.dataset的用法示例。


在下文中一共展示了dataset.ConcatDataset方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
        return train_dataset, val_dataset, valvideo_dataset 
开发者ID:gsig,项目名称:PyVideoResearch,代码行数:27,代码来源:charades_ego_plus_charades.py

示例2: get

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoVideoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoVideoPlusCharades, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 3)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 3)
        return train_dataset, val_dataset, valvideo_dataset 
开发者ID:gsig,项目名称:PyVideoResearch,代码行数:27,代码来源:charades_ego_video_plus_charades.py

示例3: get

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades3, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 1)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 1)
        return train_dataset, val_dataset, valvideo_dataset 
开发者ID:gsig,项目名称:PyVideoResearch,代码行数:27,代码来源:charades_ego_plus_charades3.py

示例4: get

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args, splits=('train', 'val', 'val_video')):
        newargs1 = copy.deepcopy(args)
        newargs2 = copy.deepcopy(args)
        vars(newargs1).update({
            'train_file': args.train_file.split(';')[0],
            'val_file': args.val_file.split(';')[0],
            'data': args.data.split(';')[0]})
        vars(newargs2).update({
            'train_file': args.train_file.split(';')[1],
            'val_file': args.val_file.split(';')[1],
            'data': args.data.split(';')[1]})

        if 'train' in splits or 'val' in splits:
            train_datasetego, val_datasetego, _ = CharadesEgoMeta.get(newargs1, splits=splits)
        else:
            train_datasetego, val_datasetego = None, None
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusCharades2, cls).get(newargs2, splits=splits)

        if 'train' in splits:
            train_dataset.target_transform = transforms.Lambda(lambda x: -x)
            train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6)  # magic number to balance
        if 'val' in splits:
            val_dataset.target_transform = transforms.Lambda(lambda x: -x)
            val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
        return train_dataset, val_dataset, valvideo_dataset 
开发者ID:gsig,项目名称:PyVideoResearch,代码行数:27,代码来源:charades_ego_plus_charades2.py

示例5: __add__

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def __add__(self, other):
        from torch.utils.data.dataset import ConcatDataset
        return ConcatDataset([self, other]) 
开发者ID:vacancy,项目名称:Jacinle,代码行数:5,代码来源:dataset.py

示例6: get

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get(cls, args):
        train_datasetego, val_datasetego, _ = charadesego.CharadesEgo.get(args)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        newargs = copy.deepcopy(args)
        vars(newargs).update({
            'train_file': args.original_charades_train,
            'val_file': args.original_charades_test,
            'data': args.original_charades_data})
        train_dataset, val_dataset, valvideo_dataset = super(CharadesEgoPlusRGB, cls).get(newargs)
        train_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        val_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        valvideo_dataset.transform.transforms.append(transforms.Lambda(lambda x: [x, x, x]))
        train_dataset.target_transform = transforms.Lambda(lambda x: -x)
        val_dataset.target_transform = transforms.Lambda(lambda x: -x)

        valvideoego_dataset = CharadesMeta(
            args.data, 'val_video',
            args.egocentric_test_data,
            args.cache,
            args.cache_buster,
            transform=transforms.Compose([
                transforms.Resize(int(256. / 224 * args.inputsize)),
                transforms.CenterCrop(args.inputsize),
                transforms.ToTensor(),
                normalize,
            ]))

        train_dataset = ConcatDataset([train_dataset] + [train_datasetego] * 6)
        val_dataset = ConcatDataset([val_dataset] + [val_datasetego] * 6)
        return train_dataset, val_dataset, valvideo_dataset, valvideoego_dataset 
开发者ID:gsig,项目名称:actor-observer,代码行数:33,代码来源:charadesegoplusrgb.py

示例7: get_train_val_loaders

# 需要导入模块: from torch.utils.data import dataset [as 别名]
# 或者: from torch.utils.data.dataset import ConcatDataset [as 别名]
def get_train_val_loaders(
    root_path: str,
    train_transforms: Callable,
    val_transforms: Callable,
    batch_size: int = 16,
    num_workers: int = 8,
    val_batch_size: Optional[int] = None,
    with_sbd: Optional[str] = None,
    limit_train_num_samples: Optional[int] = None,
    limit_val_num_samples: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    train_ds = get_train_dataset(root_path)
    val_ds = get_val_dataset(root_path)

    if with_sbd is not None:
        sbd_train_ds = get_train_noval_sbdataset(with_sbd)
        train_ds = ConcatDataset([train_ds, sbd_train_ds])

    if limit_train_num_samples is not None:
        np.random.seed(limit_train_num_samples)
        train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples]
        train_ds = Subset(train_ds, train_indices)

    if limit_val_num_samples is not None:
        np.random.seed(limit_val_num_samples)
        val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples]
        val_ds = Subset(val_ds, val_indices)

    # random samples for evaluation on training dataset
    if len(val_ds) < len(train_ds):
        np.random.seed(len(val_ds))
        train_eval_indices = np.random.permutation(len(train_ds))[: len(val_ds)]
        train_eval_ds = Subset(train_ds, train_eval_indices)
    else:
        train_eval_ds = train_ds

    train_ds = TransformedDataset(train_ds, transform_fn=train_transforms)
    val_ds = TransformedDataset(val_ds, transform_fn=val_transforms)
    train_eval_ds = TransformedDataset(train_eval_ds, transform_fn=val_transforms)

    train_loader = idist.auto_dataloader(
        train_ds, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,
    )

    val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size
    val_loader = idist.auto_dataloader(
        val_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    train_eval_loader = idist.auto_dataloader(
        train_eval_ds, shuffle=False, batch_size=val_batch_size, num_workers=num_workers, drop_last=False,
    )

    return train_loader, val_loader, train_eval_loader 
开发者ID:pytorch,项目名称:ignite,代码行数:57,代码来源:dataloaders.py


注:本文中的torch.utils.data.dataset.ConcatDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。