当前位置: 首页>>代码示例>>Python>>正文


Python datasets.ImageFolder方法代码示例

本文整理汇总了Python中torchvision.datasets.ImageFolder方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.ImageFolder方法的具体用法?Python datasets.ImageFolder怎么用?Python datasets.ImageFolder使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.ImageFolder方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def load_data(root_path, dir, batch_size, phase):
    transform_dict = {
        'src': transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])}
    data = datasets.ImageFolder(root=root_path + dir, transform=transform_dict[phase])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
    return data_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:20,代码来源:data_loader.py

示例2: train_fine_tuning

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=4):
    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs), batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs), batch_size)
    loss = torch.nn.CrossEntropyLoss()
    utils.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs) 
开发者ID:wdxtub,项目名称:deep-learning-note,代码行数:7,代码来源:48_fine_tune_hotdog.py

示例3: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def __init__(self, config):
        self.config = config

        if config.data_mode == "imgs":
            transform = v_transforms.Compose(
                [v_transforms.ToTensor(),
                 v_transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

            dataset = v_datasets.ImageFolder(self.config.data_folder, transform=transform)

            self.dataset_len = len(dataset)

            self.num_iterations = (self.dataset_len + config.batch_size - 1) // config.batch_size

            self.loader = DataLoader(dataset,
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     num_workers=config.data_loader_workers,
                                     pin_memory=config.pin_memory)
        elif config.data_mode == "numpy":
            raise NotImplementedError("This mode is not implemented YET")
        else:
            raise Exception("Please specify in the json a specified mode in data_mode") 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:25,代码来源:celebA.py

示例4: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def load_data(data_folder, batch_size, phase='train', train_val_split=True, train_ratio=.8):
    transform_dict = {
        'train': transforms.Compose(
            [transforms.Resize(256),
             transforms.RandomCrop(224),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ]),
        'test': transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])}

    data = datasets.ImageFolder(root=data_folder, transform=transform_dict[phase])
    if phase == 'train':
        if train_val_split:
            train_size = int(train_ratio * len(data))
            test_size = len(data) - train_size
            data_train, data_val = torch.utils.data.random_split(data, [train_size, test_size])
            train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=True,
                                                    num_workers=4)
            val_loader = torch.utils.data.DataLoader(data_val, batch_size=batch_size, shuffle=False, drop_last=False,
                                                num_workers=4)
            return [train_loader, val_loader]
        else:
            train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True,
                                                    num_workers=4)
            return train_loader
    else: 
        test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=False,
                                                    num_workers=4)
        return test_loader

## Below are for ImageCLEF datasets 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:40,代码来源:data_load.py

示例5: load_training

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def load_training(root_path, dir, batch_size, kwargs):

    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:12,代码来源:data_loader.py

示例6: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def load_data(data_folder, batch_size, train, kwargs):
    transform = {
        'train': transforms.Compose(
            [transforms.Resize([256, 256]),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])]),
        'test': transforms.Compose(
            [transforms.Resize([224, 224]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])])
        }
    data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)
    return data_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:20,代码来源:data_loader.py

示例7: main

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def main():
    args = parse_args()
    trans = transforms.Compose([
        preprocessing.ExifOrientationNormalize(),
        transforms.Resize(1024)
    ])

    images = datasets.ImageFolder(root=args.input_folder)
    images.idx_to_class = {v: k for k, v in images.class_to_idx.items()}
    create_dirs(args.output_folder, images.classes)

    mtcnn = MTCNN(prewhiten=False)

    for idx, (path, y) in enumerate(images.imgs):
        print("Aligning {} {}/{} ".format(path, idx + 1, len(images)), end='')
        aligned_path = args.output_folder + os.path.sep + images.idx_to_class[y] + os.path.sep + os.path.basename(path)
        if not os.path.exists(aligned_path):
            img = mtcnn(img=trans(Image.open(path).convert('RGB')), save_path=aligned_path)
            print("No face found" if img is None else '')
        else:
            print('Already aligned') 
开发者ID:arsfutura,项目名称:face-recognition,代码行数:23,代码来源:align_mtcnn.py

示例8: get_imagenet_dataflow

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_imagenet_dataflow(is_train, data_dir, batch_size, augmentor, workers=18, is_distributed=False):

    workers = min(workers, multiprocessing.cpu_count())
    sampler = None
    shuffle = False
    if is_train:
        dataset = datasets.ImageFolder(data_dir, augmentor)
        sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
        shuffle = sampler is None
    else:
        dataset = datasets.ImageFolder(data_dir, augmentor)

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                                              num_workers=workers, pin_memory=True, sampler=sampler)

    return data_loader 
开发者ID:IBM,项目名称:BigLittleNet,代码行数:18,代码来源:imagenet_utils.py

示例9: make_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def make_dataset():
    if opt.dataset in ("imagenet", "dog_and_cat_64", "dog_and_cat_128"):
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = ImageFolder(opt.root, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=False, num_workers=opt.workers)
    elif opt.dataset == "cifar10":
        trans = tfs.Compose([
            tfs.Resize(opt.img_width),
            tfs.ToTensor(),
            tfs.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])])
        data = CIFAR10(root=opt.root, train=True, download=False, transform=trans)
        loader = DataLoader(data, batch_size=100, shuffle=True, num_workers=opt.workers)
    else:
        raise ValueError(f"Unknown dataset: {opt.dataset}")
    return loader 
开发者ID:xuanqing94,项目名称:RobGAN,代码行数:20,代码来源:acc_under_attack.py

示例10: get_train_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_train_loader(batch_size=25):
    if hvd.rank() == 0:
        print('Train: ', end="")
    train_dataset = datasets.ImageFolder(root=datapath+'/train',
                                         transform=data_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              sampler=train_sampler, num_workers=4, pin_memory=True)

    if hvd.rank() == 0:
        print('Found', len(train_dataset), 'images belonging to',
              len(train_dataset.classes), 'classes')
    return train_loader, train_sampler 
开发者ID:csc-training,项目名称:intro-to-dl,代码行数:18,代码来源:pytorch_dvc_cnn_hvd.py

示例11: get_pytorch_train_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_pytorch_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None, input_size=224):
    traindir = os.path.join(data_path, 'train')
    train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                ]))

    if torch.distributed.is_initialized():
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
            num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    return PrefetchedWrapper(train_loader), len(train_loader) 
开发者ID:d-li14,项目名称:HBONet,代码行数:21,代码来源:dataloaders.py

示例12: get_pytorch_val_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_pytorch_val_loader(data_path, batch_size, workers=5, _worker_init_fn=None, input_size=224):
    valdir = os.path.join(data_path, 'val')
    val_dataset = datasets.ImageFolder(
            valdir, transforms.Compose([
                transforms.Resize(int(input_size / 0.875)),
                transforms.CenterCrop(input_size),
                ]))

    if torch.distributed.is_initialized():
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        val_sampler = None

    val_loader = torch.utils.data.DataLoader(
            val_dataset,
            sampler=val_sampler,
            batch_size=batch_size, shuffle=False,
            num_workers=workers, worker_init_fn=_worker_init_fn, pin_memory=True,
            collate_fn=fast_collate)

    return PrefetchedWrapper(val_loader), len(val_loader) 
开发者ID:d-li14,项目名称:HBONet,代码行数:23,代码来源:dataloaders.py

示例13: get_loaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False):
    val_bs = val_bs or bs
    train_tfms = [
            transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
            transforms.RandomHorizontalFlip()
    ]
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
    train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=(train_sampler is None),
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        sampler=train_sampler)

    val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        num_workers=workers, pin_memory=True, collate_fn=fast_collate, 
        batch_sampler=val_sampler)

    train_loader = BatchTransformDataLoader(train_loader, fp16=fp16)
    val_loader = BatchTransformDataLoader(val_loader, fp16=fp16)

    return train_loader, val_loader, train_sampler, val_sampler 
开发者ID:cybertronai,项目名称:imagenet18_old,代码行数:26,代码来源:dataloader.py

示例14: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def get_dataset():
    # lazy import
    import torchvision.datasets as datasets
    import torchvision.transforms as transforms
    if not args.imagenet_path:
        raise Exception('Please provide valid ImageNet path!')
    print('=> Preparing data..')
    valdir = os.path.join(args.imagenet_path, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    input_size = 224
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(int(input_size / 0.875)),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.n_worker, pin_memory=True)
    n_class = 1000
    return val_loader, n_class 
开发者ID:mit-han-lab,项目名称:amc-models,代码行数:25,代码来源:eval_mobilenet_torch.py

示例15: load_testing

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import ImageFolder [as 别名]
def load_testing(root_path, dir, batch_size, kwargs):
    start_center = (256 - 224 - 1) / 2
    transform = transforms.Compose(
        [transforms.Resize([224, 224]),
         PlaceCrop(224, start_center, start_center),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=False, **kwargs)
    return test_loader 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:11,代码来源:data_loader.py


注:本文中的torchvision.datasets.ImageFolder方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。