Python tensorflow.matching_files方法代码示例

本文整理汇总了Python中tensorflow.matching_files方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.matching_files方法的具体用法?Python tensorflow.matching_files怎么用?Python tensorflow.matching_files使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


示例1: input_fn

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def input_fn(self, name, csv_path=None):
        """Creates a dataset object for the model to consume. Input function for estimator

                name : string, Name of the data [Train or Eval]
                csv_path : The path of the csv on any storage system

                features : tf.data.TextLineDataset object, Dataset containing batch of features
                labels : tf.data.TextLineDataset object, Dataset containing batch of labels
        pattern = self._get_pattern(name, csv_path)
        tf.logging.info('The Pattern of files is : %s', pattern)
        filenames = tf.matching_files(pattern=pattern)
        dataset = tf.data.TextLineDataset(filenames).skip(1).map(
            self.parse_csv, num_parallel_calls=cpu_count())
        dataset = dataset.shuffle(buffer_size=self.batch_size * 100)
        dataset = dataset.apply(tf.contrib.data.ignore_errors())
        dataset = dataset.repeat(self.num_epochs)
        dataset = dataset.batch(self.batch_size)  # determine the ideal number
        dataset = dataset.prefetch(self.buffer_size)
        iterator = dataset.make_one_shot_iterator()
        feats, labs = iterator.get_next()
        return feats, labs 

示例2: get_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def get_dataset(tfrecords_dir, subset, batch_size):
    """Read TFRecords files and turn them into a TFRecordDataset."""
    files = tf.matching_files(os.path.join(tfrecords_dir, '%s-*' % subset))
    shards = tf.data.Dataset.from_tensor_slices(files)
    shards = shards.shuffle(tf.cast(tf.shape(files)[0], tf.int64))
    shards = shards.repeat()
    dataset = shards.interleave(tf.data.TFRecordDataset, cycle_length=4)
    dataset = dataset.shuffle(buffer_size=8192)
    parser = partial(
        _parse_fn, is_training=True if subset == 'train' else False)
    dataset = dataset.apply(
    dataset = dataset.prefetch(batch_size)
    return dataset 

示例3: read_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def read_dataset(file_read_func, decode_func, input_files, config):
  """Reads a dataset, and handles repetition and shuffling.

    file_read_func: Function to use in tf.data.Dataset.interleave, to read
      every individual file into a tf.data.Dataset.
    decode_func: Function to apply to all records.
    input_files: A list of file paths to read.
    config: A input_reader_builder.InputReader object.

    A tf.data.Dataset based on config.
  # Shard, shuffle, and read files.
  filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
  filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
  if config.shuffle:
    filename_dataset = filename_dataset.shuffle(
  filename_dataset = filename_dataset.repeat(config.num_epochs or None)

  records_dataset = filename_dataset.apply(
          file_read_func, cycle_length=config.num_readers, sloppy=True))
  if config.shuffle:
  tensor_dataset = records_dataset.map(
      decode_func, num_parallel_calls=config.num_parallel_map_calls)
  return tensor_dataset.prefetch(config.prefetch_size) 

示例4: read_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def read_dataset(file_read_func, decode_func, input_files, config):
  """Reads a dataset, and handles repetition and shuffling.

    file_read_func: Function to use in tf.data.Dataset.interleave, to read
      every individual file into a tf.data.Dataset.
    decode_func: Function to apply to all records.
    input_files: A list of file paths to read.
    config: A input_reader_builder.InputReader object.

    A tf.data.Dataset based on config.
  # Shard, shuffle, and read files.
  filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
  filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
  if config.shuffle:
    filename_dataset = filename_dataset.shuffle(
  elif config.num_readers > 1:
    tf.logging.warning('`shuffle` is false, but the input data stream is '
                       'still slightly shuffled since `num_readers` > 1.')

  filename_dataset = filename_dataset.repeat(config.num_epochs or None)

  records_dataset = filename_dataset.apply(
          file_read_func, cycle_length=config.num_readers,
          block_length=config.read_block_length, sloppy=True))
  if config.shuffle:
  tensor_dataset = records_dataset.map(
      decode_func, num_parallel_calls=config.num_parallel_map_calls)
  return tensor_dataset.prefetch(config.prefetch_size) 

示例5: testMatchingFiles

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def testMatchingFiles(self):
    cases = ['ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH',
             'AB4DEF.GH', 'ABDEF.GH', 'XYZ']
    files = [tempfile.NamedTemporaryFile(
        prefix=c, dir=self.get_temp_dir()) for c in cases]

    with self.test_session():
      # Test exact match without wildcards.
      for f in files:

      # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard.
      pos = files[0].name.find(cases[0])
      pattern = files[0].name[:pos] + 'AB%sDEF.GH*'

      self.assertEqual(set(tf.matching_files(pattern % 'z').eval()),
                       self._subset(files, [1]))
      self.assertEqual(set(tf.matching_files(pattern % '?').eval()),
                       self._subset(files, [0, 1, 3, 4]))
      self.assertEqual(set(tf.matching_files(pattern % '*').eval()),
                       self._subset(files, [0, 1, 2, 3, 4, 5]))
      self.assertEqual(set(tf.matching_files(pattern % '[cxz]').eval()),
                       self._subset(files, [0, 1]))
      self.assertEqual(set(tf.matching_files(pattern % '[0-9]').eval()),
                       self._subset(files, [3, 4])) 

示例6: input_fn

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def input_fn(input_dir, mode, batch_size, num_epochs, label_name=None,
             shuffle_buffer_size=10000, feature_spec=None):
    """Reads TFRecords and returns the features and labels."""
    if feature_spec is None:
        tf_transform_output = tft.TFTransformOutput(
            os.path.join(input_dir, 'transformed_metadata'))
        feature_spec = tf_transform_output.transformed_feature_spec()
    prefix = str(mode).lower()
    suffix = '.tfrecord'
    num_cpus = multiprocessing.cpu_count()

    file_pattern = os.path.join(input_dir, 'data', prefix, prefix+'*'+suffix)
    filenames = tf.matching_files(file_pattern)
    dataset = tf.data.TFRecordDataset(filenames=filenames, buffer_size=None,

    if mode == tf.estimator.ModeKeys.TRAIN:
        dataset = dataset.shuffle(shuffle_buffer_size)

    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(
        lambda examples: tf.parse_example(examples, feature_spec))
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()
    if mode == tf.estimator.ModeKeys.PREDICT:
        return features

    label = features.pop(label_name)
    return features, label 

示例7: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def __init__(self, in_pattern, batch_size, num_buckets=0, num_epochs=None):
        self._batch_size = batch_size
        self.num_buckets = num_buckets
        self._epoch = 0
        self._step = 1.
        self.num_epochs = num_epochs
        file_pattern = in_pattern + '/examples.proto' if os.path.isdir(in_pattern) else in_pattern
        filenames = tf.matching_files(file_pattern)
        # filenames = tf.Print(filenames, [filenames], message='filenames: ')
        self.next_batch_op = self.input_pipeline(filenames, self._batch_size, self.num_buckets, self.num_epochs) 

示例8: input_pipeline

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def input_pipeline(self, file_pattern, batch_size, num_epochs=None, num_threads=10):
        filenames = tf.matching_files(file_pattern)
        filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
        parsed_batch = self.example_parser(filename_queue)
        min_after_dequeue = 10000
        capacity = min_after_dequeue + 12 * batch_size
        next_batch = tf.train.batch(
                parsed_batch, batch_size=batch_size, capacity=capacity,
                num_threads=num_threads, dynamic_pad=True, allow_smaller_final_batch=True)
        return next_batch 

示例9: read_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def read_dataset(
    file_read_func, decode_func, input_files, config, num_workers=1,
  """Reads a dataset, and handles repetition and shuffling.

    file_read_func: Function to use in tf.data.Dataset.interleave, to read
      every individual file into a tf.data.Dataset.
    decode_func: Function to apply to all records.
    input_files: A list of file paths to read.
    config: A input_reader_builder.InputReader object.
    num_workers: Number of workers / shards.
    worker_index: Id for the current worker.

    A tf.data.Dataset based on config.
  # Shard, shuffle, and read files.
  filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
  dataset = tf.data.Dataset.from_tensor_slices(filenames)
  dataset = dataset.shard(num_workers, worker_index)
  dataset = dataset.repeat(config.num_epochs or None)
  if config.shuffle:
    dataset = dataset.shuffle(config.filenames_shuffle_buffer_size,

  # Read file records and shuffle them.
  # If cycle_length is larger than the number of files, more than one reader
  # will be assigned to the same file, leading to repetition.
  cycle_length = tf.cast(
      tf.minimum(config.num_readers, tf.size(filenames)), tf.int64)
  # TODO: find the optimal block_length.
  dataset = dataset.interleave(
      file_read_func, cycle_length=cycle_length, block_length=1)

  if config.shuffle:
    dataset = dataset.shuffle(config.shuffle_buffer_size,

  dataset = dataset.map(decode_func, num_parallel_calls=config.num_readers)
  return dataset.prefetch(config.prefetch_buffer_size) 

示例10: read_dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import matching_files [as 别名]
def read_dataset(
        file_read_func, decode_func, input_files, config, num_workers=1,
    """Reads a dataset, and handles repetition and shuffling.

      file_read_func: Function to use in tf.data.Dataset.interleave, to read
        every individual file into a tf.data.Dataset.
      decode_func: Function to apply to all records.
      input_files: A list of file paths to read.
      config: A input_reader_builder.InputReader object.
      num_workers: Number of workers / shards.
      worker_index: Id for the current worker.

      A tf.data.Dataset based on config.
    # Shard, shuffle, and read files.
    filenames = tf.concat([tf.matching_files(pattern) for pattern in input_files],
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.shard(num_workers, worker_index)
    dataset = dataset.repeat(config.num_epochs or None)
    if config.shuffle:
        dataset = dataset.shuffle(config.filenames_shuffle_buffer_size,

    # Read file records and shuffle them.
    # If cycle_length is larger than the number of files, more than one reader
    # will be assigned to the same file, leading to repetition.
    cycle_length = tf.cast(
        tf.minimum(config.num_readers, tf.size(filenames)), tf.int64)
    # TODO: find the optimal block_length.
    dataset = dataset.interleave(
        file_read_func, cycle_length=cycle_length, block_length=1)

    if config.shuffle:
        dataset = dataset.shuffle(config.shuffle_buffer_size,

    dataset = dataset.map(decode_func, num_parallel_calls=config.num_readers)
    return dataset.prefetch(config.prefetch_buffer_size) 
