當前位置: 首頁>>代碼示例>>Python>>正文


Python cifar10.create_default_splits方法代碼示例

本文整理匯總了Python中datasets.cifar10.create_default_splits方法的典型用法代碼示例。如果您正苦於以下問題:Python cifar10.create_default_splits方法的具體用法?Python cifar10.create_default_splits怎麽用?Python cifar10.create_default_splits使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在datasets.cifar10的用法示例。


在下文中一共展示了cifar10.create_default_splits方法的8個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint_path', type=str, default=None)
    parser.add_argument('--train_dir', type=str)
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
    parser.add_argument('--reconstruct', action='store_true', default=False)
    parser.add_argument('--generate', action='store_true', default=False)
    parser.add_argument('--interpolate', action='store_true', default=False)
    parser.add_argument('--data_id', nargs='*', default=None)
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.conv_info = dataset.get_conv_info()
    config.deconv_info = dataset.get_deconv_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    m, l = dataset_train.get_data(dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(m.shape), np.asarray(l.shape)])

    evaler = Evaler(config, dataset_test, dataset_train)

    log.warning("dataset: %s", config.dataset)
    evaler.eval_run() 
開發者ID:clvrai,項目名稱:Generative-Latent-Optimization-Tensorflow,代碼行數:36,代碼來源:evaler.py

示例2: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    parser.add_argument('--dump_result', action='store_true', default=False)
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.conv_info = dataset.get_conv_info()
    config.deconv_info = dataset.get_deconv_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    m, l = dataset_train.get_data(dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(m.shape), np.asarray(l.shape)])

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f",
                config.dataset, config.learning_rate)
    trainer.train(dataset_train) 
開發者ID:clvrai,項目名稱:Generative-Latent-Optimization-Tensorflow,代碼行數:37,代碼來源:trainer.py

示例3: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    config = parser.parse_args()

    if config.dataset == 'ImageNet':
        import datasets.ImageNet as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    dataset_train, dataset_test = dataset.create_default_splits()

    image, _, label, _ = dataset_train.get_data(dataset_train.ids[0], dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(image.shape), np.asarray(label.shape)])

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f",
                config.dataset, config.learning_rate)
    trainer.train(dataset_train) 
開發者ID:clvrai,項目名稱:Representation-Learning-by-Learning-to-Count,代碼行數:33,代碼來源:trainer.py

示例4: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    parser.add_argument('--activation', type=str, default='selu', choices=['relu', 'lrelu', 'selu'])
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import  datasets.mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.data_info = dataset.get_data_info()
    config.conv_info = dataset.get_conv_info()
    config.visualize_shape = dataset.get_vis_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
    trainer.train() 
開發者ID:shaohua0116,項目名稱:Activation-Visualization-Histogram,代碼行數:33,代碼來源:trainer.py

示例5: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    parser.add_argument('--activation', type=str, default='selu', choices=['relu', 'lrelu', 'selu'])
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.data_info = dataset.get_data_info()
    config.conv_info = dataset.get_conv_info()
    config.visualize_shape = dataset.get_vis_info()
    dataset_train, dataset_test = dataset.create_default_splits()
    trainer = Trainer(config, dataset_train, dataset_test)
    log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
    trainer.train() 
開發者ID:IsaacChanghau,項目名稱:AmusingPythonCodes,代碼行數:30,代碼來源:trainer.py

示例6: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--output_file', type=str, default=None)
    parser.add_argument('--checkpoint_path', type=str)
    parser.add_argument('--train_dir', type=str)
    parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'Fashion', 'SVHN', 'CIFAR10'])
    parser.add_argument('--max_steps', type=int, default=1)
    config = parser.parse_args()

    if config.dataset == 'mnist':
        import datasets.mnist as dataset
    elif config.dataset == 'Fashion':
        import datasets.fashion_mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.data_info = dataset.get_data_info()
    config.conv_info = dataset.get_conv_info()
    config.deconv_info = dataset.get_deconv_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    evaler = Evaler(config, dataset_test)

    log.warning("dataset: %s", config.dataset)
    evaler.eval_run() 
開發者ID:shaohua0116,項目名稱:DCGAN-Tensorflow,代碼行數:33,代碼來源:evaler.py

示例7: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='CIFAR10',
                        choices=['MNIST', 'Fashion', 'SVHN', 'CIFAR10'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--update_rate', type=int, default=5)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'Fashion':
        import datasets.fashion_mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.data_info = dataset.get_data_info()
    config.conv_info = dataset.get_conv_info()
    config.deconv_info = dataset.get_deconv_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
    trainer.train() 
開發者ID:shaohua0116,項目名稱:DCGAN-Tensorflow,代碼行數:36,代碼來源:trainer.py

示例8: main

# 需要導入模塊: from datasets import cifar10 [as 別名]
# 或者: from datasets.cifar10 import create_default_splits [as 別名]
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--max_batch_size', type=int, default=64)
    parser.add_argument('--prefix', type=str, default='default',
                        help='the nickname of this training job')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='MNIST',
                        choices=['MNIST', 'Fashion', 'SVHN',
                                 'CIFAR10', 'ImageNet', 'TinyImageNet'])
    parser.add_argument('--norm_type', type=str, default='batch',
                        choices=['batch', 'group'])
    # Log
    parser.add_argument('--max_training_step', type=int, default=100000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--test_sample_step', type=int, default=10)
    parser.add_argument('--write_summary_step', type=int, default=10)
    parser.add_argument('--ckpt_save_step', type=int, default=1000)
    # Learning
    parser.add_argument('--learning_rate', type=float, default=1e-5)
    parser.add_argument('--no_adjust_learning_rate', action='store_true', default=False)
    config = parser.parse_args()


    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'Fashion':
        import datasets.fashion_mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    elif config.dataset == 'TinyImageNet':
        import datasets.tiny_imagenet as dataset
    elif config.dataset == 'ImageNet':
        import datasets.imagenet as dataset
    else:
        raise ValueError(config.dataset)

    dataset_train, dataset_test = dataset.create_default_splits()
    image, label = dataset_train.get_data(dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(image.shape), np.asarray(label.shape)])

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
    trainer.train() 
開發者ID:shaohua0116,項目名稱:Group-Normalization-Tensorflow,代碼行數:51,代碼來源:trainer.py


注:本文中的datasets.cifar10.create_default_splits方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。