本文整理汇总了Python中torchvision.datasets方法的典型用法代码示例。如果您正苦于以下问题:Python torchvision.datasets方法的具体用法?Python torchvision.datasets怎么用?Python torchvision.datasets使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision
的用法示例。
在下文中一共展示了torchvision.datasets方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_data
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def load_data(train_split, val_split, root):
# Load Data
if len(train_split) > 0:
dataset = Dataset(train_split, 'training', root, batch_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn)
dataloader.root = root
else:
dataset = None
dataloader = None
val_dataset = Dataset(val_split, 'testing', root, batch_size)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_dataloader.root = root
dataloaders = {'train': dataloader, 'val': val_dataloader}
datasets = {'train': dataset, 'val': val_dataset}
return dataloaders, datasets
# train the model
示例2: get_loaders
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
val_bs = val_bs or bs
train_tfms = [
transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
transforms.RandomHorizontalFlip()
]
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=bs, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, collate_fn=fast_collate,
sampler=train_sampler)
val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
val_loader = torch.utils.data.DataLoader(
val_dataset,
num_workers=workers, pin_memory=True, collate_fn=fast_collate,
batch_sampler=val_sampler)
train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)
return train_loader, val_loader, train_sampler, val_sampler
示例3: __init__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __init__(self, split='train'):
self.split = split
assert(split=='train' or split=='val')
self.name = 'ImageNet_Split_' + split
print('Loading ImageNet dataset - split {0}'.format(split))
transforms_list = []
transforms_list.append(transforms.Scale(256))
transforms_list.append(transforms.CenterCrop(224))
transforms_list.append(lambda x: np.asarray(x))
transforms_list.append(transforms.ToTensor())
mean_pix = [0.485, 0.456, 0.406]
std_pix = [0.229, 0.224, 0.225]
transforms_list.append(transforms.Normalize(mean=mean_pix, std=std_pix))
self.transform = transforms.Compose(transforms_list)
traindir = os.path.join(_IMAGENET_DATASET_DIR, 'train')
valdir = os.path.join(_IMAGENET_DATASET_DIR, 'val')
self.data = datasets.ImageFolder(
traindir if split=='train' else valdir, self.transform)
self.labels = [item[1] for item in self.data.imgs]
示例4: __getitem__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.labels[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
示例5: __getitem__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
示例6: __init__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __init__(self, *datasets):
self.datasets = datasets
示例7: __getitem__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
示例8: __len__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __len__(self):
return min(len(d) for d in self.datasets)
示例9: test_input_block
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def test_input_block():
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = datasets.ImageFolder('/sequoia/data1/yhasson/datasets/test-dataset',
transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
densenet = torchvision.models.densenet121(pretrained=True)
features = densenet.features
seq2d = torch.nn.Sequential(
features.conv0, features.norm0, features.relu0, features.pool0)
seq3d = torch.nn.Sequential(
inflate.inflate_conv(features.conv0, 3),
inflate.inflate_batch_norm(features.norm0),
features.relu0,
inflate.inflate_pool(features.pool0, 1))
loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
frame_nb = 4
for i, (input_2d, target) in enumerate(loader):
target = target.cuda()
target_var = torch.autograd.Variable(target)
input_2d_var = torch.autograd.Variable(input_2d)
out2d = seq2d(input_2d_var)
time_pad = torch.nn.ReplicationPad3d((0, 0, 0, 0, 1, 1))
input_3d = input_2d.unsqueeze(2).repeat(1, 1, frame_nb, 1, 1)
input_3d_var = time_pad(input_3d)
out3d = seq3d(input_3d_var)
expected_out_3d = out2d.data.unsqueeze(2).repeat(1, 1, frame_nb, 1, 1)
out_diff = expected_out_3d - out3d.data
print(out_diff.max())
assert(out_diff.max() < 0.0001)
示例10: __getitem__
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def __getitem__(self, index):
img, label = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, label
示例11: toy
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def toy(dataset,
root='~/data/torchvision/',
transforms=None):
"""Load a train and test datasets from torchvision.dataset.
"""
if not hasattr(torchvision.datasets, dataset):
raise ValueError
loader_def = getattr(torchvision.datasets, dataset)
transform_funcs = []
if transforms is not None:
for transform in transforms:
if not hasattr(torchvision.transforms, transform['def']):
raise ValueError
transform_def = getattr(torchvision.transforms, transform['def'])
transform_funcs.append(transform_def(**transform['kwargs']))
transform_funcs.append(torchvision.transforms.ToTensor())
composed_transform = torchvision.transforms.Compose(transform_funcs)
trainset = loader_def(
root=os.path.expanduser(root), train=True,
download=True, transform=composed_transform)
testset = loader_def(
root=os.path.expanduser(root), train=False,
download=True, transform=composed_transform)
return trainset, testset
示例12: create_validation_set
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def create_validation_set(valdir, batch_size, target_size, rect_val, distributed):
if rect_val:
idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, batch_size)
ar_tfms = [transforms.Resize(int(target_size*1.14)), CropArTfm(idx2ar, target_size)]
val_dataset = ValDataset(valdir, transform=ar_tfms)
val_sampler = DistValSampler(idx_sorted, batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
val_sampler = DistValSampler(list(range(len(val_dataset))), batch_size=batch_size, distributed=distributed)
return val_dataset, val_sampler
示例13: sort_ar
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def sort_ar(valdir):
idx2ar_file = valdir+'/../sorted_idxar.p'
if os.path.isfile(idx2ar_file): return pickle.load(open(idx2ar_file, 'rb'))
print('Creating AR indexes. Please be patient this may take a couple minutes...')
val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset
sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))]
idx_ar = [(i, round(s[0]/s[1], 5)) for i,s in enumerate(sizes)]
sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
print('Done')
return sorted_idxar
示例14: getdataloader
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def getdataloader(datatype, train_db_path, test_db_path, batch_size):
# get transformations
transform_train, transform_test = _getdatatransformsdb(datatype=datatype)
n_classes = 0
# Data loaders
if datatype.lower() == CIFAR10:
print("Using CIFAR10 dataset.")
trainset = torchvision.datasets.CIFAR10(root=train_db_path,
train=True, download=True,
transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=test_db_path,
train=False, download=True,
transform=transform_test)
n_classes = 10
elif datatype.lower() == CIFAR100:
print("Using CIFAR100 dataset.")
trainset = torchvision.datasets.CIFAR100(root=train_db_path,
train=True, download=True,
transform=transform_train)
testset = torchvision.datasets.CIFAR100(root=test_db_path,
train=False, download=True,
transform=transform_test)
n_classes = 100
else:
print("Dataset is not supported.")
return None, None, None
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=4)
return trainloader, testloader, n_classes
示例15: create_dataset
# 需要导入模块: import torchvision [as 别名]
# 或者: from torchvision import datasets [as 别名]
def create_dataset(args, train):
transform = T.Compose([
T.ToTensor(),
T.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0,
np.array([63.0, 62.1, 66.7]) / 255.0),
])
if train:
transform = T.Compose([
T.Pad(4, padding_mode='reflect'),
T.RandomHorizontalFlip(),
T.RandomCrop(32),
transform
])
return getattr(datasets, args.dataset)(args.dataroot, train=train, download=True, transform=transform)