當前位置: 首頁>>代碼示例>>Python>>正文


Python datasets.SVHN屬性代碼示例

本文整理匯總了Python中torchvision.datasets.SVHN屬性的典型用法代碼示例。如果您正苦於以下問題:Python datasets.SVHN屬性的具體用法?Python datasets.SVHN怎麽用?Python datasets.SVHN使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.SVHN屬性的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_svhn

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_svhn(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get SVHN dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std)])

    # dataset and data loader
    svhn_dataset = datasets.SVHN(root=cfg.data_root,
                                 split='train' if train else 'test',
                                 transform=pre_process,
                                 download=True)

    if get_dataset:
        return svhn_dataset
    else:
        svhn_data_loader = torch.utils.data.DataLoader(
            dataset=svhn_dataset,
            batch_size=batch_size,
            shuffle=True)
        return svhn_data_loader 
開發者ID:corenel,項目名稱:pytorch-atda,代碼行數:24,代碼來源:svhn.py

示例2: get_loader

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""
    
    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader 
開發者ID:yunjey,項目名稱:mnist-svhn-transfer,代碼行數:23,代碼來源:data_loader.py

示例3: get_targets

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_targets(dataset):
    """Get the targets of a dataset without any target target transforms(!)."""
    if isinstance(dataset, TransformedDataset):
        return get_targets(dataset.dataset)
    if isinstance(dataset, data.Subset):
        targets = get_targets(dataset.dataset)
        return torch.as_tensor(targets)[dataset.indices]
    if isinstance(dataset, data.ConcatDataset):
        return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

    if isinstance(
            dataset, (datasets.MNIST, datasets.ImageFolder,)
    ):
        return torch.as_tensor(dataset.targets)
    if isinstance(dataset, datasets.SVHN):
        return dataset.labels

    raise NotImplementedError(f"Unknown dataset {dataset}!") 
開發者ID:BlackHC,項目名稱:BatchBALD,代碼行數:20,代碼來源:dataset_enum.py

示例4: get_dataset

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_dataset(self):
        """
        Uses torchvision.datasets.CIFAR100 to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.SVHN('datasets/SVHN/train/', split='train', transform=self.train_transforms,
                                 target_transform=None, download=True)
        valset = datasets.SVHN('datasets/SVHN/test/', split='test', transform=self.val_transforms,
                               target_transform=None, download=True)
        extraset = datasets.SVHN('datasets/SVHN/extra', split='extra', transform=self.train_transforms,
                                 target_transform=None, download=True)

        trainset = torch.utils.data.ConcatDataset([trainset, extraset])

        return trainset, valset 
開發者ID:MrtnMndt,項目名稱:OCDVAEContinualLearning,代碼行數:20,代碼來源:datasets.py

示例5: __init__

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download)

        # Subsample images to balance the training set
       
        if split == 'train':
            # compute the histogram of original label set
            label_set = np.unique(self.labels)
            num_cls = len(label_set)
            count,_ = np.histogram(self.labels.squeeze(), bins=num_cls)
            min_num = min(count)
            
            # subsample
            ind = np.zeros((num_cls, min_num), dtype=int)
            for i in label_set:
                binary_ind = np.where(self.labels.squeeze() == i)[0]
                np.random.shuffle(binary_ind)
                
                ind[i % num_cls,:] = binary_ind[:min_num]
            
            ind = ind.flatten()
            # shuffle 5 times
            for i in range(100):
                np.random.shuffle(ind)
            self.labels = self.labels[ind]
            self.data = self.data[ind] 
開發者ID:jhoffman,項目名稱:cycada_release,代碼行數:34,代碼來源:svhn_balanced.py

示例6: __init__

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def __init__(self, root, train=True,
            transform=None, target_transform=None, download=False):
        if train:
            split = 'train'
        else:
            split = 'test'
        super(SVHN, self).__init__(root, split=split, transform=transform,
                target_transform=target_transform, download=download) 
開發者ID:jhoffman,項目名稱:cycada_release,代碼行數:10,代碼來源:svhn.py

示例7: __init__

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def __init__(self):
        super(SVHNMetaInfo, self).__init__()
        self.label = "SVHN"
        self.root_dir_name = "svhn"
        self.dataset_class = SVHNFine
        self.num_training_samples = 73257 
開發者ID:osmr,項目名稱:imgclsmob,代碼行數:8,代碼來源:svhn_cls_dataset.py

示例8: svhn_loaders

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def svhn_loaders(batch_size): 
    train = datasets.SVHN("./data", split='train', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    test = datasets.SVHN("./data", split='test', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=True)
    return train_loader, test_loader 
開發者ID:locuslab,項目名稱:convex_adversarial,代碼行數:8,代碼來源:problems.py

示例9: LoadSVHN

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def LoadSVHN(data_root, batch_size=32, split='train', shuffle=True):
    if not os.path.exists(data_root):
        os.makedirs(data_root)
    svhn_dataset = datasets.SVHN(data_root, split=split, download=True,
                                   transform=transforms.ToTensor())
    return DataLoader(svhn_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True) 
開發者ID:Alexander-H-Liu,項目名稱:UFDN,代碼行數:8,代碼來源:data.py

示例10: get_train_val_loaders

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR10(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(
                root=self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dset.CIFAR100(
                root=self.args.data, train=False, download=True, transform=valid_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(
                root=self.args.data, split='train', download=True, transform=train_transform)
            valid_data = dset.SVHN(
                root=self.args.data, split='test', download=True, transform=valid_transform)

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            shuffle=True, pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data, batch_size=self.args.batch_size,
            shuffle=False, pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
開發者ID:automl,項目名稱:RobustDARTS,代碼行數:31,代碼來源:args.py

示例11: get_train_val_loaders

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(root=self.args.data, split='train', download=True, transform=train_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(self.args.train_portion * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform 
開發者ID:automl,項目名稱:RobustDARTS,代碼行數:28,代碼來源:args.py

示例12: getSVHN

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def getSVHN(batch_size, img_size=32, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        new_target = target - 1
        if new_target == -1:
            new_target = 9
        return new_target

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.Scale(img_size),
                    transforms.ToTensor(),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
開發者ID:alinlab,項目名稱:Confident_classifier,代碼行數:42,代碼來源:data_loader.py

示例13: get

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'svhn-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building SVHN data loader with {} workers".format(num_workers))

    def target_transform(target):
        return int(target) - 1

    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform,
            ),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=target_transform
            ),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)
    ds = ds[0] if len(ds) == 1 else ds
    return ds 
開發者ID:aaron-xichen,項目名稱:pytorch-playground,代碼行數:39,代碼來源:dataset.py

示例14: get_svhn_loaders

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def get_svhn_loaders(config):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    training_set = SVHN(config.data_root, split='train', download=True, transform=transform)
    dev_set = SVHN(config.data_root, split='test', download=True, transform=transform)

    def preprocess(data_set):
        for i in range(len(data_set.data)):
            if data_set.labels[i][0] == 10:
                data_set.labels[i][0] = 0
    preprocess(training_set)
    preprocess(dev_set)

    indices = np.arange(len(training_set))
    np.random.shuffle(indices)
    mask = np.zeros(indices.shape[0], dtype=np.bool)
    labels = np.array([training_set[i][1] for i in indices], dtype=np.int64)
    for i in range(10):
        mask[np.where(labels == i)[0][: config.size_labeled_data / 10]] = True
    # labeled_indices, unlabeled_indices = indices[mask], indices[~ mask]
    labeled_indices, unlabeled_indices = indices[mask], indices
    print 'labeled size', labeled_indices.shape[0], 'unlabeled size', unlabeled_indices.shape[0], 'dev size', len(dev_set)

    labeled_loader = DataLoader(config, training_set, labeled_indices, config.train_batch_size)
    unlabeled_loader = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size)
    unlabeled_loader2 = DataLoader(config, training_set, unlabeled_indices, config.train_batch_size_2)
    dev_loader = DataLoader(config, dev_set, np.arange(len(dev_set)), config.dev_batch_size)

    special_set = []
    for i in range(10):
        special_set.append(training_set[indices[np.where(labels==i)[0][0]]][0])
    special_set = torch.stack(special_set)

    return labeled_loader, unlabeled_loader, unlabeled_loader2, dev_loader, special_set 
開發者ID:kimiyoung,項目名稱:ssl_bad_gan,代碼行數:35,代碼來源:data.py

示例15: __new__

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import SVHN [as 別名]
def __new__(cls,
                root,
                train=True,
                transform=None,
                download=False):
        if train:
            return (datasets.SVHN(root, split='train', transform=transform, download=download) +
                    datasets.SVHN(root, split='extra', transform=transform, download=download))
        else:
            return OriginalSVHN(root, train=False, transform=transform, download=download) 
開發者ID:moskomule,項目名稱:homura,代碼行數:12,代碼來源:datasets.py


注:本文中的torchvision.datasets.SVHN屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。