当前位置: 首页>>代码示例>>Python>>正文


Python tf_record.TFRecordWriter方法代码示例

本文整理汇总了Python中tensorflow.python.lib.io.tf_record.TFRecordWriter方法的典型用法代码示例。如果您正苦于以下问题:Python tf_record.TFRecordWriter方法的具体用法?Python tf_record.TFRecordWriter怎么用?Python tf_record.TFRecordWriter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.lib.io.tf_record的用法示例。


在下文中一共展示了tf_record.TFRecordWriter方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testKmeans

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def testKmeans(self):
    num_features = FLAGS.patch_height * FLAGS.patch_width
    dummy_data = np.random.random((500, num_features))
    with tempfile.NamedTemporaryFile(mode='r') as patches_file:
      with tf_record.TFRecordWriter(patches_file.name) as patches_writer:
        for patch in dummy_data:
          example = example_pb2.Example()
          example.features.feature['features'].float_list.value.extend(patch)
          patches_writer.write(example.SerializeToString())
      clusters = staffline_patches_kmeans_pipeline.train_kmeans(
          patches_file.name,
          NUM_CLUSTERS,
          BATCH_SIZE,
          TRAIN_STEPS,
          min_eval_frequency=0)
      self.assertEqual(clusters.shape, (NUM_CLUSTERS, num_features)) 
开发者ID:tensorflow,项目名称:moonlight,代码行数:18,代码来源:staffline_patches_kmeans_pipeline_test.py

示例2: write_test_data

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def write_test_data(example_proto,
                        schema,
                        schema_filename="schema.pb"):
        tmp_dir = tf.test.get_temp_dir()
        schema_path = pjoin(tmp_dir, schema_filename)
        with open(schema_path, "wb") as f:
            f.write(schema.SerializeToString())
        data_file = pjoin(tmp_dir, "test.tfrecord")
        with TFRecordWriter(data_file) as f:
            for i in example_proto:
                f.write(i.SerializeToString())
        return data_file, schema_path 
开发者ID:spotify,项目名称:spotify-tensorflow,代码行数:14,代码来源:dataset_test.py

示例3: create_tfrecord_files

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def create_tfrecord_files(output_dir, num_files=3, num_records_per_file=10):
  """Creates TFRecords files.

  The method must be called within an active session.

  Args:
    output_dir: The directory where the files are stored.
    num_files: The number of files to create.
    num_records_per_file: The number of records per file.

  Returns:
    A list of the paths to the TFRecord files.
  """
  tfrecord_paths = []
  for i in range(num_files):
    path = os.path.join(output_dir,
                        'flowers.tfrecord-%d-of-%s' % (i, num_files))
    tfrecord_paths.append(path)

    writer = tf_record.TFRecordWriter(path)
    for _ in range(num_records_per_file):
      _, example = generate_image(image_shape=(10, 10, 3))
      writer.write(example)
    writer.close()

  return tfrecord_paths 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:28,代码来源:test_utils.py

示例4: save_rows_to_tf_record_file

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def save_rows_to_tf_record_file(df_rows, make_sequence_example_fn, export_filename):
    tf_record_options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)

    tf_writer = tf_record.TFRecordWriter(export_filename, options=tf_record_options)
    try:
        for index, row in df_rows.iterrows():
            seq_example = make_sequence_example_fn(row)
            tf_writer.write(seq_example.SerializeToString())
    finally:
        tf_writer.close()
        sys.stdout.flush() 
开发者ID:gabrielspmoreira,项目名称:chameleon_recsys,代码行数:13,代码来源:tf_records_management.py

示例5: save_rows_to_tf_record_file

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def save_rows_to_tf_record_file(rows, make_sequence_example_fn, export_filename):
    tf_record_options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)

    tf_writer = tf_record.TFRecordWriter(export_filename, options=tf_record_options)
    try:
        for row in rows:
            seq_example = make_sequence_example_fn(row)
            tf_writer.write(seq_example.SerializeToString())
    finally:
        tf_writer.close()
        sys.stdout.flush() 
开发者ID:gabrielspmoreira,项目名称:chameleon_recsys,代码行数:13,代码来源:tf_records_management.py

示例6: testInputFn

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def testInputFn(self):
    with tempfile.NamedTemporaryFile() as records_file:
      with tf_record.TFRecordWriter(records_file.name) as records_writer:
        flags.FLAGS.augmentation_x_shift_probability = 0
        flags.FLAGS.augmentation_max_rotation_degrees = 0
        example = tf.train.Example()
        height = 5
        width = 3
        example.features.feature['height'].int64_list.value.append(height)
        example.features.feature['width'].int64_list.value.append(width)
        example.features.feature['patch'].float_list.value.extend(
            range(height * width))
        label = 1
        example.features.feature['label'].int64_list.value.append(label)
        for _ in range(3):
          records_writer.write(example.SerializeToString())

      flags.FLAGS.train_input_patches = records_file.name
      batch_tensors = glyph_patches.input_fn(records_file.name)

      with self.test_session() as sess:
        batch = sess.run(batch_tensors)

        self.assertAllEqual(
            batch[0]['patch'],
            np.arange(height * width).reshape(
                (1, height, width)).repeat(3, axis=0))
        self.assertAllEqual(batch[1], [label, label, label]) 
开发者ID:tensorflow,项目名称:moonlight,代码行数:30,代码来源:glyph_patches_test.py

示例7: main

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def main(_):
  tf.logging.info('Building the pipeline...')
  records_dir = tempfile.mkdtemp(prefix='staffline_kmeans')
  try:
    patch_file_prefix = os.path.join(records_dir, 'patches')
    with pipeline_flags.create_pipeline() as pipeline:
      filenames = file_io.get_matching_files(FLAGS.music_pattern)
      assert filenames, 'Must have matched some filenames'
      if 0 < FLAGS.num_pages < len(filenames):
        filenames = random.sample(filenames, FLAGS.num_pages)
      filenames = pipeline | beam.transforms.Create(filenames)
      patches = filenames | beam.ParDo(
          staffline_patches_dofn.StafflinePatchesDoFn(
              patch_height=FLAGS.patch_height,
              patch_width=FLAGS.patch_width,
              num_stafflines=FLAGS.num_stafflines,
              timeout_ms=FLAGS.timeout_ms,
              max_patches_per_page=FLAGS.max_patches_per_page))
      if FLAGS.num_outputs:
        patches |= combiners.Sample.FixedSizeGlobally(FLAGS.num_outputs)
      patches |= beam.io.WriteToTFRecord(
          patch_file_prefix, beam.coders.ProtoCoder(tf.train.Example))
      tf.logging.info('Running the pipeline...')
    tf.logging.info('Running k-means...')
    patch_files = file_io.get_matching_files(patch_file_prefix + '*')
    clusters = train_kmeans(patch_files, FLAGS.kmeans_num_clusters,
                            FLAGS.kmeans_batch_size, FLAGS.kmeans_num_steps)
    tf.logging.info('Writing the centroids...')
    with tf_record.TFRecordWriter(FLAGS.output_path) as writer:
      for cluster in clusters:
        example = tf.train.Example()
        example.features.feature['features'].float_list.value.extend(cluster)
        example.features.feature['height'].int64_list.value.append(
            FLAGS.patch_height)
        example.features.feature['width'].int64_list.value.append(
            FLAGS.patch_width)
        writer.write(example.SerializeToString())
    tf.logging.info('Done!')
  finally:
    shutil.rmtree(records_dir) 
开发者ID:tensorflow,项目名称:moonlight,代码行数:42,代码来源:staffline_patches_kmeans_pipeline.py

示例8: do_POST

# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def do_POST(self):
    post_vars = cgi.parse_qs(
        self.rfile.read(int(self.headers.getheader('content-length'))))
    labels = [
        post_vars['cluster%d' % i][0]
        for i in moves.xrange(self.clusters.shape[0])
    ]
    examples = create_examples(self.clusters, labels)

    with tf_record.TFRecordWriter(self.output_path) as writer:
      for example in examples:
        writer.write(example.SerializeToString())
    self.send_response(http_client.OK)
    self.end_headers()
    self.wfile.write('Success')  # printed in the labeler alert 
开发者ID:tensorflow,项目名称:moonlight,代码行数:17,代码来源:kmeans_labeler_request_handler.py


注:本文中的tensorflow.python.lib.io.tf_record.TFRecordWriter方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。