当前位置: 首页>>代码示例>>Python>>正文


Python datasets.FashionMNIST方法代码示例

本文整理汇总了Python中torchvision.datasets.FashionMNIST方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.FashionMNIST方法的具体用法?Python datasets.FashionMNIST怎么用?Python datasets.FashionMNIST使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.FashionMNIST方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_fashion_mnist_dataloaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def get_fashion_mnist_dataloaders(batch_size=128):
    """Fashion MNIST dataloader with (32, 32) sized images."""
    # Resize images so they are a power of 2
    all_transforms = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])
    # Get train and test data
    train_data = datasets.FashionMNIST('../fashion_data', train=True, download=True,
                                       transform=all_transforms)
    test_data = datasets.FashionMNIST('../fashion_data', train=False,
                                      transform=all_transforms)
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader 
开发者ID:vandit15,项目名称:Self-Supervised-Gans-Pytorch,代码行数:18,代码来源:dataloaders.py

示例2: get_fashion_mnist_dataloaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def get_fashion_mnist_dataloaders(root=r'.\dataset\fashion_data', batch_size=128, resize=32, transform_list=None,
                                  num_workers=-1):
    """Fashion MNIST dataloader with (32, 32) sized images."""
    # Resize images so they are a power of 2
    if num_workers == -1:
        print("use %d thread!" % psutil.cpu_count())
        num_workers = psutil.cpu_count()
    if transform_list is None:
        transform_list = [
            transforms.Resize(resize),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]
    all_transforms = transforms.Compose(transform_list)
    # Get train and test data
    train_data = datasets.FashionMNIST(root, train=True, download=True,
                                       transform=all_transforms)
    test_data = datasets.FashionMNIST(root, train=False,
                                      transform=all_transforms)
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
    return train_loader, test_loader 
开发者ID:dingguanglei,项目名称:jdit,代码行数:25,代码来源:dataset.py

示例3: load_fashion_mnist

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def load_fashion_mnist(args):
    path = 'data/fashion_mnist'
    torch.cuda.manual_seed(1)
    kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
    train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST(path, train=True, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])),
                batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST(path, train=False, download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                    ])),
                batch_size=100, shuffle=False, **kwargs)
    return train_loader, test_loader 
开发者ID:neale,项目名称:Adversarial-Autoencoder,代码行数:21,代码来源:datagen.py

示例4: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def __init__(self, root: str, normal_class: int = 0, known_outlier_class: int = 1, n_known_outlier_classes: int = 0,
                 ratio_known_normal: float = 0.0, ratio_known_outlier: float = 0.0, ratio_pollution: float = 0.0):
        super().__init__(root)

        # Define normal and outlier classes
        self.n_classes = 2  # 0: normal, 1: outlier
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 10))
        self.outlier_classes.remove(normal_class)
        self.outlier_classes = tuple(self.outlier_classes)

        if n_known_outlier_classes == 0:
            self.known_outlier_classes = ()
        elif n_known_outlier_classes == 1:
            self.known_outlier_classes = tuple([known_outlier_class])
        else:
            self.known_outlier_classes = tuple(random.sample(self.outlier_classes, n_known_outlier_classes))

        # FashionMNIST preprocessing: feature scaling to [0, 1]
        transform = transforms.ToTensor()
        target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))

        # Get train set
        train_set = MyFashionMNIST(root=self.root, train=True, transform=transform, target_transform=target_transform,
                                   download=True)

        # Create semi-supervised setting
        idx, _, semi_targets = create_semisupervised_setting(train_set.targets.cpu().data.numpy(), self.normal_classes,
                                                             self.outlier_classes, self.known_outlier_classes,
                                                             ratio_known_normal, ratio_known_outlier, ratio_pollution)
        train_set.semi_targets[idx] = torch.tensor(semi_targets)  # set respective semi-supervised labels

        # Subset train_set to semi-supervised setup
        self.train_set = Subset(train_set, idx)

        # Get test set
        self.test_set = MyFashionMNIST(root=self.root, train=False, transform=transform,
                                       target_transform=target_transform, download=True) 
开发者ID:lukasruff,项目名称:Deep-SAD-PyTorch,代码行数:40,代码来源:fmnist.py

示例5: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def __init__(self, batch_size=256, subset_size=50000, test_batch_size=256, dirpath=None):
        trans = transforms.Compose([transforms.ToTensor()])

        self._dirpath = dirpath
        if not self._dirpath:
            self._dirpath = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                '..', 'data','fmnit_data')


        train_set = dset.FashionMNIST(root=self._dirpath, train=True, transform=trans, download=True)
        test_set = dset.FashionMNIST(root=self._dirpath, train=False, transform=trans, download=True)

        indices = torch.randperm(len(train_set))[:subset_size]
        train_set = torch.utils.data.Subset(train_set, indices)

        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=batch_size,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=test_batch_size,
            shuffle=False)

        self.name = "fmnist"
        self.data_dims = [28, 28, 1]
        self.train_size = len(self.train_loader)
        self.test_size = len(self.test_loader)
        self.range = [0.0, 1.0]
        self.batch_size = batch_size
        self.num_training_instances = len(train_set)
        self.num_test_instances = len(test_set)
        self.likelihood_type = 'gaussian'
        self.output_activation_type = 'sigmoid' 
开发者ID:IBM,项目名称:AIX360,代码行数:36,代码来源:fashion_mnist_dataset.py

示例6: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def __init__(self, options):
        transform_list = []
        if options.image_size is not None:
            transform_list.append(transforms.Resize((options.image_size, options.image_size)))
            # transform_list.append(transforms.CenterCrop(options.image_size))
        transform_list.append(transforms.ToTensor())
        if options.image_colors == 1:
            transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
        elif options.image_colors == 3:
            transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        transform = transforms.Compose(transform_list)

        if options.dataset == 'mnist':
            dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'emnist':
            # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
            datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
            dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform)
        elif options.dataset == 'fashion-mnist':
            dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'lsun':
            training_class = options.image_class + '_train'
            dataset =  datasets.LSUN(options.data_dir, classes=[training_class], transform=transform)
        elif options.dataset == 'cifar10':
            dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'cifar100':
            dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform)
        else:
            dataset = datasets.ImageFolder(root=options.data_dir, transform=transform)

        self.dataloader = DataLoader(
            dataset,
            batch_size=options.batch_size,
            num_workers=options.loader_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=options.pin_memory
        )
        self.iterator = iter(self.dataloader) 
开发者ID:unicredit,项目名称:ganzo,代码行数:41,代码来源:data.py

示例7: get_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def get_data(dataset, data_path, cutout_length, validation):
    """ Get torchvision dataset """
    dataset = dataset.lower()

    if dataset == 'cifar10':
        dset_cls = dset.CIFAR10
        n_classes = 10
    elif dataset == 'mnist':
        dset_cls = dset.MNIST
        n_classes = 10
    elif dataset == 'fashionmnist':
        dset_cls = dset.FashionMNIST
        n_classes = 10
    else:
        raise ValueError(dataset)

    trn_transform, val_transform = preproc.data_transforms(dataset, cutout_length)
    trn_data = dset_cls(root=data_path, train=True, download=True, transform=trn_transform)

    # assuming shape is NHW or NHWC
    shape = trn_data.train_data.shape
    input_channels = 3 if len(shape) == 4 else 1
    assert shape[1] == shape[2], "not expected shape = {}".format(shape)
    input_size = shape[1]

    ret = [input_size, input_channels, n_classes, trn_data]
    if validation: # append validation data
        ret.append(dset_cls(root=data_path, train=False, download=True, transform=val_transform))

    return ret 
开发者ID:khanrc,项目名称:pt.darts,代码行数:32,代码来源:utils.py

示例8: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def __init__(self, root="datasets/fashion_data", batch_size=64, num_workers=-1):
        super(FashionMNIST, self).__init__(root, batch_size, num_workers) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:4,代码来源:dataset.py

示例9: build_datasets

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def build_datasets(self):
        self.dataset_train = datasets.FashionMNIST(self.root, train=True, download=True,
                                                   transform=transforms.Compose(self.train_transform_list))
        self.dataset_valid = datasets.FashionMNIST(self.root, train=False, download=True,
                                                   transform=transforms.Compose(self.valid_transform_list)) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:7,代码来源:dataset.py

示例10: load_dataset_test

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def load_dataset_test(data_dir, dataset, batch_size):
    list_classes_test = []

    fas=False

    path = os.path.join(data_dir, 'Datasets', dataset)
    
    if dataset == 'mnist':
        dataset_test = datasets.MNIST(path, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    elif dataset == 'fashion':
        if fas:
            dataset_test = DataLoader(
                datasets.FashionMNIST(path, train=False, download=True, transform=transforms.Compose(
                    [transforms.ToTensor()])),
                batch_size=batch_size)
        else:
            dataset_test = fashion(path, train=False, download=True, transform=transforms.ToTensor())

    elif dataset == 'cifar10':
        transform = transforms.Compose(
                [transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        dataset_test = datasets.CIFAR10(root=path, train=False,
                   download=True, transform=transform)

    elif dataset == 'celebA':
        dataset_test = utils.load_celebA(path + 'celebA', transform=transforms.Compose(
            [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=batch_size)
    elif dataset == 'timagenet':
        dataset_test, labels = get_test_image_folders(path)
        list_classes_test = np.asarray([labels[i] for i in range(len(dataset_test))])
        dataset_test = Subset(dataset_test, np.where(list_classes_test < 10)[0])
        list_classes_test = np.where(list_classes_test < 10)[0]

    list_classes_test = np.asarray([dataset_test[i][1] for i in range(len(dataset_test))])

    return dataset_test, list_classes_test 
开发者ID:TLESORT,项目名称:Generative_Continual_Learning,代码行数:40,代码来源:load_dataset.py

示例11: main

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def main(args):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # Data loading code
    print('=> creating training set...')
    train_transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.FashionMNIST(args.data, train=True,
                                          transform=train_transform,
                                          target_transform=None,
                                          download=True)
    print('=> create train dataloader...')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print('=> creating validation set...')
    val_transform = transforms.Compose([transforms.ToTensor()])
    val_dataset = datasets.FashionMNIST(args.data, train=False,
                                        transform=val_transform,
                                        target_transform=None,
                                        download=True)
    print('=> creating validation dataloader...')
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    main_worker(train_loader, val_loader, NUM_CLASSES, args) 
开发者ID:Cerebras,项目名称:online-normalization,代码行数:42,代码来源:fmnist_main.py

示例12: data_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def data_loader(dataset, batch_size, shuffle_test=False):

    if dataset == 'mnist':
        train_data = datasets.MNIST("./data/mnist", train=True, download=True, transform=transforms.ToTensor())
        test_data = datasets.MNIST("./data/mnist", train=False, download=True, transform=transforms.ToTensor())
    elif dataset == 'fmnist':
        train_data = datasets.FashionMNIST("./data/fmnist", train=True, download=True, transform=transforms.ToTensor())
        test_data = datasets.FashionMNIST("./data/fmnist", train=False, download=True, transform=transforms.ToTensor())
    elif dataset == 'cifar10':
        train_data = datasets.CIFAR10("./data/cifar10", train=True, download=True,
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(32, 4),
                                          transforms.ToTensor(),
                                      ]))
        test_data = datasets.CIFAR10('./data/cifar10', train=False, download=True, transform=transforms.ToTensor())
    elif dataset == 'gts':
        train = scipy.io.loadmat('datasets/{}/{}_int_train.mat'.format(dataset, dataset))
        test = scipy.io.loadmat('datasets/{}/{}_int_train.mat'.format(dataset, dataset))
        x_train, y_train, x_test, y_test = train['images'], train['labels'], test['images'], test['labels']

        X_te = torch.from_numpy(x_test).float().permute([0, 3, 1, 2])  # NHWC to NCHW
        X_tr = torch.from_numpy(x_train).float().permute([0, 3, 1, 2])  # NHWC to NCHW
        y_te = torch.from_numpy(y_test).long()
        y_tr = torch.from_numpy(y_train).long()

        train_data = td.TensorDataset(X_tr, y_tr)
        test_data = td.TensorDataset(X_te, y_te)
    else:
        raise ValueError('wrong dataset')

    pin_memory = True
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=pin_memory)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=shuffle_test, pin_memory=pin_memory)
    return train_loader, test_loader 
开发者ID:max-andr,项目名称:provable-robustness-max-linear-regions,代码行数:37,代码来源:utils.py

示例13: make_32x32_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def make_32x32_dataset(dataset, batch_size, drop_remainder=True, shuffle=True, num_workers=4, pin_memory=False):

    if dataset == 'mnist':
        transform = transforms.Compose([
            transforms.Resize(size=(32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        dataset = datasets.MNIST('data/MNIST', transform=transform, download=True)
        img_shape = [32, 32, 1]

    elif dataset == 'fashion_mnist':
        transform = transforms.Compose([
            transforms.Resize(size=(32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        dataset = datasets.FashionMNIST('data/FashionMNIST', transform=transform, download=True)
        img_shape = [32, 32, 1]

    elif dataset == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        dataset = datasets.CIFAR10('data/CIFAR10', transform=transform, download=True)
        img_shape = [32, 32, 3]

    else:
        raise NotImplementedError

    dataset = OnlyImage(dataset)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_remainder, pin_memory=pin_memory)

    return data_loader, img_shape 
开发者ID:LynnHo,项目名称:DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch,代码行数:37,代码来源:data.py

示例14: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def get_dataset(self):
        """
        Uses torchvision.datasets.FashionMNIST to load dataset.
        Downloads dataset if doesn't exist already.

        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.FashionMNIST('datasets/FashionMNIST/train/', train=True, transform=self.train_transforms,
                                         target_transform=None, download=True)
        valset = datasets.FashionMNIST('datasets/FashionMNIST/test/', train=False, transform=self.val_transforms,
                                       target_transform=None, download=True)

        return trainset, valset 
开发者ID:MrtnMndt,项目名称:OCDVAEContinualLearning,代码行数:17,代码来源:datasets.py

示例15: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FashionMNIST [as 别名]
def __init__(self, batch_size, binarize=False, logit_transform=False):
        """ [-1, 1, 28, 28]
        """
        if binarize:
            raise NotImplementedError

        self.logit_transform = logit_transform

        directory='./datasets/FashionMNIST'
        if not os.path.exists(directory):
            os.makedirs(directory)

        kwargs = {'num_workers': num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}
        self.train_loader = DataLoader(
            datasets.FashionMNIST(directory, train=True, download=True,
                           transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=True, **kwargs)
        self.test_loader = DataLoader(
            datasets.FashionMNIST(directory, train=False, download=True, transform=transforms.ToTensor()),
            batch_size=batch_size, shuffle=False, **kwargs)

        self.dim = [1,28,28]

        train = torch.stack([data for data, _ in
                                list(self.train_loader.dataset)], 0).cuda()
        train = train.view(train.shape[0], -1)
        if self.logit_transform:
            train = train * 255.0
            train = (train + torch.rand_like(train)) / 256.0
            train = lamb + (1 - 2.0 * lamb) * train
            train = torch.log(train) - torch.log(1.0 - train)

        self.mean = train.mean(0)
        self.logvar = torch.log(torch.mean((train - self.mean)**2)).unsqueeze(0) 
开发者ID:yookoon,项目名称:VLAE,代码行数:36,代码来源:datasets.py


注:本文中的torchvision.datasets.FashionMNIST方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。