当前位置: 首页>>代码示例>>Python>>正文


Python vision.ImageFolderDataset方法代码示例

本文整理汇总了Python中mxnet.gluon.data.vision.ImageFolderDataset方法的典型用法代码示例。如果您正苦于以下问题:Python vision.ImageFolderDataset方法的具体用法?Python vision.ImageFolderDataset怎么用?Python vision.ImageFolderDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在mxnet.gluon.data.vision的用法示例。


在下文中一共展示了vision.ImageFolderDataset方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_imagenet_iterator

# 需要导入模块: from mxnet.gluon.data import vision [as 别名]
# 或者: from mxnet.gluon.data.vision import ImageFolderDataset [as 别名]
def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='float32'):
    """Dataset loader with preprocessing."""
    train_dir = os.path.join(root, 'train')
    train_transform, val_transform = get_imagenet_transforms(data_shape, dtype)
    logging.info("Loading image folder %s, this may take a bit long...", train_dir)
    train_dataset = ImageFolderDataset(train_dir, transform=train_transform)
    train_data = DataLoader(train_dataset, batch_size, shuffle=True,
                            last_batch='discard', num_workers=num_workers)
    val_dir = os.path.join(root, 'val')
    if not os.path.isdir(os.path.expanduser(os.path.join(root, 'val', 'n01440764'))):
        user_warning = 'Make sure validation images are stored in one subdir per category, a helper script is available at https://git.io/vNQv1'
        raise ValueError(user_warning)
    logging.info("Loading image folder %s, this may take a bit long...", val_dir)
    val_dataset = ImageFolderDataset(val_dir, transform=val_transform)
    val_data = DataLoader(val_dataset, batch_size, last_batch='keep', num_workers=num_workers)
    return DataLoaderIter(train_data, dtype), DataLoaderIter(val_data, dtype) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:18,代码来源:data.py

示例2: get_caltech101_iterator

# 需要导入模块: from mxnet.gluon.data import vision [as 别名]
# 或者: from mxnet.gluon.data.vision import ImageFolderDataset [as 别名]
def get_caltech101_iterator(batch_size, num_workers, dtype):
    def transform(image, label):
        # resize the shorter edge to 224, the longer edge will be greater or equal to 224
        resized = mx.image.resize_short(image, 224)
        # center and crop an area of size (224,224)
        cropped, crop_info = mx.image.center_crop(resized, (224, 224))
        # transpose the channels to be (3,224,224)
        transposed = mx.nd.transpose(cropped, (2, 0, 1))
        return transposed, label

    training_path, testing_path = get_caltech101_data()
    dataset_train = ImageFolderDataset(root=training_path, transform=transform)
    dataset_test = ImageFolderDataset(root=testing_path, transform=transform)

    train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
    test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
    return DataLoaderIter(train_data), DataLoaderIter(test_data) 
开发者ID:awslabs,项目名称:dynamic-training-with-apache-mxnet-on-aws,代码行数:19,代码来源:data.py


注:本文中的mxnet.gluon.data.vision.ImageFolderDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。