本文整理汇总了Python中tensorflow.contrib.learn.python.learn.datasets.mnist.read_data_sets方法的典型用法代码示例。如果您正苦于以下问题:Python mnist.read_data_sets方法的具体用法?Python mnist.read_data_sets怎么用?Python mnist.read_data_sets使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.learn.python.learn.datasets.mnist
的用法示例。
在下文中一共展示了mnist.read_data_sets方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _maybe_download_and_extract
# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def _maybe_download_and_extract(self):
"""Download and extract the MNIST dataset"""
data_sets = mnist.read_data_sets(
self._data_dir,
dtype=tf.uint8,
reshape=False,
validation_size=self._num_examples_per_epoch_for_eval)
# Convert to Examples and write the result to TFRecords.
if not tf.gfile.Exists(os.path.join(self._data_dir, 'train.tfrecords')):
convert_to_tfrecords(data_sets.train, 'train', self._data_dir)
if not tf.gfile.Exists(
os.path.join(self._data_dir, 'validation.tfrecords')):
convert_to_tfrecords(data_sets.validation, 'validation',
self._data_dir)
if not tf.gfile.Exists(os.path.join(self._data_dir, 'test.tfrecords')):
convert_to_tfrecords(data_sets.test, 'test', self._data_dir)
示例2: main
# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')
示例3: build_input_pipeline
# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def build_input_pipeline(data_dir, batch_size, heldout_size, mnist_type):
"""Builds an Iterator switching between train and heldout data."""
# Build an iterator over training batches.
if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
if mnist_type == MnistType.FAKE_DATA:
mnist_data = build_fake_data()
else:
mnist_data = mnist.read_data_sets(data_dir)
training_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.train.images, np.int32(mnist_data.train.labels)))
heldout_dataset = tf.data.Dataset.from_tensor_slices(
(mnist_data.validation.images,
np.int32(mnist_data.validation.labels)))
elif mnist_type == MnistType.BERNOULLI:
training_dataset = load_bernoulli_mnist_dataset(data_dir, "train")
heldout_dataset = load_bernoulli_mnist_dataset(data_dir, "valid")
else:
raise ValueError("Unknown MNIST type.")
training_batches = training_dataset.repeat().batch(batch_size)
training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)
# Build a iterator over the heldout set with batch_size=heldout_size,
# i.e., return the entire heldout set as a constant.
heldout_frozen = (heldout_dataset.take(heldout_size).
repeat().batch(heldout_size))
heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_frozen)
# Combine these into a feedable iterator that can switch between training
# and validation inputs.
handle = tf.compat.v1.placeholder(tf.string, shape=[])
feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
handle, training_batches.output_types, training_batches.output_shapes)
images, labels = feedable_iterator.get_next()
# Reshape as a pixel image and binarize pixels.
images = tf.reshape(images, shape=[-1] + IMAGE_SHAPE)
if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
images = tf.cast(images > 0.5, dtype=tf.int32)
return images, labels, handle, training_iterator, heldout_iterator
示例4: main
# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def main():
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory, dtype=tf.uint8, reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')