当前位置: 首页>>代码示例>>Python>>正文


Python torchvision.datasets方法代码示例

本文整理汇总了Python中torchvision.datasets方法的典型用法代码示例。如果您正苦于以下问题:Python torchvision.datasets方法的具体用法?Python torchvision.datasets怎么用?Python torchvision.datasets使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision的用法示例。


在下文中一共展示了torchvision.datasets方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_data

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def load_data(train_split, val_split, root):
    # Load Data

    if len(train_split) > 0:
        dataset = Dataset(train_split, 'training', root, batch_size)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn)
        dataloader.root = root
    else:
        
        dataset = None
        dataloader = None

    val_dataset = Dataset(val_split, 'testing', root, batch_size)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    val_dataloader.root = root

    dataloaders = {'train': dataloader, 'val': val_dataloader}
    datasets = {'train': dataset, 'val': val_dataset}
    return dataloaders, datasets


# train the model 
开发者ID:piergiaj,项目名称:super-events-cvpr18,代码行数:24,代码来源:train_model.py

示例2: get_loaders

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler 
开发者ID:cybertronai,项目名称:imagenet18_old,代码行数:26,代码来源:dataloader.py

示例3: __init__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __init__(self, split='train'):
        self.split = split
        assert(split=='train' or split=='val')
        self.name = 'ImageNet_Split_' + split

        print('Loading ImageNet dataset - split {0}'.format(split))
        transforms_list = []
        transforms_list.append(transforms.Scale(256))
        transforms_list.append(transforms.CenterCrop(224))
        transforms_list.append(lambda x: np.asarray(x))
        transforms_list.append(transforms.ToTensor())
        mean_pix = [0.485, 0.456, 0.406]
        std_pix = [0.229, 0.224, 0.225]
        transforms_list.append(transforms.Normalize(mean=mean_pix, std=std_pix))
        self.transform = transforms.Compose(transforms_list)

        traindir = os.path.join(_IMAGENET_DATASET_DIR, 'train')
        valdir = os.path.join(_IMAGENET_DATASET_DIR, 'val')
        self.data = datasets.ImageFolder(
            traindir if split=='train' else valdir, self.transform)
        self.labels = [item[1] for item in self.data.imgs] 
开发者ID:gidariss,项目名称:FewShotWithoutForgetting,代码行数:23,代码来源:dataloader.py

示例4: __getitem__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.labels[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index 
开发者ID:k-han,项目名称:DTC,代码行数:23,代码来源:svhnloader.py

示例5: __getitem__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index 
开发者ID:k-han,项目名称:DTC,代码行数:23,代码来源:cifarloader.py

示例6: __init__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __init__(self, *datasets):
        self.datasets = datasets 
开发者ID:AlexiaJM,项目名称:Deep-learning-with-cats,代码行数:4,代码来源:CycleGAN.py

示例7: __getitem__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets) 
开发者ID:AlexiaJM,项目名称:Deep-learning-with-cats,代码行数:4,代码来源:CycleGAN.py

示例8: __len__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __len__(self):
        return min(len(d) for d in self.datasets) 
开发者ID:AlexiaJM,项目名称:Deep-learning-with-cats,代码行数:4,代码来源:CycleGAN.py

示例9: test_input_block

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def test_input_block():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = datasets.ImageFolder('/sequoia/data1/yhasson/datasets/test-dataset',
            transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    densenet = torchvision.models.densenet121(pretrained=True)
    features = densenet.features
    seq2d = torch.nn.Sequential(
        features.conv0, features.norm0, features.relu0, features.pool0)
    seq3d = torch.nn.Sequential(
        inflate.inflate_conv(features.conv0, 3),
        inflate.inflate_batch_norm(features.norm0),
        features.relu0,
        inflate.inflate_pool(features.pool0, 1))

    loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
    frame_nb = 4
    for i, (input_2d, target) in enumerate(loader):
        target = target.cuda()
        target_var = torch.autograd.Variable(target)
        input_2d_var = torch.autograd.Variable(input_2d)
        out2d = seq2d(input_2d_var)
        time_pad = torch.nn.ReplicationPad3d((0, 0, 0, 0, 1, 1))
        input_3d = input_2d.unsqueeze(2).repeat(1, 1, frame_nb, 1, 1)
        input_3d_var = time_pad(input_3d) 
        out3d = seq3d(input_3d_var)
        expected_out_3d = out2d.data.unsqueeze(2).repeat(1, 1, frame_nb, 1, 1)
        out_diff = expected_out_3d - out3d.data
        print(out_diff.max())
        assert(out_diff.max() < 0.0001) 
开发者ID:hassony2,项目名称:kinetics_i3d_pytorch,代码行数:38,代码来源:test_first_block.py

示例10: __getitem__

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
        img, label = self.data[index], self.labels[index]
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, label 
开发者ID:kjunelee,项目名称:MetaOptNet,代码行数:10,代码来源:tiered_imagenet.py

示例11: toy

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def toy(dataset,
        root='~/data/torchvision/',
        transforms=None):
    """Load a train and test datasets from torchvision.dataset.
    """
    if not hasattr(torchvision.datasets, dataset):
        raise ValueError
    loader_def = getattr(torchvision.datasets, dataset)

    transform_funcs = []
    if transforms is not None:
        for transform in transforms:
            if not hasattr(torchvision.transforms, transform['def']):
                raise ValueError
            transform_def = getattr(torchvision.transforms, transform['def'])
            transform_funcs.append(transform_def(**transform['kwargs']))
    transform_funcs.append(torchvision.transforms.ToTensor())

    composed_transform = torchvision.transforms.Compose(transform_funcs)
    trainset = loader_def(
            root=os.path.expanduser(root), train=True,
            download=True, transform=composed_transform)
    testset = loader_def(
            root=os.path.expanduser(root), train=False,
            download=True, transform=composed_transform)
    return trainset, testset 
开发者ID:hav4ik,项目名称:Hydra,代码行数:28,代码来源:toy.py

示例12: create_validation_set

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def create_validation_set(valdir, batch_size, target_size, rect_val, distributed):
    if rect_val:
        idx_ar_sorted = sort_ar(valdir)
        idx_sorted, _ = zip(*idx_ar_sorted)
        idx2ar = map_idx2ar(idx_ar_sorted, batch_size)

        ar_tfms = [transforms.Resize(int(target_size*1.14)), CropArTfm(idx2ar, target_size)]
        val_dataset = ValDataset(valdir, transform=ar_tfms)
        val_sampler = DistValSampler(idx_sorted, batch_size=batch_size, distributed=distributed)
        return val_dataset, val_sampler
    
    val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)]
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
    val_sampler = DistValSampler(list(range(len(val_dataset))), batch_size=batch_size, distributed=distributed)
    return val_dataset, val_sampler 
开发者ID:cybertronai,项目名称:imagenet18_old,代码行数:17,代码来源:dataloader.py

示例13: sort_ar

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def sort_ar(valdir):
    idx2ar_file = valdir+'/../sorted_idxar.p'
    if os.path.isfile(idx2ar_file): return pickle.load(open(idx2ar_file, 'rb'))
    print('Creating AR indexes. Please be patient this may take a couple minutes...')
    val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset
    sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))]
    idx_ar = [(i, round(s[0]/s[1], 5)) for i,s in enumerate(sizes)]
    sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
    pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
    print('Done')
    return sorted_idxar 
开发者ID:cybertronai,项目名称:imagenet18_old,代码行数:13,代码来源:dataloader.py

示例14: getdataloader

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def getdataloader(datatype, train_db_path, test_db_path, batch_size):
    # get transformations
    transform_train, transform_test = _getdatatransformsdb(datatype=datatype)
    n_classes = 0

    # Data loaders
    if datatype.lower() == CIFAR10:
        print("Using CIFAR10 dataset.")
        trainset = torchvision.datasets.CIFAR10(root=train_db_path,
                                                train=True, download=True,
                                                transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=test_db_path,
                                               train=False, download=True,
                                               transform=transform_test)
        n_classes = 10
    elif datatype.lower() == CIFAR100:
        print("Using CIFAR100 dataset.")
        trainset = torchvision.datasets.CIFAR100(root=train_db_path,
                                                 train=True, download=True,
                                                 transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=test_db_path,
                                                train=False, download=True,
                                                transform=transform_test)
        n_classes = 100
    else:
        print("Dataset is not supported.")
        return None, None, None

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=4)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=4)
    return trainloader, testloader, n_classes 
开发者ID:adiyoss,项目名称:WatermarkNN,代码行数:35,代码来源:loaders.py

示例15: create_dataset

# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def create_dataset(args, train):
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0,
                    np.array([63.0, 62.1, 66.7]) / 255.0),
    ])
    if train:
        transform = T.Compose([
            T.Pad(4, padding_mode='reflect'),
            T.RandomHorizontalFlip(),
            T.RandomCrop(32),
            transform
        ])
    return getattr(datasets, args.dataset)(args.dataroot, train=train, download=True, transform=transform) 
开发者ID:szagoruyko,项目名称:binary-wide-resnet,代码行数:16,代码来源:main.py


注:本文中的torchvision.datasets方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。