当前位置: 首页>>代码示例>>Python>>正文


Python datasets.SVHN属性代码示例

本文整理汇总了Python中torchvision.datasets.SVHN属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.SVHN属性的具体用法?Python datasets.SVHN怎么用?Python datasets.SVHN使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.SVHN属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_svhn

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_svhn(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get SVHN dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std)])

    # dataset and data loader
    svhn_dataset = datasets.SVHN(root=cfg.data_root,
                                 split='train' if train else 'test',
                                 transform=pre_process,
                                 download=True)

    if get_dataset:
        return svhn_dataset
    else:
        svhn_data_loader = torch.utils.data.DataLoader(
            dataset=svhn_dataset,
            batch_size=batch_size,
            shuffle=True)
        return svhn_data_loader 
开发者ID:corenel,项目名称:pytorch-atda,代码行数:24,代码来源:svhn.py

示例2: get_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""
    
    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader 
开发者ID:yunjey,项目名称:mnist-svhn-transfer,代码行数:23,代码来源:data_loader.py

示例3: get_targets

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_targets(dataset):
    """Get the targets of a dataset without any target target transforms(!)."""
    if isinstance(dataset, TransformedDataset):
        return get_targets(dataset.dataset)
    if isinstance(dataset, data.Subset):
        targets = get_targets(dataset.dataset)
        return torch.as_tensor(targets)[dataset.indices]
    if isinstance(dataset, data.ConcatDataset):
        return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

    if isinstance(
            dataset, (datasets.MNIST, datasets.ImageFolder,)
    ):
        return torch.as_tensor(dataset.targets)
    if isinstance(dataset, datasets.SVHN):
        return dataset.labels

    raise NotImplementedError(f"Unknown dataset {dataset}!") 
开发者ID:BlackHC,项目名称:BatchBALD,代码行数:20,代码来源:dataset_enum.py

示例4: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_dataset(self):
        """
        Uses torchvision.datasets.CIFAR100 to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.SVHN('datasets/SVHN/train/', split='train', transform=self.train_transforms,
                                 target_transform=None, download=True)
        valset = datasets.SVHN('datasets/SVHN/test/', split='test', transform=self.val_transforms,
                               target_transform=None, download=True)
        extraset = datasets.SVHN('datasets/SVHN/extra', split='extra', transform=self.train_transforms,
                                 target_transform=None, download=True)

        trainset = torch.utils.data.ConcatDataset([trainset, extraset])

        return trainset, valset 
开发者ID:MrtnMndt,项目名称:OCDVAEContinualLearning,代码行数:20,代码来源:datasets.py

示例5: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download)

        # Subsample images to balance the training set
       
        if split == 'train':
            # compute the histogram of original label set
            label_set = np.unique(self.labels)
            num_cls = len(label_set)
            count,_ = np.histogram(self.labels.squeeze(), bins=num_cls)
            min_num = min(count)
            
            # subsample
            ind = np.zeros((num_cls, min_num), dtype=int)
            for i in label_set:
                binary_ind = np.where(self.labels.squeeze() == i)[0]
                np.random.shuffle(binary_ind)
                
                ind[i % num_cls,:] = binary_ind[:min_num]
            
            ind = ind.flatten()
            # shuffle 5 times
            for i in range(100):
                np.random.shuffle(ind)
            self.labels = self.labels[ind]
            self.data = self.data[ind] 
开发者ID:jhoffman,项目名称:cycada_release,代码行数:34,代码来源:svhn_balanced.py

示例6: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download) 
开发者ID:jhoffman,项目名称:cycada_release,代码行数:10,代码来源:svhn.py

示例7: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self):
        super(SVHNMetaInfo, self).__init__()
        self.label = "SVHN"
        self.root_dir_name = "svhn"
        self.dataset_class = SVHNFine
        self.num_training_samples = 73257 
开发者ID:osmr,项目名称:imgclsmob,代码行数:8,代码来源:svhn_cls_dataset.py

示例8: svhn_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def svhn_loaders(batch_size): 
    train = datasets.SVHN("./data", split='train', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    test = datasets.SVHN("./data", split='test', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=True)
    return train_loader, test_loader 
开发者ID:locuslab,项目名称:convex_adversarial,代码行数:8,代码来源:problems.py

示例9: LoadSVHN

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def LoadSVHN(data_root, batch_size=32, split='train', shuffle=True):
    if not os.path.exists(data_root):
        os.makedirs(data_root)
    svhn_dataset = datasets.SVHN(data_root, split=split, download=True,
                                   transform=transforms.ToTensor())
    return DataLoader(svhn_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True) 
开发者ID:Alexander-H-Liu,项目名称:UFDN,代码行数:8,代码来源:data.py

示例10: get_train_val_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR10(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR100(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(
                root=self.args.data, split='train', download=True, transform=train_transform)
            valid_data = dset.SVHN(
                root=self.args.data, split='test', download=True, transform=valid_transform)

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            shuffle=True, pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=self.args.batch_size,
            shuffle=False, pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
开发者ID:automl,项目名称:RobustDARTS,代码行数:31,代码来源:args.py

示例11: get_train_val_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(root=self.args.data, split='train', download=True, transform=train_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(self.args.train_portion * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
开发者ID:automl,项目名称:RobustDARTS,代码行数:28,代码来源:args.py

示例12: getSVHN

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def getSVHN(batch_size, img_size=32, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        new_target = target - 1
        if new_target == -1:
            new_target = 9
        return new_target

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
开发者ID:alinlab,项目名称:Confident_classifier,代码行数:42,代码来源:data_loader.py

示例13: get

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        return int(target) - 1

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
开发者ID:aaron-xichen,项目名称:pytorch-playground,代码行数:39,代码来源:dataset.py

示例14: get_svhn_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_svhn_loaders(config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    training_set = SVHN(config.data_root, split='train', download=True, transform=transform)
    dev_set = SVHN(config.data_root, split='test', download=True, transform=transform)

    def preprocess(data_set):
        for i in range(len(data_set.data)):
            if data_set.labels[i][0] == 10:
                data_set.labels[i][0] = 0
    preprocess(training_set)
    preprocess(dev_set)

    indices = np.arange(len(training_set))
    np.random.shuffle(indices)
    mask = np.zeros(indices.shape[0], dtype=np.bool)
    labels = np.array([training_set[i][1] for i in indices], dtype=np.int64)
    for i in range(10):
        mask[np.where(labels == i)[0][: config.size_labeled_data / 10]] = True
    # labeled_indices, unlabeled_indices = indices[mask], indices[~ mask]
    labeled_indices, unlabeled_indices = indices[mask], indices
    print 'labeled size', labeled_indices.shape[0], 'unlabeled size', unlabeled_indices.shape[0], 'dev size', len(dev_set)

    labeled_loader = DataLoader(config, training_set, labeled_indices, config.train_batch_size)
    unlabeled_loader = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size)
    unlabeled_loader2 = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size_2)
    dev_loader = DataLoader(config, dev_set, np.arange(len(dev_set)), config.dev_batch_size)

    special_set = []
    for i in range(10):
        special_set.append(training_set[indices[np.where(labels==i)[0][0]]][0])
    special_set = torch.stack(special_set)

    return labeled_loader, unlabeled_loader, unlabeled_loader2, dev_loader, special_set 
开发者ID:kimiyoung,项目名称:ssl_bad_gan,代码行数:35,代码来源:data.py

示例15: __new__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __new__(cls,
                root,
                train=True,
                transform=None,
                download=False):
        if train:
            return (datasets.SVHN(root, split='train', transform=transform, download=download) +
                    datasets.SVHN(root, split='extra', transform=transform, download=download))
        else:
            return OriginalSVHN(root, train=False, transform=transform, download=download) 
开发者ID:moskomule,项目名称:homura,代码行数:12,代码来源:datasets.py


注:本文中的torchvision.datasets.SVHN属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。