当前位置: 首页>>代码示例>>Python>>正文


Python datasets.STL10属性代码示例

本文整理汇总了Python中torchvision.datasets.STL10属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.STL10属性的具体用法?Python datasets.STL10怎么用?Python datasets.STL10使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.STL10属性的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
    train = (split == 'train')
    root = os.path.join(datasets_path, name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform) 
开发者ID:eladhoffer,项目名称:bigBatch,代码行数:38,代码来源:data.py

示例2: get_encoder_size

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_encoder_size(dataset):
    if dataset in [Dataset.C10, Dataset.C100]:
        return 32
    if dataset == Dataset.STL10:
        return 64
    if dataset in [Dataset.IN128, Dataset.PLACES205]:
        return 128
    raise RuntimeError("Couldn't get encoder size, unknown dataset: {}".format(dataset)) 
开发者ID:Philip-Bachman,项目名称:amdim-public,代码行数:10,代码来源:datasets.py

示例3: _get_directories

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def _get_directories(dataset, input_dir):
    if dataset in [Dataset.C10, Dataset.C100, Dataset.STL10]:
        # Pytorch will download those datasets automatically
        return None, None
    if dataset == Dataset.IN128:
        train_dir = os.path.join(input_dir, 'ILSVRC2012_img_train/')
        val_dir = os.path.join(input_dir, 'ILSVRC2012_img_val/')
    elif dataset == Dataset.PLACES205:
        train_dir = os.path.join(input_dir, 'places205_256_train/')
        val_dir = os.path.join(input_dir, 'places205_256_val/')
    else:
        raise 'Data directories for dataset ' + dataset + ' are not defined'
    return train_dir, val_dir 
开发者ID:Philip-Bachman,项目名称:amdim-public,代码行数:15,代码来源:datasets.py

示例4: get

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get(batch_size, data_root='/mnt/local0/public_dataset/pytorch/', train=True, val=True, **kwargs):
    data_root = os.path.expanduser(os.path.join(data_root, 'stl10-data'))
    num_workers = kwargs.setdefault('num_workers', 1)
    kwargs.pop('input_size', None)
    print("Building STL10 data loader with {} workers".format(num_workers))
    ds = []
    if train:
        train_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='train', download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(96),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=True, **kwargs)
        ds.append(train_loader)

    if val:
        test_loader = torch.utils.data.DataLoader(
            datasets.STL10(
                root=data_root, split='test', download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
            batch_size=batch_size, shuffle=False, **kwargs)
        ds.append(test_loader)

    ds = ds[0] if len(ds) == 1 else ds
    return ds 
开发者ID:aaron-xichen,项目名称:pytorch-playground,代码行数:35,代码来源:dataset.py

示例5: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(self):
        """
        Uses torchvision.datasets.STL to load dataset.
        Downloads dataset if doesn't exist already.
        Returns:
             torch.utils.data.TensorDataset: trainset, valset
        """

        trainset = datasets.STL10('datasets/STL10/train/', split='train', transform=self.train_transforms,
                                  target_transform=None, download=True)
        valset = datasets.STL10('datasets/STL10/test/', split='test', transform=self.val_transforms,
                                target_transform=None, download=True)

        return trainset, valset 
开发者ID:MrtnMndt,项目名称:Deep_Openset_Recognition_through_Uncertainty,代码行数:16,代码来源:datasets.py

示例6: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def __init__(self, args, cur_img_size=None):
        img_size = cur_img_size if cur_img_size else args.img_size
        if args.dataset.lower() == 'cifar10':
            Dt = datasets.CIFAR10
            transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
            args.n_classes = 10
        elif args.dataset.lower() == 'stl10':
            Dt = datasets.STL10
            transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            raise NotImplementedError('Unknown dataset: {}'.format(args.dataset))

        if args.dataset.lower() == 'stl10':
            self.train = torch.utils.data.DataLoader(
                Dt(root=args.data_path, split='train+unlabeled', transform=transform, download=True),
                batch_size=args.dis_batch_size, shuffle=True,
                num_workers=args.num_workers, pin_memory=True)

            self.valid = torch.utils.data.DataLoader(
                Dt(root=args.data_path, split='test', transform=transform),
                batch_size=args.dis_batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

            self.test = self.valid
        else:
            self.train = torch.utils.data.DataLoader(
                Dt(root=args.data_path, train=True, transform=transform, download=True),
                batch_size=args.dis_batch_size, shuffle=True,
                num_workers=args.num_workers, pin_memory=True)

            self.valid = torch.utils.data.DataLoader(
                Dt(root=args.data_path, train=False, transform=transform),
                batch_size=args.dis_batch_size, shuffle=False,
                num_workers=args.num_workers, pin_memory=True)

            self.test = self.valid 
开发者ID:TAMU-VITA,项目名称:AutoGAN,代码行数:46,代码来源:datasets.py

示例7: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(name, split='train', transform=None,
                target_transform=None, download=True, datasets_path='~/Datasets'):
    train = (split == 'train')
    root = os.path.join(os.path.expanduser(datasets_path), name)
    if name == 'cifar10':
        return datasets.CIFAR10(root=root,
                                train=train,
                                transform=transform,
                                target_transform=target_transform,
                                download=download)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=root,
                                 train=train,
                                 transform=transform,
                                 target_transform=target_transform,
                                 download=download)
    elif name == 'mnist':
        return datasets.MNIST(root=root,
                              train=train,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'stl10':
        return datasets.STL10(root=root,
                              split=split,
                              transform=transform,
                              target_transform=target_transform,
                              download=download)
    elif name == 'imagenet':
        if train:
            root = os.path.join(root, 'train')
        else:
            root = os.path.join(root, 'val')
        return datasets.ImageFolder(root=root,
                                    transform=transform,
                                    target_transform=target_transform)
    elif name == 'imagenet_tar':
        if train:
            root = os.path.join(root, 'imagenet_train.tar')
        else:
            root = os.path.join(root, 'imagenet_validation.tar')
        return IndexedFileDataset(root, extract_target_fn=(
            lambda fname: fname.split('/')[0]),
            transform=transform,
            target_transform=target_transform) 
开发者ID:eladhoffer,项目名称:convNet.pytorch,代码行数:47,代码来源:data.py

示例8: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def __init__(
        self,
        height: int = 96,
        width: int = 96,
        data_folder: str = "~/data/st10",
        split: str = "train",
        name: Optional[str] = None,
        batch_size: int = 64,
        shuffle: bool = True,
    ):
        """
        Initializes the STL10 datalayer.

        Args:
            height: image height (DEFAULT: 96)
            width: image width (DEFAULT: 96)
            data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/stl10")
            split: One of 4 splits {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’} (DEFAULT: "train")
            name: Name of the module (DEFAULT: None)
            batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER]
            shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER]
        """
        # Call the base class constructor of DataLayer.
        DataLayerNM.__init__(self, name=name)

        # Store height and width.
        self._height = height
        self._width = width

        # Create transformations: up-scale and transform to tensors.
        STL10_transforms = Compose([Resize((self._height, self._width)), ToTensor()])

        # Get absolute path.
        abs_data_folder = expanduser(data_folder)

        # Create the STL10 dataset object.
        self._dataset = STL10(root=abs_data_folder, split=split, download=True, transform=STL10_transforms)

        # Remember the params passed to DataLoader. :]
        self._batch_size = batch_size
        self._shuffle = shuffle

        # Class names.
        labels = 'airplane bird car cat deer dog horse monkey ship truck'.split(' ')
        word_to_ix = {labels[i]: i for i in range(10)}

        # Reverse mapping.
        self._ix_to_word = {value: key for (key, value) in word_to_ix.items()} 
开发者ID:NVIDIA,项目名称:NeMo,代码行数:50,代码来源:stl10_datalayer.py


注:本文中的torchvision.datasets.STL10属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。