当前位置: 首页>>代码示例>>Python>>正文


Python input.string_input_producer函数代码示例

本文整理汇总了Python中tensorflow.python.training.input.string_input_producer函数的典型用法代码示例。如果您正苦于以下问题:Python string_input_producer函数的具体用法?Python string_input_producer怎么用?Python string_input_producer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了string_input_producer函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _apply_transform

  def _apply_transform(self, transform_input):
    filename_queue = input_ops.string_input_producer(self._work_units,
                                                     shuffle=self.shuffle,
                                                     seed=self._seed)

    if self.shuffle:
      queue = data_flow_ops.RandomShuffleQueue(
          capacity=self.queue_capacity,
          min_after_dequeue=self.min_after_dequeue,
          dtypes=[dtypes.string, dtypes.string],
          shapes=[[], []],
          seed=self.seed)
    else:
      queue = data_flow_ops.FIFOQueue(capacity=self.queue_capacity,
                                      dtypes=[dtypes.string, dtypes.string],
                                      shapes=[[], []])

    enqueue_ops = []
    for _ in range(self.num_threads):
      reader = self._reader_cls(**self._reader_kwargs)
      enqueue_ops.append(queue.enqueue(reader.read(filename_queue)))

    runner = queue_runner.QueueRunner(queue, enqueue_ops)
    queue_runner.add_queue_runner(runner)
    dequeued = queue.dequeue_many(self.batch_size)

    # pylint: disable=not-callable
    return self.return_type(*dequeued)
开发者ID:0ruben,项目名称:tensorflow,代码行数:28,代码来源:reader_source.py

示例2: _apply_transform

  def _apply_transform(self, transform_input):
    filename_queue = input_ops.string_input_producer(self.work_units,
                                                     num_epochs=self.num_epochs,
                                                     shuffle=self.shuffle,
                                                     seed=self.seed)
    reader_ops = []
    for _ in range(self.num_threads):
      reader = self._reader_cls(**self._reader_kwargs)
      reader_ops.append(reader.read_up_to(filename_queue, self.enqueue_size))

    if self.shuffle:
      dequeued = input_ops.shuffle_batch_join(
          reader_ops,
          self.batch_size,
          capacity=self.queue_capacity,
          min_after_dequeue=self.min_after_dequeue,
          seed=self.seed,
          enqueue_many=True,
          shared_name=None,
          name=None)
    else:
      dequeued = input_ops.batch_join(reader_ops,
                                      self.batch_size,
                                      capacity=self.queue_capacity,
                                      enqueue_many=True,
                                      dynamic_pad=False,
                                      shared_name=None,
                                      name=None)

    # pylint: disable=not-callable
    return self.return_type(*dequeued)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:31,代码来源:reader_source.py

示例3: parallel_read

def parallel_read(data_sources,
                  reader_class,
                  num_epochs=None,
                  num_readers=4,
                  reader_kwargs=None,
                  shuffle=True,
                  dtypes=None,
                  capacity=256,
                  min_after_dequeue=128):
  """Reads multiple records in parallel from data_sources using n readers.

  It uses a ParallelReader to read from multiple files in  parallel using
  multiple readers created using `reader_class` with `reader_kwargs'.

  If shuffle is True the common_queue would be a RandomShuffleQueue otherwise
  it would be a FIFOQueue.

  Usage:
      data_sources = ['path_to/train*']
      key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /cns/../[email protected], /cns/.../train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader
    num_epochs: The number of times each data source is read. If left as None,
        the data will be cycled through indefinitely.
    num_readers: a integer, number of Readers to create.
    reader_kwargs: an optional dict, of kwargs for the reader.
    shuffle: boolean, wether should shuffle the files and the records by using
      RandomShuffleQueue as common_queue.
    dtypes:  A list of types.  The length of dtypes must equal the number
        of elements in each record. If it is None it will default to
        [tf.string, tf.string] for (key, value).
    capacity: integer, capacity of the common_queue.
    min_after_dequeue: integer, minimum number of records in the common_queue
      after dequeue. Needed for a good shuffle.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
  data_files = get_data_files(data_sources)
  with ops.name_scope('parallel_read'):
    filename_queue = tf_input.string_input_producer(
        data_files, num_epochs=num_epochs, shuffle=shuffle)
    dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
    if shuffle:
      common_queue = data_flow_ops.RandomShuffleQueue(
          capacity=capacity,
          min_after_dequeue=min_after_dequeue,
          dtypes=dtypes)
    else:
      common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes)

    return ParallelReader(reader_class,
                          common_queue,
                          num_readers=num_readers,
                          reader_kwargs=reader_kwargs).read(filename_queue)
开发者ID:0ruben,项目名称:tensorflow,代码行数:58,代码来源:parallel_reader.py

示例4: _verify_read_up_to_out

  def _verify_read_up_to_out(self, shared_queue):
    with self.test_session():
      num_files = 3
      num_records_per_file = 7
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(),
          num_files=num_files,
          num_records_per_file=num_records_per_file)

    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=5)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files, num_epochs=1)
    key, value = p_reader.read_up_to(filename_queue, 4)

    count0 = 0
    count1 = 0
    count2 = 0
    all_keys_count = 0
    all_values_count = 0

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)
      while True:
        try:
          current_keys, current_values = sess.run([key, value])
          self.assertEquals(len(current_keys), len(current_values))
          all_keys_count += len(current_keys)
          all_values_count += len(current_values)
          for current_key in current_keys:
            if '0-of-3' in str(current_key):
              count0 += 1
            if '1-of-3' in str(current_key):
              count1 += 1
            if '2-of-3' in str(current_key):
              count2 += 1
        except errors_impl.OutOfRangeError:
          break

    self.assertEquals(count0, num_records_per_file)
    self.assertEquals(count1, num_records_per_file)
    self.assertEquals(count2, num_records_per_file)
    self.assertEquals(
        all_keys_count,
        num_files * num_records_per_file)
    self.assertEquals(all_values_count, all_keys_count)
    self.assertEquals(
        count0 + count1 + count2,
        all_keys_count)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:51,代码来源:parallel_reader_test.py

示例5: _get_shared_file_name_queue

def _get_shared_file_name_queue(file_names, shuffle, num_epochs, name):
  # Creating a dummy variable so we can put the shared queue in ps if there is
  # a PS and in the worker otherwise. TODO(rohanj): Figure out how to place an
  # op on PS without this hack
  dummy_var = var_ops.Variable(initial_value=0, name='queue_placement_var')
  with ops.device(dummy_var.device):
    shared_file_name_queue = input_ops.string_input_producer(
        constant_op.constant(
            file_names, name='input'),
        shuffle=shuffle,
        num_epochs=num_epochs,
        name=name,
        shared_name=name)
    return shared_file_name_queue
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:14,代码来源:graph_io.py

示例6: testManagedEndOfInputOneQueue

 def testManagedEndOfInputOneQueue(self):
   # Tests that the supervisor finishes without an error when using
   # a fixed number of epochs, reading from a single queue.
   logdir = self._test_dir("managed_end_of_input_one_queue")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with ops.Graph().as_default():
     # Create an input pipeline that reads the file 3 times.
     filename_queue = input_lib.string_input_producer(
         [data_path], num_epochs=3)
     reader = io_ops.TextLineReader()
     _, csv = reader.read(filename_queue)
     rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
     sv = supervisor.Supervisor(logdir=logdir)
     with sv.managed_session("") as sess:
       while not sv.should_stop():
         sess.run(rec)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:17,代码来源:supervisor_test.py

示例7: single_pass_read

def single_pass_read(data_sources, reader_class, reader_kwargs=None):
    """Reads sequentially the data_sources using the reader, doing a single pass.

  Args:
    data_sources: a list/tuple of files or the location of the data, i.e.
      /path/to/[email protected], /path/to/train* or /tmp/.../train*
    reader_class: one of the io_ops.ReaderBase subclasses ex: TFRecordReader.
    reader_kwargs: an optional dict, of kwargs for the reader.

  Returns:
    key, value: a tuple of keys and values from the data_source.
  """
    data_files = get_data_files(data_sources)
    with ops.name_scope("single_pass_read"):
        filename_queue = tf_input.string_input_producer(data_files, num_epochs=1, shuffle=False, capacity=1)
        reader_kwargs = reader_kwargs or {}
        return reader_class(**reader_kwargs).read(filename_queue)
开发者ID:brchiu,项目名称:tensorflow,代码行数:17,代码来源:parallel_reader.py

示例8: testReadFromSameFile

  def testReadFromSameFile(self):
    with self.test_session() as sess:
      reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
      reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key1, value1 = reader1.read(filename_queue)
      key2, value2 = reader2.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      for _ in range(3):
        for _ in range(10):
          k1, v1, k2, v2 = sess.run([key1, value1, key2, value2])
          self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
          self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
      coord.request_stop()
      coord.join(threads)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:18,代码来源:reader_ops_test.py

示例9: testReadFromFileRepeatedly

  def testReadFromFileRepeatedly(self):
    with self.test_session() as sess:
      reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
      filename_queue = input_lib.string_input_producer(
          [self.db_path], num_epochs=None)
      key, value = reader.read(filename_queue)

      coord = coordinator.Coordinator()
      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
      # Iterate over the lmdb 3 times.
      for _ in range(3):
        # Go over all 10 records each time.
        for j in range(10):
          k, v = sess.run([key, value])
          self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
          self.assertAllEqual(
              compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
      coord.request_stop()
      coord.join(threads)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:19,代码来源:reader_ops_test.py

示例10: _get_filename_queue

  def _get_filename_queue(self, epoch_limit):
    """Constructs a filename queue with an epoch limit.

    `epoch_limit` is intended as an error checking fallback to prevent a reader
    from infinitely looping in its requests for more work items if none are
    available in any file. It should be set high enough that it is never reached
    assuming at least one record exists in some file.

    Args:
      epoch_limit: The maximum number of times to read through the complete list
        of files before throwing an OutOfRangeError.

    Returns:
      A tuple of (filename_queue, epoch_limiter):
        filename_queue: A FIFOQueue with filename work items.
        epoch_limiter: The local variable used for epoch limitation. This should
          be set to zero before a reader is passed `filename_queue` in order to
          reset the epoch limiter's state.
    """
    epoch_limiter = variable_scope.variable(
        initial_value=constant_op.constant(0, dtype=dtypes.int64),
        name="epoch_limiter",
        trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])
    filenames_tensor = array_ops.reshape(
        ops.convert_to_tensor(self._filenames), [-1])
    # We can't rely on epoch_limiter being initialized, since queue runners are
    # started before local variables are initialized. Instead, we ignore epoch
    # limits before variable initialization. This means that prior to variable
    # initialization, a QueueRunner may cause a reader to enter an un-checked
    # infinite loop. However, as soon as local variables are initialized, we
    # will start incrementing and checking epoch_limiter, which will interrupt
    # any in-progress loops.
    conditional_count_up_to = control_flow_ops.cond(
        state_ops.is_variable_initialized(
            epoch_limiter), lambda: epoch_limiter.count_up_to(epoch_limit),
        lambda: constant_op.constant(0, dtype=dtypes.int64))
    with ops.control_dependencies([conditional_count_up_to]):
      filenames_tensor = array_ops.identity(filenames_tensor)
    filename_queue = input_lib.string_input_producer(
        filenames_tensor, shuffle=False, capacity=1)
    return filename_queue, epoch_limiter
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:42,代码来源:input_pipeline.py

示例11: testManagedMainErrorTwoQueues

 def testManagedMainErrorTwoQueues(self):
   # Tests that the supervisor correctly raises a main loop
   # error even when using multiple queues for input.
   logdir = self._test_dir("managed_main_error_two_queues")
   os.makedirs(logdir)
   data_path = self._csv_data(logdir)
   with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
     with ops.Graph().as_default():
       # Create an input pipeline that reads the file 3 times.
       filename_queue = input_lib.string_input_producer(
           [data_path], num_epochs=3)
       reader = io_ops.TextLineReader()
       _, csv = reader.read(filename_queue)
       rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
       shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
       sv = supervisor.Supervisor(logdir=logdir)
       with sv.managed_session("") as sess:
         for step in range(9):
           if sv.should_stop():
             break
           elif step == 3:
             raise RuntimeError("fail at step 3")
           else:
             sess.run(shuff_rec)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:24,代码来源:supervisor_test.py

示例12: _verify_all_data_sources_read

  def _verify_all_data_sources_read(self, shared_queue):
    with self.test_session():
      tfrecord_paths = test_utils.create_tfrecord_files(
          self.get_temp_dir(), num_files=3)

    num_readers = len(tfrecord_paths)
    p_reader = parallel_reader.ParallelReader(
        io_ops.TFRecordReader, shared_queue, num_readers=num_readers)

    data_files = parallel_reader.get_data_files(tfrecord_paths)
    filename_queue = input_lib.string_input_producer(data_files)
    key, value = p_reader.read(filename_queue)

    count0 = 0
    count1 = 0
    count2 = 0

    num_reads = 50

    sv = supervisor.Supervisor(logdir=self.get_temp_dir())
    with sv.prepare_or_wait_for_session() as sess:
      sv.start_queue_runners(sess)

      for _ in range(num_reads):
        current_key, _ = sess.run([key, value])
        if '0-of-3' in str(current_key):
          count0 += 1
        if '1-of-3' in str(current_key):
          count1 += 1
        if '2-of-3' in str(current_key):
          count2 += 1

    self.assertGreater(count0, 0)
    self.assertGreater(count1, 0)
    self.assertGreater(count2, 0)
    self.assertEquals(count0 + count1 + count2, num_reads)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:36,代码来源:parallel_reader_test.py

示例13: read_keyed_batch_examples

def read_keyed_batch_examples(
    file_pattern, batch_size, reader,
    randomize_input=True, num_epochs=None,
    queue_capacity=10000, num_threads=1,
    read_batch_size=1, parse_fn=None,
    name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Use `parse_fn` if you need to do parsing / processing on single examples.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto. If `keep_keys` is True, then
    returns tuple of string `Tensor`s, where first value is the key.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrive files to read.
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and
      (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
    # Setup filename queue with shuffling.
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      file_name_queue = input_ops.string_input_producer(
          constant_op.constant(file_names, name='input'),
          shuffle=randomize_input, num_epochs=num_epochs,
          name=file_name_queue_scope)

    # Create readers, one per thread and set them to read from filename queue.
    with ops.name_scope('read'):
      example_list = []
      for _ in range(num_threads):
        if read_batch_size > 1:
          keys, examples_proto = reader().read_up_to(file_name_queue,
                                                     read_batch_size)
        else:
          keys, examples_proto = reader().read(file_name_queue)
        if parse_fn:
          parsed_examples = parse_fn(examples_proto)
          # Map keys into example map because batch_join doesn't support
          # tuple of Tensor + dict.
          if isinstance(parsed_examples, dict):
#.........这里部分代码省略.........
开发者ID:285219011,项目名称:hello-world,代码行数:101,代码来源:graph_io.py

示例14: _read_keyed_batch_examples_helper

def _read_keyed_batch_examples_helper(file_pattern,
                                      batch_size,
                                      reader,
                                      randomize_input=True,
                                      num_epochs=None,
                                      queue_capacity=10000,
                                      num_threads=1,
                                      read_batch_size=1,
                                      parse_fn=None,
                                      setup_shared_queue=False,
                                      name=None):
  """Adds operations to read, queue, batch `Example` protos.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    num_epochs: Integer specifying the number of times to read through the
      dataset. If `None`, cycles through the dataset forever.
      NOTE - If specified, creates a variable that must be initialized, so call
      `tf.initialize_all_variables()` as shown in the tests.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    read_batch_size: An int or scalar `Tensor` specifying the number of
      records to read at once
    parse_fn: Parsing function, takes `Example` Tensor returns parsed
      representation. If `None`, no parsing is done.
    setup_shared_queue: Whether to set up a shared queue for file names.
    name: Name of resulting op.

  Returns:
    Returns tuple of:
    - `Tensor` of string keys.
    - String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrieve files to read.
  file_names = _get_file_names(file_pattern, randomize_input)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (read_batch_size is None) or (
      (not isinstance(read_batch_size, ops.Tensor)) and
      (read_batch_size <= 0)):
    raise ValueError('Invalid read_batch_size %s.' % read_batch_size)
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)
  if (num_epochs is not None) and (num_epochs <= 0):
    raise ValueError('Invalid num_epochs %s.' % num_epochs)

  with ops.name_scope(name, 'read_batch_examples', [file_pattern]) as scope:
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      if setup_shared_queue:
        shared_file_name_queue = _get_shared_file_name_queue(
            file_names, randomize_input, num_epochs, file_name_queue_scope)
        file_name_queue = data_flow_ops.FIFOQueue(
            capacity=1, dtypes=[dtypes.string], shapes=[[]])
        enqueue_op = file_name_queue.enqueue(shared_file_name_queue.dequeue())
        queue_runner.add_queue_runner(
            queue_runner.QueueRunner(file_name_queue, [enqueue_op]))
      else:
        file_name_queue = input_ops.string_input_producer(
            constant_op.constant(
                file_names, name='input'),
            shuffle=randomize_input,
            num_epochs=num_epochs,
            name=file_name_queue_scope)

    example_list = _get_examples(file_name_queue, reader, num_threads,
                                 read_batch_size, parse_fn)

    enqueue_many = read_batch_size > 1

    if num_epochs is None:
      allow_smaller_final_batch = False
    else:
      allow_smaller_final_batch = True

    # Setup batching queue given list of read example tensors.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      queued_examples_with_keys = input_ops.shuffle_batch_join(
          example_list, batch_size, capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=enqueue_many, name=scope,
#.........这里部分代码省略.........
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:101,代码来源:graph_io.py

示例15: read_batch_examples

def read_batch_examples(file_pattern, batch_size, reader,
                        randomize_input=True, queue_capacity=10000,
                        num_threads=1, name='dequeue_examples'):
  """Adds operations to read, queue, batch `Example` protos.

  Given file pattern (or list of files), will setup a queue for file names,
  read `Example` proto using provided `reader`, use batch queue to create
  batches of examples of size `batch_size`.

  All queue runners are added to the queue runners collection, and may be
  started via `start_queue_runners`.

  All ops are added to the default graph.

  Args:
    file_pattern: List of files or pattern of file paths containing
        `Example` records. See `tf.gfile.Glob` for pattern rules.
    batch_size: An int or scalar `Tensor` specifying the batch size to use.
    reader: A function or class that returns an object with
      `read` method, (filename tensor) -> (example tensor).
    randomize_input: Whether the input should be randomized.
    queue_capacity: Capacity for input queue.
    num_threads: The number of threads enqueuing examples.
    name: Name of resulting op.

  Returns:
    String `Tensor` of batched `Example` proto.

  Raises:
    ValueError: for invalid inputs.
  """
  # Retrive files to read.
  if isinstance(file_pattern, list):
    file_names = file_pattern
    if not file_names:
      raise ValueError('No files given to dequeue_examples.')
  else:
    file_names = list(gfile.Glob(file_pattern))
    if not file_names:
      raise ValueError('No files match %s.' % file_pattern)

  # Sort files so it will be deterministic for unit tests. They'll be shuffled
  # in `string_input_producer` if `randomize_input` is enabled.
  if not randomize_input:
    file_names = sorted(file_names)

  # Check input parameters are given and reasonable.
  if (not queue_capacity) or (queue_capacity <= 0):
    raise ValueError('Invalid queue_capacity %s.' % queue_capacity)
  if (batch_size is None) or (
      (not isinstance(batch_size, ops.Tensor)) and
      (batch_size <= 0 or batch_size > queue_capacity)):
    raise ValueError(
        'Invalid batch_size %s, with queue_capacity %s.' %
        (batch_size, queue_capacity))
  if (not num_threads) or (num_threads <= 0):
    raise ValueError('Invalid num_threads %s.' % num_threads)

  with ops.name_scope(name) as scope:
    # Setup filename queue with shuffling.
    with ops.name_scope('file_name_queue') as file_name_queue_scope:
      file_name_queue = input_ops.string_input_producer(
          constant_op.constant(file_names, name='input'),
          shuffle=randomize_input, name=file_name_queue_scope)

    # Create reader and set it to read from filename queue.
    with ops.name_scope('read'):
      _, example_proto = reader().read(file_name_queue)

    # Setup batching queue.
    if randomize_input:
      if isinstance(batch_size, ops.Tensor):
        min_after_dequeue = int(queue_capacity * 0.4)
      else:
        min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
      examples = input_ops.shuffle_batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, min_after_dequeue=min_after_dequeue,
          name=scope)
    else:
      examples = input_ops.batch(
          [example_proto], batch_size, capacity=queue_capacity,
          num_threads=num_threads, name=scope)

    return examples
开发者ID:01-,项目名称:tensorflow,代码行数:85,代码来源:graph_io.py


注:本文中的tensorflow.python.training.input.string_input_producer函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。