本文整理汇总了Python中mxnet.gluon.data.vision.ImageFolderDataset方法的典型用法代码示例。如果您正苦于以下问题:Python vision.ImageFolderDataset方法的具体用法?Python vision.ImageFolderDataset怎么用?Python vision.ImageFolderDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类mxnet.gluon.data.vision
的用法示例。
在下文中一共展示了vision.ImageFolderDataset方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_imagenet_iterator
# 需要导入模块: from mxnet.gluon.data import vision [as 别名]
# 或者: from mxnet.gluon.data.vision import ImageFolderDataset [as 别名]
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'):
"""Dataset loader with preprocessing."""
train_dir = os.path.join(root, 'train')
train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
logging.info("Loading image folder %s, this may take a bit long...", train_dir)
train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
train_data = DataLoader(train_dataset, batch_size, shuffle=True,
last_batch='discard', num_workers=num_workers)
val_dir = os.path.join(root, 'val')
if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))):
user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
raise ValueError(user_warning)
logging.info("Loading image folder %s, this may take a bit long...", val_dir)
val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype)
示例2: get_caltech101_iterator
# 需要导入模块: from mxnet.gluon.data import vision [as 别名]
# 或者: from mxnet.gluon.data.vision import ImageFolderDataset [as 别名]
def get_caltech101_iterator(batch_size, num_workers, dtype):
def transform(image, label):
# resize the shorter edge to 224, the longer edge will be greater or equal to 224
resized = mx.image.resize_short(image, 224)
# center and crop an area of size (224,224)
cropped, crop_info = mx.image.center_crop(resized, (224, 224))
# transpose the channels to be (3,224,224)
transposed = mx.nd.transpose(cropped, (2, 0, 1))
return transposed, label
training_path, testing_path = get_caltech101_data()
dataset_train = ImageFolderDataset(root=training_path, transform=transform)
dataset_test = ImageFolderDataset(root=testing_path, transform=transform)
train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
return DataLoaderIter(train_data), DataLoaderIter(test_data)