本文整理汇总了Python中tensorflow.python.lib.io.tf_record.TFRecordWriter方法的典型用法代码示例。如果您正苦于以下问题:Python tf_record.TFRecordWriter方法的具体用法?Python tf_record.TFRecordWriter怎么用?Python tf_record.TFRecordWriter使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.lib.io.tf_record
的用法示例。
在下文中一共展示了tf_record.TFRecordWriter方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testKmeans
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def testKmeans(self):
num_features = FLAGS.patch_height * FLAGS.patch_width
dummy_data = np.random.random((500, num_features))
with tempfile.NamedTemporaryFile(mode='r') as patches_file:
with tf_record.TFRecordWriter(patches_file.name) as patches_writer:
for patch in dummy_data:
example = example_pb2.Example()
example.features.feature['features'].float_list.value.extend(patch)
patches_writer.write(example.SerializeToString())
clusters = staffline_patches_kmeans_pipeline.train_kmeans(
patches_file.name,
NUM_CLUSTERS,
BATCH_SIZE,
TRAIN_STEPS,
min_eval_frequency=0)
self.assertEqual(clusters.shape, (NUM_CLUSTERS, num_features))
示例2: write_test_data
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def write_test_data(example_proto,
schema,
schema_filename="schema.pb"):
tmp_dir = tf.test.get_temp_dir()
schema_path = pjoin(tmp_dir, schema_filename)
with open(schema_path, "wb") as f:
f.write(schema.SerializeToString())
data_file = pjoin(tmp_dir, "test.tfrecord")
with TFRecordWriter(data_file) as f:
for i in example_proto:
f.write(i.SerializeToString())
return data_file, schema_path
示例3: create_tfrecord_files
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def create_tfrecord_files(output_dir, num_files=3, num_records_per_file=10):
"""Creates TFRecords files.
The method must be called within an active session.
Args:
output_dir: The directory where the files are stored.
num_files: The number of files to create.
num_records_per_file: The number of records per file.
Returns:
A list of the paths to the TFRecord files.
"""
tfrecord_paths = []
for i in range(num_files):
path = os.path.join(output_dir,
'flowers.tfrecord-%d-of-%s' % (i, num_files))
tfrecord_paths.append(path)
writer = tf_record.TFRecordWriter(path)
for _ in range(num_records_per_file):
_, example = generate_image(image_shape=(10, 10, 3))
writer.write(example)
writer.close()
return tfrecord_paths
示例4: save_rows_to_tf_record_file
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def save_rows_to_tf_record_file(df_rows, make_sequence_example_fn, export_filename):
tf_record_options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
tf_writer = tf_record.TFRecordWriter(export_filename, options=tf_record_options)
try:
for index, row in df_rows.iterrows():
seq_example = make_sequence_example_fn(row)
tf_writer.write(seq_example.SerializeToString())
finally:
tf_writer.close()
sys.stdout.flush()
示例5: save_rows_to_tf_record_file
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def save_rows_to_tf_record_file(rows, make_sequence_example_fn, export_filename):
tf_record_options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
tf_writer = tf_record.TFRecordWriter(export_filename, options=tf_record_options)
try:
for row in rows:
seq_example = make_sequence_example_fn(row)
tf_writer.write(seq_example.SerializeToString())
finally:
tf_writer.close()
sys.stdout.flush()
示例6: testInputFn
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def testInputFn(self):
with tempfile.NamedTemporaryFile() as records_file:
with tf_record.TFRecordWriter(records_file.name) as records_writer:
flags.FLAGS.augmentation_x_shift_probability = 0
flags.FLAGS.augmentation_max_rotation_degrees = 0
example = tf.train.Example()
height = 5
width = 3
example.features.feature['height'].int64_list.value.append(height)
example.features.feature['width'].int64_list.value.append(width)
example.features.feature['patch'].float_list.value.extend(
range(height * width))
label = 1
example.features.feature['label'].int64_list.value.append(label)
for _ in range(3):
records_writer.write(example.SerializeToString())
flags.FLAGS.train_input_patches = records_file.name
batch_tensors = glyph_patches.input_fn(records_file.name)
with self.test_session() as sess:
batch = sess.run(batch_tensors)
self.assertAllEqual(
batch[0]['patch'],
np.arange(height * width).reshape(
(1, height, width)).repeat(3, axis=0))
self.assertAllEqual(batch[1], [label, label, label])
示例7: main
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def main(_):
tf.logging.info('Building the pipeline...')
records_dir = tempfile.mkdtemp(prefix='staffline_kmeans')
try:
patch_file_prefix = os.path.join(records_dir, 'patches')
with pipeline_flags.create_pipeline() as pipeline:
filenames = file_io.get_matching_files(FLAGS.music_pattern)
assert filenames, 'Must have matched some filenames'
if 0 < FLAGS.num_pages < len(filenames):
filenames = random.sample(filenames, FLAGS.num_pages)
filenames = pipeline | beam.transforms.Create(filenames)
patches = filenames | beam.ParDo(
staffline_patches_dofn.StafflinePatchesDoFn(
patch_height=FLAGS.patch_height,
patch_width=FLAGS.patch_width,
num_stafflines=FLAGS.num_stafflines,
timeout_ms=FLAGS.timeout_ms,
max_patches_per_page=FLAGS.max_patches_per_page))
if FLAGS.num_outputs:
patches |= combiners.Sample.FixedSizeGlobally(FLAGS.num_outputs)
patches |= beam.io.WriteToTFRecord(
patch_file_prefix, beam.coders.ProtoCoder(tf.train.Example))
tf.logging.info('Running the pipeline...')
tf.logging.info('Running k-means...')
patch_files = file_io.get_matching_files(patch_file_prefix + '*')
clusters = train_kmeans(patch_files, FLAGS.kmeans_num_clusters,
FLAGS.kmeans_batch_size, FLAGS.kmeans_num_steps)
tf.logging.info('Writing the centroids...')
with tf_record.TFRecordWriter(FLAGS.output_path) as writer:
for cluster in clusters:
example = tf.train.Example()
example.features.feature['features'].float_list.value.extend(cluster)
example.features.feature['height'].int64_list.value.append(
FLAGS.patch_height)
example.features.feature['width'].int64_list.value.append(
FLAGS.patch_width)
writer.write(example.SerializeToString())
tf.logging.info('Done!')
finally:
shutil.rmtree(records_dir)
示例8: do_POST
# 需要导入模块: from tensorflow.python.lib.io import tf_record [as 别名]
# 或者: from tensorflow.python.lib.io.tf_record import TFRecordWriter [as 别名]
def do_POST(self):
post_vars = cgi.parse_qs(
self.rfile.read(int(self.headers.getheader('content-length'))))
labels = [
post_vars['cluster%d' % i][0]
for i in moves.xrange(self.clusters.shape[0])
]
examples = create_examples(self.clusters, labels)
with tf_record.TFRecordWriter(self.output_path) as writer:
for example in examples:
writer.write(example.SerializeToString())
self.send_response(http_client.OK)
self.end_headers()
self.wfile.write('Success') # printed in the labeler alert