当前位置: 首页>>代码示例>>Python>>正文


Python preprocessing.shuffle_tf_examples方法代码示例

本文整理汇总了Python中preprocessing.shuffle_tf_examples方法的典型用法代码示例。如果您正苦于以下问题:Python preprocessing.shuffle_tf_examples方法的具体用法?Python preprocessing.shuffle_tf_examples怎么用?Python preprocessing.shuffle_tf_examples使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在preprocessing的用法示例。


在下文中一共展示了preprocessing.shuffle_tf_examples方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_serialize_round_trip_no_parse

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def test_serialize_round_trip_no_parse(self):
        np.random.seed(1)
        raw_data = self.create_random_data(10)
        tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

        with tempfile.NamedTemporaryFile() as start_file, \
                tempfile.NamedTemporaryFile() as rewritten_file:
            preprocessing.write_tf_examples(start_file.name, tfexamples)
            # We want to test that the rewritten, shuffled file contains correctly
            # serialized tf.Examples.
            batch_size = 4
            batches = list(preprocessing.shuffle_tf_examples(
                batch_size, [start_file.name]))
            # 2 batches of 4, 1 incomplete batch of 2.
            self.assertEqual(len(batches), 3)

            # concatenate list of lists into one list
            all_batches = list(itertools.chain.from_iterable(batches))

            for batch in batches:
                preprocessing.write_tf_examples(
                    rewritten_file.name, all_batches, serialize=False)

            original_data = self.extract_data(start_file.name)
            recovered_data = self.extract_data(rewritten_file.name)

        # stuff is shuffled, so sort before checking equality
        def sort_key(nparray_tuple): return nparray_tuple[2]
        original_data = sorted(original_data, key=sort_key)
        recovered_data = sorted(recovered_data, key=sort_key)

        self.assertEqualData(original_data, recovered_data) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:34,代码来源:test_preprocessing.py

示例2: test_serialize_round_trip_no_parse

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def test_serialize_round_trip_no_parse(self):
    np.random.seed(1)
    raw_data = self.create_random_data(10)
    tfexamples = list(map(preprocessing.make_tf_example, *zip(*raw_data)))

    with tempfile.NamedTemporaryFile() as start_file, \
        tempfile.NamedTemporaryFile() as rewritten_file:
      preprocessing.write_tf_examples(start_file.name, tfexamples)
      # We want to test that the rewritten, shuffled file contains correctly
      # serialized tf.Examples.
      batch_size = 4
      batches = list(preprocessing.shuffle_tf_examples(
          1000, batch_size, [start_file.name]))
      # 2 batches of 4, 1 incomplete batch of 2.
      self.assertEqual(len(batches), 3)

      # concatenate list of lists into one list
      all_batches = list(itertools.chain.from_iterable(batches))

      for _ in batches:
        preprocessing.write_tf_examples(
            rewritten_file.name, all_batches, serialize=False)

      original_data = self.extract_data(start_file.name)
      recovered_data = self.extract_data(rewritten_file.name)

    # stuff is shuffled, so sort before checking equality
    def sort_key(nparray_tuple):
      return nparray_tuple[2]
    original_data = sorted(original_data, key=sort_key)
    recovered_data = sorted(recovered_data, key=sort_key)

    self.assertEqualData(original_data, recovered_data) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:35,代码来源:preprocessing_test.py

示例3: gather

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def gather(
        input_directory: 'where to look for games'='data/selfplay/',
        output_directory: 'where to put collected games'='data/training_chunks/',
        examples_per_record: 'how many tf.examples to gather in each chunk'=EXAMPLES_PER_RECORD):
    qmeas.start_time('gather')
    _ensure_dir_exists(output_directory)
    models = [model_dir.strip('/')
              for model_dir in sorted(gfile.ListDirectory(input_directory))[-50:]]
    with timer("Finding existing tfrecords..."):
        model_gamedata = {
            model: gfile.Glob(
                os.path.join(input_directory, model, '*.tfrecord.zz'))
            for model in models
        }
    print("Found %d models" % len(models))
    for model_name, record_files in sorted(model_gamedata.items()):
        print("    %s: %s files" % (model_name, len(record_files)))

    meta_file = os.path.join(output_directory, 'meta.txt')
    try:
        with gfile.GFile(meta_file, 'r') as f:
            already_processed = set(f.read().split())
    except tf.errors.NotFoundError:
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
            continue
        print("Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(examples_per_record, record_files))):
            output_record = os.path.join(output_directory,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
            preprocessing.write_tf_examples(
                output_record, example_batch, serialize=False)
        already_processed.update(record_files)

    print("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
        f.write('\n'.join(sorted(already_processed)))
    qmeas.stop_time('gather') 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:46,代码来源:main.py

示例4: gather

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def gather(selfplay_dir, training_chunk_dir, params):
  """Gather selfplay data into large training chunk.

  Args:
    selfplay_dir: Where to look for games. Set as 'base_dir/data/selfplay/'.
    training_chunk_dir: where to put collected games. Set as
      'base_dir/data/training_chunks/'.
    params: An object of hyperparameters for the model.
  """
  # Check the selfplay data from the most recent 50 models.
  _ensure_dir_exists(training_chunk_dir)
  sorted_model_dirs = sorted(tf.gfile.ListDirectory(selfplay_dir))
  models = [model_dir.strip('/')
            for model_dir in sorted_model_dirs[-params.gather_generation:]]

  with utils.logged_timer('Finding existing tfrecords...'):
    model_gamedata = {
        model: tf.gfile.Glob(
            os.path.join(selfplay_dir, model, '*'+_TF_RECORD_SUFFIX))
        for model in models
    }
  print('Found {} models'.format(len(models)))
  for model_name, record_files in sorted(model_gamedata.items()):
    print('    {}: {} files'.format(model_name, len(record_files)))

  meta_file = os.path.join(training_chunk_dir, 'meta.txt')
  try:
    with tf.gfile.GFile(meta_file, 'r') as f:
      already_processed = set(f.read().split())
  except tf.errors.NotFoundError:
    already_processed = set()

  num_already_processed = len(already_processed)

  for model_name, record_files in sorted(model_gamedata.items()):
    if set(record_files) <= already_processed:
      continue
    print('Gathering files from {}:'.format(model_name))
    tf_examples = preprocessing.shuffle_tf_examples(
        params.shuffle_buffer_size, params.examples_per_chunk, record_files)
    # tqdm to make the loops show a smart progress meter
    for i, example_batch in enumerate(tf_examples):
      output_record = os.path.join(
          training_chunk_dir,
          ('{}-{}'+_TF_RECORD_SUFFIX).format(model_name, str(i)))
      preprocessing.write_tf_examples(
          output_record, example_batch, serialize=False)
    already_processed.update(record_files)

  print('Processed {} new files'.format(
      len(already_processed) - num_already_processed))
  with tf.gfile.GFile(meta_file, 'w') as f:
    f.write('\n'.join(sorted(already_processed))) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:55,代码来源:minigo.py

示例5: aggregate

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def aggregate():
    logger.info("Gathering game results")

    os.makedirs(PATHS.TRAINING_CHUNK_DIR, exist_ok=True)
    os.makedirs(PATHS.SELFPLAY_DIR, exist_ok=True)
    models = [model_dir.strip('/')
              for model_dir in sorted(gfile.ListDirectory(PATHS.SELFPLAY_DIR))[-50:]]

    with timer("Finding existing tfrecords..."):
        model_gamedata = {
            model: gfile.Glob(
                os.path.join(PATHS.SELFPLAY_DIR, model, '*.zz'))
            for model in models
        }
    logger.info("Found %d models" % len(models))
    for model_name, record_files in sorted(model_gamedata.items()):
        logger.info("    %s: %s files" % (model_name, len(record_files)))

    meta_file = os.path.join(PATHS.TRAINING_CHUNK_DIR, 'meta.txt')
    try:
        with gfile.GFile(meta_file, 'r') as f:
            already_processed = set(f.read().split())
    except tf.errors.NotFoundError:
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
            continue
        logger.info("Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(GLOBAL_PARAMETER_STORE.EXAMPLES_PER_RECORD, record_files))):
            output_record = os.path.join(PATHS.TRAINING_CHUNK_DIR,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
            preprocessing.write_tf_examples(
                output_record, example_batch, serialize=False)
        already_processed.update(record_files)

    logger.info("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
        f.write('\n'.join(sorted(already_processed))) 
开发者ID:PacktPublishing,项目名称:Python-Reinforcement-Learning-Projects,代码行数:45,代码来源:controller.py

示例6: gather

# 需要导入模块: import preprocessing [as 别名]
# 或者: from preprocessing import shuffle_tf_examples [as 别名]
def gather(selfplay_dir, training_chunk_dir, params):
  """Gather selfplay data into large training chunk.

  Args:
    selfplay_dir: Where to look for games. Set as 'base_dir/data/selfplay/'.
    training_chunk_dir: where to put collected games. Set as
      'base_dir/data/training_chunks/'.
    params: A MiniGoParams instance of hyperparameters for the model.
  """
  # Check the selfplay data from the most recent 50 models.
  _ensure_dir_exists(training_chunk_dir)
  sorted_model_dirs = sorted(tf.gfile.ListDirectory(selfplay_dir))
  models = [model_dir.strip('/')
            for model_dir in sorted_model_dirs[-params.gather_generation:]]

  with utils.logged_timer('Finding existing tfrecords...'):
    model_gamedata = {
        model: tf.gfile.Glob(
            os.path.join(selfplay_dir, model, '*'+_TF_RECORD_SUFFIX))
        for model in models
    }
  print('Found {} models'.format(len(models)))
  for model_name, record_files in sorted(model_gamedata.items()):
    print('    {}: {} files'.format(model_name, len(record_files)))

  meta_file = os.path.join(training_chunk_dir, 'meta.txt')
  try:
    with tf.gfile.GFile(meta_file, 'r') as f:
      already_processed = set(f.read().split())
  except tf.errors.NotFoundError:
    already_processed = set()

  num_already_processed = len(already_processed)

  for model_name, record_files in sorted(model_gamedata.items()):
    if set(record_files) <= already_processed:
      continue
    print('Gathering files from {}:'.format(model_name))
    tf_examples = preprocessing.shuffle_tf_examples(
        params.shuffle_buffer_size, params.examples_per_chunk, record_files)
    # tqdm to make the loops show a smart progress meter
    for i, example_batch in enumerate(tf_examples):
      output_record = os.path.join(
          training_chunk_dir,
          ('{}-{}'+_TF_RECORD_SUFFIX).format(model_name, str(i)))
      preprocessing.write_tf_examples(
          output_record, example_batch, serialize=False)
    already_processed.update(record_files)

  print('Processed {} new files'.format(
      len(already_processed) - num_already_processed))
  with tf.gfile.GFile(meta_file, 'w') as f:
    f.write('\n'.join(sorted(already_processed))) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:55,代码来源:minigo.py


注:本文中的preprocessing.shuffle_tf_examples方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。