本文整理汇总了Python中datasets.mnist.create_default_splits方法的典型用法代码示例。如果您正苦于以下问题:Python mnist.create_default_splits方法的具体用法?Python mnist.create_default_splits怎么用?Python mnist.create_default_splits使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类datasets.mnist
的用法示例。
在下文中一共展示了mnist.create_default_splits方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--checkpoint_path', type=str, default=None)
parser.add_argument('--train_dir', type=str)
parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--reconstruct', action='store_true', default=False)
parser.add_argument('--generate', action='store_true', default=False)
parser.add_argument('--interpolate', action='store_true', default=False)
parser.add_argument('--data_id', nargs='*', default=None)
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.conv_info = dataset.get_conv_info()
config.deconv_info = dataset.get_deconv_info()
dataset_train, dataset_test = dataset.create_default_splits()
m, l = dataset_train.get_data(dataset_train.ids[0])
config.data_info = np.concatenate([np.asarray(m.shape), np.asarray(l.shape)])
evaler = Evaler(config, dataset_test, dataset_train)
log.warning("dataset: %s", config.dataset)
evaler.eval_run()
示例2: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--lr_weight_decay', action='store_true', default=False)
parser.add_argument('--dump_result', action='store_true', default=False)
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.conv_info = dataset.get_conv_info()
config.deconv_info = dataset.get_deconv_info()
dataset_train, dataset_test = dataset.create_default_splits()
m, l = dataset_train.get_data(dataset_train.ids[0])
config.data_info = np.concatenate([np.asarray(m.shape), np.asarray(l.shape)])
trainer = Trainer(config,
dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate: %f",
config.dataset, config.learning_rate)
trainer.train(dataset_train)
示例3: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--lr_weight_decay', action='store_true', default=False)
parser.add_argument('--activation', type=str, default='selu', choices=['relu', 'lrelu', 'selu'])
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.data_info = dataset.get_data_info()
config.conv_info = dataset.get_conv_info()
config.visualize_shape = dataset.get_vis_info()
dataset_train, dataset_test = dataset.create_default_splits()
trainer = Trainer(config,
dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
trainer.train()
示例4: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--lr_weight_decay', action='store_true', default=False)
parser.add_argument('--activation', type=str, default='selu', choices=['relu', 'lrelu', 'selu'])
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.data_info = dataset.get_data_info()
config.conv_info = dataset.get_conv_info()
config.visualize_shape = dataset.get_vis_info()
dataset_train, dataset_test = dataset.create_default_splits()
trainer = Trainer(config, dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
trainer.train()
示例5: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--output_file', type=str, default=None)
parser.add_argument('--checkpoint_path', type=str)
parser.add_argument('--train_dir', type=str)
parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'Fashion', 'SVHN', 'CIFAR10'])
parser.add_argument('--max_steps', type=int, default=1)
config = parser.parse_args()
if config.dataset == 'mnist':
import datasets.mnist as dataset
elif config.dataset == 'Fashion':
import datasets.fashion_mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.data_info = dataset.get_data_info()
config.conv_info = dataset.get_conv_info()
config.deconv_info = dataset.get_deconv_info()
dataset_train, dataset_test = dataset.create_default_splits()
evaler = Evaler(config, dataset_test)
log.warning("dataset: %s", config.dataset)
evaler.eval_run()
示例6: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--dataset', type=str, default='CIFAR10',
choices=['MNIST', 'Fashion', 'SVHN', 'CIFAR10'])
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--update_rate', type=int, default=5)
parser.add_argument('--lr_weight_decay', action='store_true', default=False)
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'Fashion':
import datasets.fashion_mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
else:
raise ValueError(config.dataset)
config.data_info = dataset.get_data_info()
config.conv_info = dataset.get_conv_info()
config.deconv_info = dataset.get_deconv_info()
dataset_train, dataset_test = dataset.create_default_splits()
trainer = Trainer(config,
dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
trainer.train()
示例7: main
# 需要导入模块: from datasets import mnist [as 别名]
# 或者: from datasets.mnist import create_default_splits [as 别名]
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--max_batch_size', type=int, default=64)
parser.add_argument('--prefix', type=str, default='default',
help='the nickname of this training job')
parser.add_argument('--checkpoint', type=str, default=None)
parser.add_argument('--dataset', type=str, default='MNIST',
choices=['MNIST', 'Fashion', 'SVHN',
'CIFAR10', 'ImageNet', 'TinyImageNet'])
parser.add_argument('--norm_type', type=str, default='batch',
choices=['batch', 'group'])
# Log
parser.add_argument('--max_training_step', type=int, default=100000)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--test_sample_step', type=int, default=10)
parser.add_argument('--write_summary_step', type=int, default=10)
parser.add_argument('--ckpt_save_step', type=int, default=1000)
# Learning
parser.add_argument('--learning_rate', type=float, default=1e-5)
parser.add_argument('--no_adjust_learning_rate', action='store_true', default=False)
config = parser.parse_args()
if config.dataset == 'MNIST':
import datasets.mnist as dataset
elif config.dataset == 'Fashion':
import datasets.fashion_mnist as dataset
elif config.dataset == 'SVHN':
import datasets.svhn as dataset
elif config.dataset == 'CIFAR10':
import datasets.cifar10 as dataset
elif config.dataset == 'TinyImageNet':
import datasets.tiny_imagenet as dataset
elif config.dataset == 'ImageNet':
import datasets.imagenet as dataset
else:
raise ValueError(config.dataset)
dataset_train, dataset_test = dataset.create_default_splits()
image, label = dataset_train.get_data(dataset_train.ids[0])
config.data_info = np.concatenate([np.asarray(image.shape), np.asarray(label.shape)])
trainer = Trainer(config,
dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate: %f", config.dataset, config.learning_rate)
trainer.train()