当前位置: 首页>>代码示例>>Python>>正文


Python datasets.MNIST属性代码示例

本文整理汇总了Python中torchvision.datasets.MNIST属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.MNIST属性的具体用法?Python datasets.MNIST怎么用?Python datasets.MNIST使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在torchvision.datasets的用法示例。


在下文中一共展示了datasets.MNIST属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_mnist_m

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_mnist_m(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get MNIST-M dataset loader."""
    # image pre-processing
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std)])

    # dataset and data loader
    mnist_m_dataset = MNIST_M(root=cfg.data_root,
                              train=train,
                              transform=pre_process,
                              download=True)

    if get_dataset:
        return mnist_m_dataset
    else:
        mnist_m_data_loader = torch.utils.data.DataLoader(
            dataset=mnist_m_dataset,
            batch_size=batch_size,
            shuffle=True)
        return mnist_m_data_loader 
开发者ID:corenel,项目名称:pytorch-atda,代码行数:24,代码来源:mnist_m.py

示例2: mnist_noniid

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def mnist_noniid(dataset, num_users):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]
    idxs = idxs_labels[0,:]

    # divide and assign
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
    return dict_users 
开发者ID:shaoxiongji,项目名称:federated-learning,代码行数:27,代码来源:sampling.py

示例3: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 100
        self.learning_rate = 0.01
        self.sgd_momentum = 0.9
        self.log_interval = 100
        # Fetch MNIST data set.
        self.train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('/tmp/mnist/data', train=True, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
                ])),
            batch_size=self.batch_size,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('/tmp/mnist/data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
                ])),
            batch_size=self.test_batch_size,
            shuffle=True)
        self.network = Net()

    # Train the network for several epochs, validating after each epoch. 
开发者ID:aimuch,项目名称:iAI,代码行数:26,代码来源:model.py

示例4: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 100
        self.learning_rate = 0.0025
        self.sgd_momentum = 0.9
        self.log_interval = 100
        # Fetch MNIST data set.
        self.train_loader = torch.utils.data.DataLoader(
            datasets.MNIST('/tmp/mnist/data', train=True, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
                ])),
            batch_size=self.batch_size,
            shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('/tmp/mnist/data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
                ])),
            batch_size=self.test_batch_size,
            shuffle=True)
        self.network = Net()

    # Train the network for one or more epochs, validating after each epoch. 
开发者ID:aimuch,项目名称:iAI,代码行数:26,代码来源:model.py

示例5: loaders_mnist

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def loaders_mnist(dataset, batch_size=64, cuda=0,
                  train_size=50000, val_size=10000, test_size=10000,
                  test_batch_size=1000, **kwargs):

    assert dataset == 'mnist'
    root = '{}/{}'.format(os.environ['VISION_DATA'], dataset)

    # Data loading code
    normalize = transforms.Normalize(mean=(0.1307,),
                                     std=(0.3081,))

    transform = transforms.Compose([transforms.ToTensor(), normalize])

    # define two datasets in order to have different transforms
    # on training and validation
    dataset_train = datasets.MNIST(root=root, train=True, transform=transform)
    dataset_val = datasets.MNIST(root=root, train=True, transform=transform)
    dataset_test = datasets.MNIST(root=root, train=False, transform=transform)

    return create_loaders(dataset_train, dataset_val,
                          dataset_test, train_size, val_size, test_size,
                          batch_size=batch_size,
                          test_batch_size=test_batch_size,
                          cuda=cuda, num_workers=0) 
开发者ID:oval-group,项目名称:dfw,代码行数:26,代码来源:loaders.py

示例6: __getitem__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __getitem__(self, index):
        """Override the original method of the MNIST class.
        Args:
            index (int): Index

        Returns:
            tuple: (image, target, semi_target, index)
        """
        img, target, semi_target = self.data[index], int(self.targets[index]), int(self.semi_targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, semi_target, index 
开发者ID:lukasruff,项目名称:Deep-SAD-PyTorch,代码行数:23,代码来源:mnist.py

示例7: get_mnist

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_mnist(train, get_dataset=False, batch_size=cfg.batch_size):
    """Get MNIST dataset loader."""
    # image pre-processing
    convert_to_3_channels = transforms.Lambda(
        lambda x: torch.cat([x, x, x], 0))
    pre_process = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=cfg.dataset_mean,
                                          std=cfg.dataset_std),
                                      convert_to_3_channels])

    # dataset and data loader
    mnist_dataset = datasets.MNIST(root=cfg.data_root,
                                   train=train,
                                   transform=pre_process,
                                   download=True)

    if get_dataset:
        return mnist_dataset
    else:
        mnist_data_loader = torch.utils.data.DataLoader(
            dataset=mnist_dataset,
            batch_size=batch_size,
            shuffle=True)
        return mnist_data_loader 
开发者ID:corenel,项目名称:pytorch-atda,代码行数:27,代码来源:mnist.py

示例8: get_mnist_dataloaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_mnist_dataloaders(batch_size=128):
    """MNIST dataloader with (32, 32) sized images."""
    # Resize images so they are a power of 2
    all_transforms = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])
    # Get train and test data
    train_data = datasets.MNIST('../data', train=True, download=True,
                                transform=all_transforms)
    test_data = datasets.MNIST('../data', train=False,
                               transform=all_transforms)
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader 
开发者ID:vandit15,项目名称:Self-Supervised-Gans-Pytorch,代码行数:18,代码来源:dataloaders.py

示例9: get_fashion_mnist_dataloaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_fashion_mnist_dataloaders(batch_size=128):
    """Fashion MNIST dataloader with (32, 32) sized images."""
    # Resize images so they are a power of 2
    all_transforms = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])
    # Get train and test data
    train_data = datasets.FashionMNIST('../fashion_data', train=True, download=True,
                                       transform=all_transforms)
    test_data = datasets.FashionMNIST('../fashion_data', train=False,
                                      transform=all_transforms)
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader 
开发者ID:vandit15,项目名称:Self-Supervised-Gans-Pytorch,代码行数:18,代码来源:dataloaders.py

示例10: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __init__(self, min_len: int, max_len: int, dataset_len: int, train: bool = True, transform: Compose = None):
        self.min_len = min_len
        self.max_len = max_len
        self.dataset_len = dataset_len
        self.train = train
        self.transform = transform

        self.mnist = MNIST(DATA_ROOT, train=self.train, transform=self.transform, download=True)
        mnist_len = self.mnist.__len__()
        mnist_items_range = np.arange(0, mnist_len)

        items_len_range = np.arange(self.min_len, self.max_len + 1)
        items_len = np.random.choice(items_len_range, size=self.dataset_len, replace=True)
        self.mnist_items = []
        for i in range(self.dataset_len):
            self.mnist_items.append(np.random.choice(mnist_items_range, size=items_len[i], replace=True)) 
开发者ID:yassersouri,项目名称:pytorch-deep-sets,代码行数:18,代码来源:datasets.py

示例11: get_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""
    
    transform = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    return svhn_loader, mnist_loader 
开发者ID:yunjey,项目名称:mnist-svhn-transfer,代码行数:23,代码来源:data_loader.py

示例12: _get_train_data_loader

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
    logger.info("Get train data loader")
    dataset = datasets.MNIST(
        training_dir,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
        download=False,  # True sets a dependency on an external site for our canaries.
    )
    train_sampler = (
        torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
    )
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=train_sampler is None,
        sampler=train_sampler,
        **kwargs
    )
    return train_sampler, train_loader 
开发者ID:aws,项目名称:sagemaker-python-sdk,代码行数:23,代码来源:mnist.py

示例13: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __init__(self):
        model = mnist_model.Net().to(device)
        model.eval()
        map_location = None if use_cuda else 'cpu'
        model.load_state_dict(
            torch.load('mnist.pth', map_location=map_location))

        stats_file = f'mnist_act_{feature_layer}.npz'
        try:
            f = np.load(stats_file)
            m_mnist, s_mnist = f['mu'][:], f['sigma'][:]
            f.close()
        except FileNotFoundError:
            data = datasets.MNIST('data', train=True, download=True,
                                  transform=transforms.ToTensor())
            images = len(data)
            batch_size = 64
            data_loader = DataLoader([image for image, _ in data],
                                     batch_size=batch_size)
            m_mnist, s_mnist = calculate_activation_statistics(
                data_loader, images, model, verbose=True)
            np.savez(stats_file, mu=m_mnist, sigma=s_mnist)

        self.model = model
        self.mnist_stats = m_mnist, s_mnist 
开发者ID:steveli,项目名称:misgan,代码行数:27,代码来源:mnist_fid.py

示例14: __init__

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def __init__(self, root, mnist_root="data", train=True, transform=None, target_transform=None, download=False):
        """Init MNIST-M dataset."""
        super(MNISTM, self).__init__()
        self.root = os.path.expanduser(root)
        self.mnist_root = os.path.expanduser(mnist_root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found." + " You can use download=True to download it")

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file)
            )
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file)
            ) 
开发者ID:eriklindernoren,项目名称:PyTorch-GAN,代码行数:25,代码来源:mnistm.py

示例15: get_mnist_dataloaders

# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import MNIST [as 别名]
def get_mnist_dataloaders(root=r'..\data', batch_size=128):
    """MNIST dataloader with (32, 32) sized images."""
    # Resize images so they are a power of 2
    all_transforms = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        # transforms.Normalize([0.5],[0.5])
    ])
    # Get train and test data
    train_data = datasets.MNIST(root, train=True, download=True,
                                transform=all_transforms)
    test_data = datasets.MNIST(root, train=False,
                               transform=all_transforms)
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader 
开发者ID:dingguanglei,项目名称:jdit,代码行数:19,代码来源:dataset.py


注:本文中的torchvision.datasets.MNIST属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。