当前位置: 首页>>代码示例>>Python>>正文


Python datasets.CIFAR10属性代码示例

本文整理汇总了Python中torchvision.datasets.CIFAR10属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.CIFAR10属性的具体用法?Python datasets.CIFAR10怎么用?Python datasets.CIFAR10使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.CIFAR10属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = nn.DataParallel(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, 
    #                       momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
开发者ID:dnddnjs,项目名称:pytorch-multigpu,代码行数:38,代码来源:train.py

示例2: __getitem__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def __getitem__(self, index):
        """Override the original method of the CIFAR10 class.
        Args:
            index (int): Index

        Returns:
            tuple: (image, target, semi_target, index)
        """
        img, target, semi_target = self.data[index], self.targets[index], int(self.semi_targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, semi_target, index 
开发者ID:lukasruff,项目名称:Deep-SAD-PyTorch,代码行数:23,代码来源:cifar10.py

示例3: get_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	# binarize
	label_0 = data_y < 5
	label_1 = ~label_0

	data_y[label_0] = 0
	data_y[label_1] = 1

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
开发者ID:jaromiru,项目名称:cwcf,代码行数:27,代码来源:conv_cifar_2.py

示例4: get_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def get_data(train):
	data_raw = datasets.CIFAR10('../data/dl/', train=train, download=True,  transform=transforms.Compose([
							transforms.Grayscale(),
							transforms.Resize((20, 20)),
							transforms.ToTensor(),
							lambda x: x.numpy().flatten()]))

	data_x, data_y = zip(*data_raw)
	
	data_x = np.array(data_x)
	data_y = np.array(data_y, dtype='int32').reshape(-1, 1)

	data = pd.DataFrame(data_x)
	data[COLUMN_LABEL] = data_y

	return data, data_x.mean(), data_x.std()

#--- 
开发者ID:jaromiru,项目名称:cwcf,代码行数:20,代码来源:conv_cifar.py

示例5: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def __init__(self):
        super(CIFAR10MetaInfo, self).__init__()
        self.label = "CIFAR10"
        self.short_label = "cifar"
        self.root_dir_name = "cifar10"
        self.dataset_class = CIFAR10Fine
        self.num_training_samples = 50000
        self.in_channels = 3
        self.num_classes = 10
        self.input_image_size = (32, 32)
        self.train_metric_capts = ["Train.Err"]
        self.train_metric_names = ["Top1Error"]
        self.train_metric_extra_kwargs = [{"name": "err"}]
        self.val_metric_capts = ["Val.Err"]
        self.val_metric_names = ["Top1Error"]
        self.val_metric_extra_kwargs = [{"name": "err"}]
        self.saver_acc_ind = 0
        self.train_transform = cifar10_train_transform
        self.val_transform = cifar10_val_transform
        self.test_transform = cifar10_val_transform
        self.ml_type = "imgcls" 
开发者ID:osmr,项目名称:imgclsmob,代码行数:23,代码来源:cifar10_cls_dataset.py

示例6: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def load_data():
    transform_train = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True,
                                transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                              shuffle=False, num_workers=0)

    testset = datasets.CIFAR10(root='./data', train=False, download=True,
                               transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False, num_workers=0)
    return trainloader, testloader 
开发者ID:flyingpot,项目名称:pytorch_deephash,代码行数:21,代码来源:evaluate.py

示例7: init_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def init_dataset():
    transform_train = transforms.Compose(
        [transforms.Resize(256),
         transforms.RandomCrop(227),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose(
        [transforms.Resize(227),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True,
                                transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=0)

    testset = datasets.CIFAR10(root='./data', train=False, download=True,
                               transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=True, num_workers=0)
    return trainloader, testloader 
开发者ID:flyingpot,项目名称:pytorch_deephash,代码行数:23,代码来源:train.py

示例8: make_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def make_dataset():
    if opt.dataset in ("imagenet", "dog_and_cat_64", "dog_and_cat_128"):
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = ImageFolder(opt.root, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=False, num_workers=opt.workers)
    elif opt.dataset == "cifar10":
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = CIFAR10(root=opt.root, train=True, download=False, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=True, num_workers=opt.workers)
    else:
        raise ValueError(f"Unknown dataset: {opt.dataset}")
    return loader 
开发者ID:xuanqing94,项目名称:RobGAN,代码行数:20,代码来源:acc_under_attack.py

示例9: cifar_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def cifar_loaders(batch_size, shuffle_test=False): 
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.225, 0.225, 0.225])
    train = datasets.CIFAR10('./data', train=True, download=True, 
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]))
    test = datasets.CIFAR10('./data', train=False, 
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=True, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size,
        shuffle=shuffle_test, pin_memory=True)
    return train_loader, test_loader 
开发者ID:locuslab,项目名称:convex_adversarial,代码行数:19,代码来源:problems.py

示例10: build_datasets

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def build_datasets(self):
        """ You must to rewrite this method to load your own datasets.

        * :attr:`self.dataset_train` . Assign a training ``dataset`` to this.
        * :attr:`self.dataset_valid` . Assign a valid_epoch ``dataset`` to this.
        * :attr:`self.dataset_test` is optional. Assign a test ``dataset`` to this.
          If not, it will be replaced by ``self.dataset_valid`` .

        Example::

            self.dataset_train = datasets.CIFAR10(root, train=True, download=True,
                                                  transform=transforms.Compose(self.train_transform_list))
            self.dataset_valid = datasets.CIFAR10(root, train=False, download=True,
                                                  transform=transforms.Compose(self.valid_transform_list))
        """
        pass 
开发者ID:dingguanglei,项目名称:jdit,代码行数:18,代码来源:dataset.py

示例11: __getitem__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def __getitem__(self, index):
        """Override the original method of the CIFAR10 class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  # only line changed 
开发者ID:lukasruff,项目名称:Deep-SVDD-PyTorch,代码行数:25,代码来源:cifar10.py

示例12: get_datasets

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def get_datasets(initial_pool):
    transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(30),
         transforms.ToTensor(),
         transforms.Normalize(3 * [0.5], 3 * [0.5]), ])
    test_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(3 * [0.5], 3 * [0.5]),
        ]
    )
    # Note: We use the test set here as an example. You should make your own validation set.
    train_ds = datasets.CIFAR10('.', train=True,
                                transform=transform, target_transform=None, download=True)
    test_set = datasets.CIFAR10('.', train=False,
                                transform=test_transform, target_transform=None, download=True)

    active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform})

    # We start labeling randomly.
    active_set.label_randomly(initial_pool)
    return active_set, test_set 
开发者ID:ElementAI,项目名称:baal,代码行数:27,代码来源:vgg_mcdropout_cifar10.py

示例13: main

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr, 
                          momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
开发者ID:dnddnjs,项目名称:pytorch-multigpu,代码行数:36,代码来源:train.py

示例14: test

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def test():
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
开发者ID:Eric-mingjie,项目名称:network-slimming,代码行数:31,代码来源:prune_mask.py

示例15: test

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import CIFAR10 [as 别名]
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    else:
        raise ValueError("No valid dataset is given.")
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset)) 
开发者ID:Eric-mingjie,项目名称:network-slimming,代码行数:31,代码来源:vggprune.py


注:本文中的torchvision.datasets.CIFAR10属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。