本文整理汇总了Python中fuel.datasets.MNIST属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.MNIST属性的具体用法?Python datasets.MNIST怎么用?Python datasets.MNIST使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类fuel.datasets
的用法示例。
在下文中一共展示了datasets.MNIST属性的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_mnist_train
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def test_mnist_train():
skip_if_not_available(datasets=['mnist.hdf5'])
dataset = MNIST(('train',), load_in_memory=False)
handle = dataset.open()
data, labels = dataset.get_data(handle, slice(0, 10))
assert data.dtype == 'uint8'
assert data.shape == (10, 1, 28, 28)
assert labels.shape == (10, 1)
known = numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253,
253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0,
0, 0, 0])
assert_allclose(data[0][0][6], known)
assert labels[0][0] == 5
assert dataset.num_examples == 60000
dataset.close(handle)
stream = DataStream.default_stream(
dataset, iteration_scheme=SequentialScheme(10, 10))
data = next(stream.get_epoch_iterator())[0]
assert data.min() >= 0.0 and data.max() <= 1.0
assert data.dtype == config.floatX
示例2: test_mnist_test
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def test_mnist_test():
skip_if_not_available(datasets=['mnist.hdf5'])
dataset = MNIST(('test',), load_in_memory=False)
handle = dataset.open()
data, labels = dataset.get_data(handle, slice(0, 10))
assert data.dtype == 'uint8'
assert data.shape == (10, 1, 28, 28)
assert labels.shape == (10, 1)
known = numpy.array([0, 0, 0, 0, 0, 0, 84, 185, 159, 151, 60, 36, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
assert_allclose(data[0][0][7], known)
assert labels[0][0] == 7
assert dataset.num_examples == 10000
dataset.close(handle)
stream = DataStream.default_stream(
dataset, iteration_scheme=SequentialScheme(10, 10))
data = next(stream.get_epoch_iterator())[0]
assert data.min() >= 0.0 and data.max() <= 1.0
assert data.dtype == config.floatX
示例3: parse_args
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', default=512, type=int,
help='Batch size')
parser.add_argument('--lr', default=1e-3, type=float,
help='Initial learning rate. ' + \
'Will be decayed until it\'s 1e-5.')
parser.add_argument('--resume_file', default=None, type=str,
help='Name of saved model to continue training')
parser.add_argument('--suffix', default='', type=str,
help='Optional descriptive suffix for model')
parser.add_argument('--output-dir', type=str, default='./',
help='Output directory to store trained models')
parser.add_argument('--ext-every-n', type=int, default=25,
help='Evaluate training extensions every N epochs')
parser.add_argument('--model-args', type=str, default='',
help='Dictionary string to be eval()d containing model arguments.')
parser.add_argument('--dropout_rate', type=float, default=0.,
help='Rate to use for dropout during training+testing.')
parser.add_argument('--dataset', type=str, default='MNIST',
help='Name of dataset to use.')
parser.add_argument('--plot_before_training', type=bool, default=False,
help='Save diagnostic plots at epoch 0, before any training.')
args = parser.parse_args()
model_args = eval('dict(' + args.model_args + ')')
print model_args
if not os.path.exists(args.output_dir):
raise IOError("Output directory '%s' does not exist. "%args.output_dir)
return args, model_args
示例4: test_mnist_axes
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def test_mnist_axes():
skip_if_not_available(datasets=['mnist.hdf5'])
dataset = MNIST(('train',), load_in_memory=False)
assert_equal(dataset.axis_labels['features'],
('batch', 'channel', 'height', 'width'))
示例5: test_mnist_invalid_split
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def test_mnist_invalid_split():
skip_if_not_available(datasets=['mnist.hdf5'])
assert_raises(ValueError, MNIST, ('dummy',))
示例6: test_in_memory
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def test_in_memory():
skip_if_not_available(datasets=['mnist.hdf5'])
# Load MNIST and get two batches
mnist = MNIST(('train',), load_in_memory=True)
data_stream = DataStream(mnist, iteration_scheme=SequentialScheme(
examples=mnist.num_examples, batch_size=256))
epoch = data_stream.get_epoch_iterator()
for i, (features, targets) in enumerate(epoch):
if i == 1:
break
handle = mnist.open()
known_features, _ = mnist.get_data(handle, slice(256, 512))
mnist.close(handle)
assert numpy.all(features == known_features)
# Pickle the epoch and make sure that the data wasn't dumped
with tempfile.NamedTemporaryFile(delete=False) as f:
filename = f.name
cPickle.dump(epoch, f)
assert os.path.getsize(filename) < 1024 * 1024 # Less than 1MB
# Reload the epoch and make sure that the state was maintained
del epoch
with open(filename, 'rb') as f:
epoch = cPickle.load(f)
features, targets = next(epoch)
handle = mnist.open()
known_features, _ = mnist.get_data(handle, slice(512, 768))
mnist.close(handle)
assert numpy.all(features == known_features)
示例7: unify_labels
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def unify_labels(y):
""" Work-around for Fuel bug where MNIST and Cifar-10
datasets have different dimensionalities for the targets:
e.g. (50000, 1) vs (60000,) """
yshape = y.shape
y = y.flatten()
assert y.shape[0] == yshape[0]
return y
示例8: main
# 需要导入模块: from fuel import datasets [as 别名]
# 或者: from fuel.datasets import MNIST [as 别名]
def main(save_to, num_epochs):
mlp = MLP([Tanh(), Softmax()], [784, 100, 10],
weights_init=IsotropicGaussian(0.01),
biases_init=Constant(0))
mlp.initialize()
x = tensor.matrix('features')
y = tensor.lmatrix('targets')
probs = mlp.apply(x)
cost = CategoricalCrossEntropy().apply(y.flatten(), probs)
error_rate = MisclassificationRate().apply(y.flatten(), probs)
cg = ComputationGraph([cost])
W1, W2 = VariableFilter(roles=[WEIGHT])(cg.variables)
cost = cost + .00005 * (W1 ** 2).sum() + .00005 * (W2 ** 2).sum()
cost.name = 'final_cost'
mnist_train = MNIST(("train",))
mnist_test = MNIST(("test",))
algorithm = GradientDescent(
cost=cost, parameters=cg.parameters,
step_rule=Scale(learning_rate=0.1))
extensions = [Timing(),
FinishAfter(after_n_epochs=num_epochs),
DataStreamMonitoring(
[cost, error_rate],
Flatten(
DataStream.default_stream(
mnist_test,
iteration_scheme=SequentialScheme(
mnist_test.num_examples, 500)),
which_sources=('features',)),
prefix="test"),
TrainingDataMonitoring(
[cost, error_rate,
aggregation.mean(algorithm.total_gradient_norm)],
prefix="train",
after_epoch=True),
Checkpoint(save_to),
Printing()]
if BLOCKS_EXTRAS_AVAILABLE:
extensions.append(Plot(
'MNIST example',
channels=[
['test_final_cost',
'test_misclassificationrate_apply_error_rate'],
['train_total_gradient_norm']]))
main_loop = MainLoop(
algorithm,
Flatten(
DataStream.default_stream(
mnist_train,
iteration_scheme=SequentialScheme(
mnist_train.num_examples, 50)),
which_sources=('features',)),
model=Model(cost),
extensions=extensions)
main_loop.run()