当前位置: 首页>>代码示例>>Python>>正文


Python mnist.training方法代码示例

本文整理汇总了Python中tensorflow.examples.tutorials.mnist.mnist.training方法的典型用法代码示例。如果您正苦于以下问题:Python mnist.training方法的具体用法?Python mnist.training怎么用?Python mnist.training使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.examples.tutorials.mnist.mnist的用法示例。


在下文中一共展示了mnist.training方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def main(unused_argv):
  if FLAGS.log_dir is None or FLAGS.log_dir == "":
    raise ValueError("Must specify an explicit `log_dir`")
  if FLAGS.data_dir is None or FLAGS.data_dir == "":
    raise ValueError("Must specify an explicit `data_dir`")

  device, target = device_and_target()
  with tf.device(device):
    images = tf.placeholder(tf.float32, [None, 784], name='image_input')
    labels = tf.placeholder(tf.float32, [None], name='label_input')
    data = read_data_sets(FLAGS.data_dir,
            one_hot=False,
            fake_data=False)
    logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
    loss = mnist.loss(logits, labels)
    loss = tf.Print(loss, [loss], message="Loss = ")
    train_op = mnist.training(loss, FLAGS.learning_rate)

  with tf.train.MonitoredTrainingSession(
      master=target,
      is_chief=(FLAGS.task_index == 0),
      checkpoint_dir=FLAGS.log_dir) as sess:
    while not sess.should_stop():
      xs, ys = data.train.next_batch(FLAGS.batch_size, fake_data=False)
      sess.run(train_op, feed_dict={images:xs, labels:ys}) 
开发者ID:tensorport,项目名称:mnist,代码行数:27,代码来源:mnist.py

示例2: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl):
    """Fills the feed_dict for training the given step.
    A feed_dict takes the form of:
    feed_dict = {
            <placeholder>: <tensor of values to be passed for placeholder>,
            ....
    }
    Args:
        data_set: The set of images and labels, from input_data.read_data_sets()
        images_pl: The images placeholder, from placeholder_inputs().
        labels_pl: The labels placeholder, from placeholder_inputs().
    Returns:
        feed_dict: The feed dictionary mapping from placeholders to values.
    """
    # Create the feed_dict for the placeholders filled with the next
    # `batch size ` examples.
    images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                   FLAGS.fake_data)
    feed_dict = {
        images_pl: images_feed,
        labels_pl: labels_feed,
    }
    return feed_dict 
开发者ID:gradientzoo,项目名称:python-gradientzoo,代码行数:25,代码来源:tensorflow_mnist.py

示例3: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl, batch_size):
  """Fills the feed_dict for training the given step.

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().
    batch_size: Batch size of data to feed.

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size ` examples.
  images_feed, labels_feed = data_set.next_batch(batch_size, FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict 
开发者ID:rashmitripathi,项目名称:DeepLearning_VirtualReality_BigData_Project,代码行数:22,代码来源:mnist.py

示例4: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:28,代码来源:task.py

示例5: fill_feed_dict

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  _, images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                    FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:28,代码来源:task.py

示例6: inputs

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def inputs(train, batch_size, num_epochs):
  """Reads input data num_epochs times.
  Args:
    train: Selects between the training (True) and validation (False) data.
    batch_size: Number of examples per returned batch.
    num_epochs: Number of times to read the input data, or 0/None to
       train forever.
  Returns:
    A tuple (images, labels), where:
    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
      in the range [-0.5, 0.5].
    * labels is an int32 tensor with shape [batch_size] with the true label,
      a number in the range [0, mnist.NUM_CLASSES).
    Note that an tf.train.QueueRunner is added to the graph, which
    must be run using e.g. tf.train.start_queue_runners().
  """
  if not num_epochs: num_epochs = None
  filename = os.path.join(FLAGS.train_dir,
                          TRAIN_FILE if train else VALIDATION_FILE)

  with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer(
        [filename], num_epochs=num_epochs)

    # Even when reading in multiple threads, share the filename
    # queue.
    image, label = read_and_decode(filename_queue)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, num_threads=2,
        capacity=1000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=1000)

    return images, sparse_labels 
开发者ID:cheyang,项目名称:mnist-examples,代码行数:40,代码来源:mnist_train.py

示例7: device_and_target

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def device_and_target():
  # If FLAGS.job_name is not set, we're running single-machine TensorFlow.
  # Don't set a device.
  if FLAGS.job_name is None:
    print("Running single-machine training")
    return (None, "")

  # Otherwise we're running distributed TensorFlow.
  print("%s.%d  -- Running distributed training"%(FLAGS.job_name, FLAGS.task_index))
  if FLAGS.task_index is None or FLAGS.task_index == "":
    raise ValueError("Must specify an explicit `task_index`")
  if FLAGS.ps_hosts is None or FLAGS.ps_hosts == "":
    raise ValueError("Must specify an explicit `ps_hosts`")
  if FLAGS.worker_hosts is None or FLAGS.worker_hosts == "":
    raise ValueError("Must specify an explicit `worker_hosts`")

  cluster_spec = tf.train.ClusterSpec({
      "ps": FLAGS.ps_hosts.split(","),
      "worker": FLAGS.worker_hosts.split(","),
  })
  server = tf.train.Server(
      cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
  if FLAGS.job_name == "ps":
    server.join()

  worker_device = "/job:worker/task:{}".format(FLAGS.task_index)
  # The device setter will automatically place Variables ops on separate
  # parameter servers (ps). The non-Variable ops will be placed on the workers.
  return (
      tf.train.replica_device_setter(
          worker_device=worker_device,
          cluster=cluster_spec),
      server.target,
  ) 
开发者ID:tensorport,项目名称:mnist,代码行数:36,代码来源:mnist.py

示例8: inputs

# 需要导入模块: from tensorflow.examples.tutorials.mnist import mnist [as 别名]
# 或者: from tensorflow.examples.tutorials.mnist.mnist import training [as 别名]
def inputs(train, batch_size, num_epochs):
  """Reads input data num_epochs times.

  Args:
    train: Selects between the training (True) and validation (False) data.
    batch_size: Number of examples per returned batch.
    num_epochs: Number of times to read the input data, or 0/None to
       train forever.

  Returns:
    A tuple (images, labels), where:
    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
      in the range [-0.5, 0.5].
    * labels is an int32 tensor with shape [batch_size] with the true label,
      a number in the range [0, mnist.NUM_CLASSES).
    Note that an tf.train.QueueRunner is added to the graph, which
    must be run using e.g. tf.train.start_queue_runners().
  """
  if not num_epochs: num_epochs = None
  filename = os.path.join(FLAGS.train_dir,
                          TRAIN_FILE if train else VALIDATION_FILE)

  with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer(
        [filename], num_epochs=num_epochs)

    # Even when reading in multiple threads, share the filename
    # queue.
    image, label = read_and_decode(filename_queue)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.
    images, sparse_labels = tf.train.batch(
        [image, label], batch_size=batch_size, num_threads=10, capacity=60000)

    return images, sparse_labels 
开发者ID:yaroslavvb,项目名称:stuff,代码行数:39,代码来源:fully_connected_reader.py


注:本文中的tensorflow.examples.tutorials.mnist.mnist.training方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。