本文整理汇总了Python中torchvision.datasets.STL10属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.STL10属性的具体用法?Python datasets.STL10怎么用?Python datasets.STL10使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类torchvision.datasets
的用法示例。
在下文中一共展示了datasets.STL10属性的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(name, split='train', transform=None,
target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH):
train = (split == 'train')
root = os.path.join(datasets_path, name)
if name == 'cifar10':
return datasets.CIFAR10(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'cifar100':
return datasets.CIFAR100(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'mnist':
return datasets.MNIST(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'stl10':
return datasets.STL10(root=root,
split=split,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'imagenet':
if train:
root = os.path.join(root, 'train')
else:
root = os.path.join(root, 'val')
return datasets.ImageFolder(root=root,
transform=transform,
target_transform=target_transform)
示例2: get_encoder_size
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_encoder_size(dataset):
if dataset in [Dataset.C10, Dataset.C100]:
return 32
if dataset == Dataset.STL10:
return 64
if dataset in [Dataset.IN128, Dataset.PLACES205]:
return 128
raise RuntimeError("Couldn't get encoder size, unknown dataset: {}".format(dataset))
示例3: _get_directories
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def _get_directories(dataset, input_dir):
if dataset in [Dataset.C10, Dataset.C100, Dataset.STL10]:
# Pytorch will download those datasets automatically
return None, None
if dataset == Dataset.IN128:
train_dir = os.path.join(input_dir, 'ILSVRC2012_img_train/')
val_dir = os.path.join(input_dir, 'ILSVRC2012_img_val/')
elif dataset == Dataset.PLACES205:
train_dir = os.path.join(input_dir, 'places205_256_train/')
val_dir = os.path.join(input_dir, 'places205_256_val/')
else:
raise 'Data directories for dataset ' + dataset + ' are not defined'
return train_dir, val_dir
示例4: get
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get(batch_size, data_root='/mnt/local0/public_dataset/pytorch/', train=True, val=True, **kwargs):
data_root = os.path.expanduser(os.path.join(data_root, 'stl10-data'))
num_workers = kwargs.setdefault('num_workers', 1)
kwargs.pop('input_size', None)
print("Building STL10 data loader with {} workers".format(num_workers))
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
datasets.STL10(
root=data_root, split='train', download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(96),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])),
batch_size=batch_size, shuffle=True, **kwargs)
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
datasets.STL10(
root=data_root, split='test', download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])),
batch_size=batch_size, shuffle=False, **kwargs)
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
示例5: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(self):
"""
Uses torchvision.datasets.STL to load dataset.
Downloads dataset if doesn't exist already.
Returns:
torch.utils.data.TensorDataset: trainset, valset
"""
trainset = datasets.STL10('datasets/STL10/train/', split='train', transform=self.train_transforms,
target_transform=None, download=True)
valset = datasets.STL10('datasets/STL10/test/', split='test', transform=self.val_transforms,
target_transform=None, download=True)
return trainset, valset
示例6: __init__
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def __init__(self, args, cur_img_size=None):
img_size = cur_img_size if cur_img_size else args.img_size
if args.dataset.lower() == 'cifar10':
Dt = datasets.CIFAR10
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
args.n_classes = 10
elif args.dataset.lower() == 'stl10':
Dt = datasets.STL10
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
else:
raise NotImplementedError('Unknown dataset: {}'.format(args.dataset))
if args.dataset.lower() == 'stl10':
self.train = torch.utils.data.DataLoader(
Dt(root=args.data_path, split='train+unlabeled', transform=transform, download=True),
batch_size=args.dis_batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)
self.valid = torch.utils.data.DataLoader(
Dt(root=args.data_path, split='test', transform=transform),
batch_size=args.dis_batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True)
self.test = self.valid
else:
self.train = torch.utils.data.DataLoader(
Dt(root=args.data_path, train=True, transform=transform, download=True),
batch_size=args.dis_batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)
self.valid = torch.utils.data.DataLoader(
Dt(root=args.data_path, train=False, transform=transform),
batch_size=args.dis_batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True)
self.test = self.valid
示例7: get_dataset
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def get_dataset(name, split='train', transform=None,
target_transform=None, download=True, datasets_path='~/Datasets'):
train = (split == 'train')
root = os.path.join(os.path.expanduser(datasets_path), name)
if name == 'cifar10':
return datasets.CIFAR10(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'cifar100':
return datasets.CIFAR100(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'mnist':
return datasets.MNIST(root=root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'stl10':
return datasets.STL10(root=root,
split=split,
transform=transform,
target_transform=target_transform,
download=download)
elif name == 'imagenet':
if train:
root = os.path.join(root, 'train')
else:
root = os.path.join(root, 'val')
return datasets.ImageFolder(root=root,
transform=transform,
target_transform=target_transform)
elif name == 'imagenet_tar':
if train:
root = os.path.join(root, 'imagenet_train.tar')
else:
root = os.path.join(root, 'imagenet_validation.tar')
return IndexedFileDataset(root, extract_target_fn=(
lambda fname: fname.split('/')[0]),
transform=transform,
target_transform=target_transform)
示例8: __init__
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import STL10 [as 别名]
def __init__(
self,
height: int = 96,
width: int = 96,
data_folder: str = "~/data/st10",
split: str = "train",
name: Optional[str] = None,
batch_size: int = 64,
shuffle: bool = True,
):
"""
Initializes the STL10 datalayer.
Args:
height: image height (DEFAULT: 96)
width: image width (DEFAULT: 96)
data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/stl10")
split: One of 4 splits {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’} (DEFAULT: "train")
name: Name of the module (DEFAULT: None)
batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER]
shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER]
"""
# Call the base class constructor of DataLayer.
DataLayerNM.__init__(self, name=name)
# Store height and width.
self._height = height
self._width = width
# Create transformations: up-scale and transform to tensors.
STL10_transforms = Compose([Resize((self._height, self._width)), ToTensor()])
# Get absolute path.
abs_data_folder = expanduser(data_folder)
# Create the STL10 dataset object.
self._dataset = STL10(root=abs_data_folder, split=split, download=True, transform=STL10_transforms)
# Remember the params passed to DataLoader. :]
self._batch_size = batch_size
self._shuffle = shuffle
# Class names.
labels = 'airplane bird car cat deer dog horse monkey ship truck'.split(' ')
word_to_ix = {labels[i]: i for i in range(10)}
# Reverse mapping.
self._ix_to_word = {value: key for (key, value) in word_to_ix.items()}