当前位置: 首页>>代码示例>>Python>>正文


Python mnist.read_data_sets方法代码示例

本文整理汇总了Python中tensorflow.contrib.learn.python.learn.datasets.mnist.read_data_sets方法的典型用法代码示例。如果您正苦于以下问题:Python mnist.read_data_sets方法的具体用法?Python mnist.read_data_sets怎么用?Python mnist.read_data_sets使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.learn.python.learn.datasets.mnist的用法示例。


在下文中一共展示了mnist.read_data_sets方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _maybe_download_and_extract

# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def _maybe_download_and_extract(self):
        """Download and extract the MNIST dataset"""
        data_sets = mnist.read_data_sets(
            self._data_dir,
            dtype=tf.uint8,
            reshape=False,
            validation_size=self._num_examples_per_epoch_for_eval)

        # Convert to Examples and write the result to TFRecords.
        if not tf.gfile.Exists(os.path.join(self._data_dir, 'train.tfrecords')):
            convert_to_tfrecords(data_sets.train, 'train', self._data_dir)

        if not tf.gfile.Exists(
                os.path.join(self._data_dir, 'validation.tfrecords')):
            convert_to_tfrecords(data_sets.validation, 'validation',
                                 self._data_dir)

        if not tf.gfile.Exists(os.path.join(self._data_dir, 'test.tfrecords')):
            convert_to_tfrecords(data_sets.test, 'test', self._data_dir) 
开发者ID:galeone,项目名称:dynamic-training-bench,代码行数:21,代码来源:MNIST.py

示例2: main

# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def main(unused_argv):
  # Get the data.
  data_sets = mnist.read_data_sets(FLAGS.directory,
                                   dtype=tf.uint8,
                                   reshape=False,
                                   validation_size=FLAGS.validation_size)

  # Convert to Examples and write the result to TFRecords.
  convert_to(data_sets.train, 'train')
  convert_to(data_sets.validation, 'validation')
  convert_to(data_sets.test, 'test') 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-dist-mnist-example,代码行数:13,代码来源:create_records.py

示例3: build_input_pipeline

# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def build_input_pipeline(data_dir, batch_size, heldout_size, mnist_type):
  """Builds an Iterator switching between train and heldout data."""
  # Build an iterator over training batches.
  if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
    if mnist_type == MnistType.FAKE_DATA:
      mnist_data = build_fake_data()
    else:
      mnist_data = mnist.read_data_sets(data_dir)
    training_dataset = tf.data.Dataset.from_tensor_slices(
        (mnist_data.train.images, np.int32(mnist_data.train.labels)))
    heldout_dataset = tf.data.Dataset.from_tensor_slices(
        (mnist_data.validation.images,
         np.int32(mnist_data.validation.labels)))
  elif mnist_type == MnistType.BERNOULLI:
    training_dataset = load_bernoulli_mnist_dataset(data_dir, "train")
    heldout_dataset = load_bernoulli_mnist_dataset(data_dir, "valid")
  else:
    raise ValueError("Unknown MNIST type.")

  training_batches = training_dataset.repeat().batch(batch_size)
  training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)

  # Build a iterator over the heldout set with batch_size=heldout_size,
  # i.e., return the entire heldout set as a constant.
  heldout_frozen = (heldout_dataset.take(heldout_size).
                    repeat().batch(heldout_size))
  heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_frozen)

  # Combine these into a feedable iterator that can switch between training
  # and validation inputs.
  handle = tf.compat.v1.placeholder(tf.string, shape=[])
  feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
      handle, training_batches.output_types, training_batches.output_shapes)
  images, labels = feedable_iterator.get_next()
  # Reshape as a pixel image and binarize pixels.
  images = tf.reshape(images, shape=[-1] + IMAGE_SHAPE)
  if mnist_type in [MnistType.FAKE_DATA, MnistType.THRESHOLD]:
    images = tf.cast(images > 0.5, dtype=tf.int32)

  return images, labels, handle, training_iterator, heldout_iterator 
开发者ID:GoogleCloudPlatform,项目名称:ml-on-gcp,代码行数:42,代码来源:vq_vae.py

示例4: main

# 需要导入模块: from tensorflow.contrib.learn.python.learn.datasets import mnist [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets [as 别名]
def main():
    # Get the data.
    data_sets = mnist.read_data_sets(FLAGS.directory, dtype=tf.uint8, reshape=False,
                                     validation_size=FLAGS.validation_size)
    # Convert to Examples and write the result to TFRecords.
    convert_to(data_sets.train, 'train')
    convert_to(data_sets.validation, 'validation')
    convert_to(data_sets.test, 'test') 
开发者ID:IsaacChanghau,项目名称:AmusingPythonCodes,代码行数:10,代码来源:mnist_to_tfrecords.py


注:本文中的tensorflow.contrib.learn.python.learn.datasets.mnist.read_data_sets方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。