當前位置: 首頁>>代碼示例>>Python>>正文


Python datasets.LSUN屬性代碼示例

本文整理匯總了Python中torchvision.datasets.LSUN屬性的典型用法代碼示例。如果您正苦於以下問題:Python datasets.LSUN屬性的具體用法?Python datasets.LSUN怎麽用?Python datasets.LSUN使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.LSUN屬性的13個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_lsun_dataloader

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def get_lsun_dataloader(path_to_data='../lsun', dataset='bedroom_train',
                        batch_size=64):
    """LSUN dataloader with (128, 128) sized images.
    path_to_data : str
        One of 'bedroom_val' or 'bedroom_train'
    """
    # Compose transforms
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    # Get dataset
    lsun_dset = datasets.LSUN(db_path=path_to_data, classes=[dataset],
                              transform=transform)

    # Create dataloader
    return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 
開發者ID:vandit15,項目名稱:Self-Supervised-Gans-Pytorch,代碼行數:21,代碼來源:dataloaders.py

示例2: get_lsun_dataloader

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def get_lsun_dataloader(path_to_data='/data/dgl/LSUN', dataset='bedroom_train',
                        batch_size=64):
    """LSUN dataloader with (128, 128) sized images.

    path_to_data : str
        One of 'bedroom_val' or 'bedroom_train'
    """
    # Compose transforms
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor()
    ])

    # Get dataset
    lsun_dset = datasets.LSUN(root=path_to_data, classes=[dataset],
                              transform=transform)

    # Create dataloader
    return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 
開發者ID:dingguanglei,項目名稱:jdit,代碼行數:22,代碼來源:dataset.py

示例3: load_lsun

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def load_lsun(self, classes=['church_outdoor_train','classroom_train']):
        transforms = self.transform(True, True, True, False)
        dataset = dsets.LSUN(self.path, classes=classes, transform=transforms)
        return dataset 
開發者ID:sxhxliang,項目名稱:BigGAN-pytorch,代碼行數:6,代碼來源:data_loader.py

示例4: make_dataloader

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
                    normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor,
                               normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        if not os.path.exists(data_path):
            print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes 
開發者ID:voletiv,項目名稱:self-attention-GAN-pytorch,代碼行數:30,代碼來源:utils.py

示例5: __init__

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def __init__(self, options):
        transform_list = []
        if options.image_size is not None:
            transform_list.append(transforms.Resize((options.image_size, options.image_size)))
            # transform_list.append(transforms.CenterCrop(options.image_size))
        transform_list.append(transforms.ToTensor())
        if options.image_colors == 1:
            transform_list.append(transforms.Normalize(mean=[0.5], std=[0.5]))
        elif options.image_colors == 3:
            transform_list.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        transform = transforms.Compose(transform_list)

        if options.dataset == 'mnist':
            dataset = datasets.MNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'emnist':
            # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
            datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
            dataset = datasets.EMNIST(options.data_dir, split=options.image_class, train=True, download=True, transform=transform)
        elif options.dataset == 'fashion-mnist':
            dataset = datasets.FashionMNIST(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'lsun':
            training_class = options.image_class + '_train'
            dataset =  datasets.LSUN(options.data_dir, classes=[training_class], transform=transform)
        elif options.dataset == 'cifar10':
            dataset = datasets.CIFAR10(options.data_dir, train=True, download=True, transform=transform)
        elif options.dataset == 'cifar100':
            dataset = datasets.CIFAR100(options.data_dir, train=True, download=True, transform=transform)
        else:
            dataset = datasets.ImageFolder(root=options.data_dir, transform=transform)

        self.dataloader = DataLoader(
            dataset,
            batch_size=options.batch_size,
            num_workers=options.loader_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=options.pin_memory
        )
        self.iterator = iter(self.dataloader) 
開發者ID:unicredit,項目名稱:ganzo,代碼行數:41,代碼來源:data.py

示例6: load_lsun

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def load_lsun(self, classes='church_outdoor_train'):
        transforms = self.transform(True, True, True, False)
        dataset = dsets.LSUN(self.path, classes=[classes], transform=transforms)
        return dataset 
開發者ID:iSarmad,項目名稱:RL-GAN-Net,代碼行數:6,代碼來源:data_loader.py

示例7: get_dataset

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def get_dataset(name, data_dir, size=64, lsun_categories=None):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())),
    ])

    if name == 'image':
        dataset = datasets.ImageFolder(data_dir, transform)
        nlabels = len(dataset.classes)
    elif name == 'npy':
        # Only support normalization for now
        dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
        nlabels = len(dataset.classes)
    elif name == 'cifar10':
        dataset = datasets.CIFAR10(root=data_dir, train=True, download=True,
                                   transform=transform)
        nlabels = 10
    elif name == 'lsun':
        if lsun_categories is None:
            lsun_categories = 'train'
        dataset = datasets.LSUN(data_dir, lsun_categories, transform)
        nlabels = len(dataset.classes)
    elif name == 'lsun_class':
        dataset = datasets.LSUNClass(data_dir, transform,
                                     target_transform=(lambda t: 0))
        nlabels = 1
    else:
        raise NotImplemented

    return dataset, nlabels 
開發者ID:akanazawa,項目名稱:vgan,代碼行數:36,代碼來源:inputs.py

示例8: get_dataset

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def get_dataset(args):
    trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise])

    if args.data == "mnist":
        im_dim = 1
        im_size = 28 if args.imagesize is None else args.imagesize
        train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True)
    elif args.data == "cifar10":
        im_dim = 3
        im_size = 32 if args.imagesize is None else args.imagesize
        train_set = dset.CIFAR10(
            root="./data", train=True, transform=tforms.Compose([
                tforms.Resize(im_size),
                tforms.RandomHorizontalFlip(),
                tforms.ToTensor(),
                add_noise,
            ]), download=True
        )
    elif args.data == 'lsun_church':
        im_dim = 3
        im_size = 64 if args.imagesize is None else args.imagesize
        train_set = dset.LSUN(
            'data', ['church_outdoor_train'], transform=tforms.Compose([
                tforms.Resize(96),
                tforms.RandomCrop(64),
                tforms.Resize(im_size),
                tforms.ToTensor(),
                add_noise,
            ])
        )
    data_shape = (im_dim, im_size, im_size)
    if not args.conv:
        data_shape = (im_dim * im_size * im_size,)

    return train_set, data_shape 
開發者ID:rtqichen,項目名稱:ffjord,代碼行數:37,代碼來源:viz_multiscale.py

示例9: load_data

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def load_data(image_data_type, path_to_folder, data_transform, batch_size, classes=None, num_workers=5):
    # torch issue
    # https://github.com/pytorch/pytorch/issues/22866
    torch.set_num_threads(1)
    if image_data_type == 'lsun':
        dataset =  datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
    elif image_data_type == "image_folder":
        dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
    else:
        raise ValueError("Invalid image data type")
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True)
    return dataset_loader 
開發者ID:jalola,項目名稱:improved-wgan-pytorch,代碼行數:14,代碼來源:training_utils.py

示例10: load_data

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def load_data(path_to_folder, classes):
    data_transform = transforms.Compose([
                 transforms.Scale(64),
                 transforms.CenterCrop(64),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
                ])
    if IMAGE_DATA_SET == 'lsun':
        dataset =  datasets.LSUN(path_to_folder, classes=classes, transform=data_transform)
    else:
        dataset = datasets.ImageFolder(root=path_to_folder,transform=data_transform)
    dataset_loader = torch.utils.data.DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, num_workers=5, drop_last=True, pin_memory=True)
    return dataset_loader 
開發者ID:jalola,項目名稱:improved-wgan-pytorch,代碼行數:15,代碼來源:congan_train.py

示例11: check_dataset

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def check_dataset(dataset, dataroot):
    """

    Args:
        dataset (str): Name of the dataset to use. See CLI help for details
        dataroot (str): root directory where the dataset will be stored.

    Returns:
        dataset (data.Dataset): torchvision Dataset object

    """
    resize = transforms.Resize(64)
    crop = transforms.CenterCrop(64)
    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if dataset in {"imagenet", "folder", "lfw"}:
        dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize]))
        nc = 3

    elif dataset == "lsun":
        dataset = dset.LSUN(
            root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "cifar10":
        dataset = dset.CIFAR10(
            root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "mnist":
        dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]))
        nc = 1

    elif dataset == "fake":
        dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
        nc = 3

    else:
        raise RuntimeError("Invalid dataset name: {}".format(dataset))

    return dataset, nc 
開發者ID:pytorch,項目名稱:ignite,代碼行數:46,代碼來源:dcgan.py

示例12: __getDataSet

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def __getDataSet(opt):
    if isDebug: print(f"Getting dataset: {opt.dataset} ... ")

    dataset = None
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        traindir = os.path.join(opt.dataroot, f"{opt.dataroot}/train")
        valdir = os.path.join(opt.dataroot, f"{opt.dataroot}/val")
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = dset.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(opt.imageSize),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Load pre-trained state dict
        if opt.load_dict:
            opt.netD = NETD_CIFAR10
            opt.netG = NETG_CIFAR10
    elif opt.dataset == 'mnist':
        opt.nc = 1
        opt.imageSize = 32
        dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Update opt params for mnist
        if opt.load_dict:
            opt.netD = NETD_MNIST
            opt.netG = NETG_MNIST

    return dataset 
開發者ID:raeidsaqur,項目名稱:CapsGAN,代碼行數:62,代碼來源:main.py

示例13: get_data_loader

# 需要導入模塊: from torchvision import datasets [as 別名]
# 或者: from torchvision.datasets import LSUN [as 別名]
def get_data_loader(dataset, dataroot, workers, image_size, batch_size):
    if dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif dataset == 'lsun':
        dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset == 'mnist':
        dataset = dset.MNIST(root=dataroot, train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.CenterCrop(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)),
                             ]))
    elif dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
    else:
        assert False
    assert dataset

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=int(workers))
    return data_loader 
開發者ID:uber-research,項目名稱:metropolis-hastings-gans,代碼行數:50,代碼來源:dcgan_loader.py


注:本文中的torchvision.datasets.LSUN屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。