当前位置: 首页>>代码示例>>Python>>正文


Python datasets.LSUN属性代码示例

本文整理汇总了Python中torchvision.datasets.LSUN属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.LSUN属性的具体用法?Python datasets.LSUN怎么用?Python datasets.LSUN使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.LSUN属性的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_lsun_dataloader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
                        batch_size=64):
    """LSUN dataloader with (128, 128) sized images.
    path_to_data : str
        One of 'bedroom_val' or 'bedroom_train'
    """
    # Compose transforms
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    # Get dataset
    lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
                              transform=transform)

    # Create dataloader
    return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 
开发者ID:vandit15,项目名称:Self-Supervised-Gans-Pytorch,代码行数:21,代码来源:dataloaders.py

示例2: get_lsun_dataloader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_lsun_dataloader(path_to_data='/data/dgl/LSUN', dataset='bedroom_train',
                        batch_size=64):
    """LSUN dataloader with (128, 128) sized images.

    path_to_data : str
        One of 'bedroom_val' or 'bedroom_train'
    """
    # Compose transforms
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    # Get dataset
    lsun_dset = datasets.LSUN(root=path_to_data, classes=[dataset],
                              transform=transform)

    # Create dataloader
    return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 
开发者ID:dingguanglei,项目名称:jdit,代码行数:22,代码来源:dataset.py

示例3: load_lsun

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_lsun(self, classes=['church_outdoor_train','classroom_train']):
        transforms = self.transform(True, True, True, False)
        dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
        return dataset 
开发者ID:sxhxliang,项目名称:BigGAN-pytorch,代码行数:6,代码来源:data_loader.py

示例4: make_dataloader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
                    normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor,
                               normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        if not os.path.exists(data_path):
            print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes 
开发者ID:voletiv,项目名称:self-attention-GAN-pytorch,代码行数:30,代码来源:utils.py

示例5: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def __init__(self, options):
        transform_list = []
        if options.image_size is not None:
            transform_list.append(transforms.Resize((options.image_size, options.image_size)))
            # transform_list.append(transforms.CenterCrop(options.image_size))
        transform_list.append(transforms.ToTensor())
        if options.image_colors == 1:
            transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
        elif options.image_colors == 3:
            transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        transform = transforms.Compose(transform_list)

        if options.dataset == 'mnist':
            dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'emnist':
            # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
            datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
            dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform)
        elif options.dataset == 'fashion-mnist':
            dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'lsun':
            training_class = options.image_class + '_train'
            dataset =  datasets.LSUN(options.data_dir, classes=[training_class], transform=transform)
        elif options.dataset == 'cifar10':
            dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'cifar100':
            dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform)
        else:
            dataset = datasets.ImageFolder(root=options.data_dir, transform=transform)

        self.dataloader = DataLoader(
            dataset,
            batch_size=options.batch_size,
            num_workers=options.loader_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=options.pin_memory
        )
        self.iterator = iter(self.dataloader) 
开发者ID:unicredit,项目名称:ganzo,代码行数:41,代码来源:data.py

示例6: load_lsun

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_lsun(self, classes='church_outdoor_train'):
        transforms = self.transform(True, True, True, False)
        dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
        return dataset 
开发者ID:iSarmad,项目名称:RL-GAN-Net,代码行数:6,代码来源:data_loader.py

示例7: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_dataset(name, data_dir, size=64, lsun_categories=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())),
    ])

    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir, train=True, download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir, transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented

    return dataset, nlabels 
开发者ID:akanazawa,项目名称:vgan,代码行数:36,代码来源:inputs.py

示例8: get_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root="./data", train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
                tforms.ToTensor(),
                add_noise,
            ]), download=True
        )
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
                tforms.ToTensor(),
                add_noise,
            ])
        )
    data_shape = (im_dim, im_size, im_size)
    if not args.conv:
        data_shape = (im_dim * im_size * im_size,)

    return train_set, data_shape 
开发者ID:rtqichen,项目名称:ffjord,代码行数:37,代码来源:viz_multiscale.py

示例9: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5):
    # torch issue
    # https://github.com/pytorch/pytorch/issues/22866
    torch.set_num_threads(1)
    if image_data_type == 'lsun':
        dataset =  datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
    elif image_data_type == "image_folder":
        dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
    else:
        raise ValueError("Invalid image data type")
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
    return dataset_loader 
开发者ID:jalola,项目名称:improved-wgan-pytorch,代码行数:14,代码来源:training_utils.py

示例10: load_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def load_data(path_to_folder, classes):
    data_transform = transforms.Compose([
                 transforms.Scale(64),
                 transforms.CenterCrop(64),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
                ])
    if IMAGE_DATA_SET == 'lsun':
        dataset =  datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
    else:
        dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
    return dataset_loader 
开发者ID:jalola,项目名称:improved-wgan-pytorch,代码行数:15,代码来源:congan_train.py

示例11: check_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def check_dataset(dataset, dataroot):
    """

    Args:
        dataset (str): Name of the dataset to use. See CLI help for details
        dataroot (str): root directory where the dataset will be stored.

    Returns:
        dataset (data.Dataset): torchvision Dataset object

    """
    resize = transforms.Resize(64)
    crop = transforms.CenterCrop(64)
    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if dataset in {"imagenet", "folder", "lfw"}:
        dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize]))
        nc = 3

    elif dataset == "lsun":
        dataset = dset.LSUN(
            root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "cifar10":
        dataset = dset.CIFAR10(
            root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "mnist":
        dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]))
        nc = 1

    elif dataset == "fake":
        dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
        nc = 3

    else:
        raise RuntimeError("Invalid dataset name: {}".format(dataset))

    return dataset, nc 
开发者ID:pytorch,项目名称:ignite,代码行数:46,代码来源:dcgan.py

示例12: __getDataSet

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def __getDataSet(opt):
    if isDebug: print(f"Getting dataset: {opt.dataset} ... ")

    dataset = None
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        traindir = os.path.join(opt.dataroot, f"{opt.dataroot}/train")
        valdir = os.path.join(opt.dataroot, f"{opt.dataroot}/val")
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = dset.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(opt.imageSize),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Load pre-trained state dict
        if opt.load_dict:
            opt.netD = NETD_CIFAR10
            opt.netG = NETG_CIFAR10
    elif opt.dataset == 'mnist':
        opt.nc = 1
        opt.imageSize = 32
        dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Update opt params for mnist
        if opt.load_dict:
            opt.netD = NETD_MNIST
            opt.netG = NETG_MNIST

    return dataset 
开发者ID:raeidsaqur,项目名称:CapsGAN,代码行数:62,代码来源:main.py

示例13: get_data_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import LSUN [as 别名]
def get_data_loader(dataset, dataroot, workers, image_size, batch_size):
    if dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif dataset == 'lsun':
        dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset == 'mnist':
        dataset = dset.MNIST(root=dataroot, train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.CenterCrop(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)),
                             ]))
    elif dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
    else:
        assert False
    assert dataset

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=int(workers))
    return data_loader 
开发者ID:uber-research,项目名称:metropolis-hastings-gans,代码行数:50,代码来源:dcgan_loader.py


注:本文中的torchvision.datasets.LSUN属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。