当前位置: 首页>>代码示例>>Python>>正文


Python datasets.FakeData方法代码示例

本文整理汇总了Python中torchvision.datasets.FakeData方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.FakeData方法的具体用法?Python datasets.FakeData怎么用?Python datasets.FakeData使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.FakeData方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: check_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def check_dataset(args):
    transform = transforms.Compose(
        [
            transforms.Resize(args.image_size),
            transforms.CenterCrop(args.image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255)),
        ]
    )

    if args.dataset in {"folder", "mscoco"}:
        train_dataset = datasets.ImageFolder(args.dataroot, transform)
    elif args.dataset == "test":
        train_dataset = datasets.FakeData(
            size=args.batch_size, image_size=(3, 32, 32), num_classes=1, transform=transform
        )
    else:
        raise RuntimeError("Invalid dataset name: {}".format(args.dataset))

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    return train_loader 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:neural_style.py

示例2: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def __init__(self, batch_size=256, subset_size=256, test_batch_size=256):
        trans = transforms.Compose([transforms.ToTensor()])

        root = './data_fake_fmnist'
        train_set = dset.FakeData(image_size=(1, 28, 28),transform=transforms.ToTensor())
        test_set = dset.FakeData(image_size=(1, 28, 28),transform=transforms.ToTensor())

        indices = torch.randperm(len(train_set))[:subset_size]
        train_set = torch.utils.data.Subset(train_set, indices)

        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=batch_size,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=test_batch_size,
            shuffle=False)

        self.name = "fakemnist"
        self.data_dims = [28, 28, 1]
        self.train_size = len(self.train_loader)
        self.test_size = len(self.test_loader)
        self.range = [0.0, 1.0]
        self.batch_size = batch_size
        self.num_training_instances = len(train_set)
        self.num_test_instances = len(test_set)
        self.likelihood_type = 'gaussian'
        self.output_activation_type = 'sigmoid' 
开发者ID:IBM,项目名称:AIX360,代码行数:31,代码来源:test_DIPVAE.py

示例3: make_dataloader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},
                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,
                    normalize=False, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):
    # Make transform
    transform = make_transform(resize=resize, imsize=imsize,
                               centercrop=centercrop, centercrop_size=centercrop_size,
                               totensor=totensor,
                               normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)
    # Make dataset
    if dataset_type in ['folder', 'imagenet', 'lfw']:
        # folder dataset
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.ImageFolder(root=data_path, transform=transform)
    elif dataset_type == 'lsun':
        assert os.path.exists(data_path), "data_path does not exist! Given: " + data_path
        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)
    elif dataset_type == 'cifar10':
        if not os.path.exists(data_path):
            print("data_path does not exist! Given: {}\nDownloading CIFAR10 dataset...".format(data_path))
        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)
    elif dataset_type == 'fake':
        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())
    assert dataset
    num_of_classes = len(dataset.classes)
    print("Data found!  # of images =", len(dataset), ", # of classes =", num_of_classes, ", classes:", dataset.classes)
    # Make dataloader from dataset
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)
    return dataloader, num_of_classes 
开发者ID:voletiv,项目名称:self-attention-GAN-pytorch,代码行数:30,代码来源:utils.py

示例4: test_init_fit_predict

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def test_init_fit_predict(self):
        import torchvision.datasets as datasets
        import torchvision.transforms as transforms
        from lale.lib.pytorch import ResNet50

        transform = transforms.Compose([transforms.ToTensor()])

        data_train = datasets.FakeData(size = 50, num_classes=2 , transform = transform)#, target_transform = transform)
        clf = ResNet50(num_classes=2,num_epochs = 1)
        clf.fit(data_train)
        predicted = clf.predict(data_train) 
开发者ID:IBM,项目名称:lale,代码行数:13,代码来源:test_interoperability.py

示例5: genFakeData

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def genFakeData(
        self, imgSize: Tuple[int, int, int], batch_size: int = 1, num_batches: int = 1
    ) -> DataLoader:
        self.ds = FakeData(
            size=num_batches,
            image_size=imgSize,
            num_classes=2,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
        return DataLoader(self.ds, batch_size=batch_size) 
开发者ID:facebookresearch,项目名称:pytorch-dp,代码行数:12,代码来源:utils_test.py

示例6: setUp_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        self.dl = DataLoader(self.ds, batch_size=self.DATA_SIZE) 
开发者ID:facebookresearch,项目名称:pytorch-dp,代码行数:12,代码来源:per_sample_gradient_clip_test.py

示例7: setUp_data

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def setUp_data(self):
        self.ds = FakeData(
            size=self.DATA_SIZE,
            image_size=(1, 35, 35),
            num_classes=10,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
            ),
        )
        self.dl = DataLoader(self.ds, batch_size=self.BATCH_SIZE) 
开发者ID:facebookresearch,项目名称:pytorch-dp,代码行数:12,代码来源:virtual_step_test.py

示例8: check_dataset

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def check_dataset(dataset, dataroot):
    """

    Args:
        dataset (str): Name of the dataset to use. See CLI help for details
        dataroot (str): root directory where the dataset will be stored.

    Returns:
        dataset (data.Dataset): torchvision Dataset object

    """
    resize = transforms.Resize(64)
    crop = transforms.CenterCrop(64)
    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

    if dataset in {"imagenet", "folder", "lfw"}:
        dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([resize, crop, to_tensor, normalize]))
        nc = 3

    elif dataset == "lsun":
        dataset = dset.LSUN(
            root=dataroot, classes=["bedroom_train"], transform=transforms.Compose([resize, crop, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "cifar10":
        dataset = dset.CIFAR10(
            root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize])
        )
        nc = 3

    elif dataset == "mnist":
        dataset = dset.MNIST(root=dataroot, download=True, transform=transforms.Compose([resize, to_tensor, normalize]))
        nc = 1

    elif dataset == "fake":
        dataset = dset.FakeData(size=256, image_size=(3, 64, 64), transform=to_tensor)
        nc = 3

    else:
        raise RuntimeError("Invalid dataset name: {}".format(dataset))

    return dataset, nc 
开发者ID:pytorch,项目名称:ignite,代码行数:46,代码来源:dcgan.py

示例9: get_data_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import FakeData [as 别名]
def get_data_loader(dataset, dataroot, workers, image_size, batch_size):
    if dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif dataset == 'lsun':
        dataset = dset.LSUN(root=dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif dataset == 'cifar10':
        dataset = dset.CIFAR10(root=dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))
    elif dataset == 'mnist':
        dataset = dset.MNIST(root=dataroot, train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.Resize(image_size),
                                 transforms.CenterCrop(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),
                                                      (0.5, 0.5, 0.5)),
                             ]))
    elif dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, image_size, image_size),
                                transform=transforms.ToTensor())
    else:
        assert False
    assert dataset

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=int(workers))
    return data_loader 
开发者ID:uber-research,项目名称:metropolis-hastings-gans,代码行数:50,代码来源:dcgan_loader.py


注:本文中的torchvision.datasets.FakeData方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。