本文整理匯總了Python中tensorflow.python.lib.io.file_io.get_matching_files方法的典型用法代碼示例。如果您正苦於以下問題:Python file_io.get_matching_files方法的具體用法?Python file_io.get_matching_files怎麽用?Python file_io.get_matching_files使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.lib.io.file_io
的用法示例。
在下文中一共展示了file_io.get_matching_files方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: checkpoint_exists
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is the recommended way to check if a checkpoint exists, since it takes
into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
示例2: get_all_checkpoints
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def get_all_checkpoints(output_dir):
"""docstring."""
ckpt = cm.get_checkpoint_state(output_dir, None)
res = []
if not ckpt:
return None
for path in ckpt.all_model_checkpoint_paths:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = cm._prefix_to_checkpoint_path(path, saver_pb2.SaverDef.V2)
v1_path = cm._prefix_to_checkpoint_path(path, saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
res.append(path)
else:
tf.logging.error("Couldn't match files for checkpoint %s", path)
return res
示例3: read_patch_dimensions
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def read_patch_dimensions():
"""Reads the dimensions of the input patches from disk.
Parses the first example in the training set, which must have "height" and
"width" features.
Returns:
Tuple of (height, width) read from disk, using the glob passed to
--train_input_patches.
"""
for filename in file_io.get_matching_files(FLAGS.train_input_patches):
# If one matching file is empty, go on to the next file.
for record in tf_record.tf_record_iterator(filename):
example = tf.train.Example.FromString(record)
# Convert long (int64) to int, necessary for use in feature columns in
# Python 2.
patch_height = int(example.features.feature['height'].int64_list.value[0])
patch_width = int(example.features.feature['width'].int64_list.value[0])
return patch_height, patch_width
示例4: list_tf_records
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def list_tf_records(paths, default_schema):
for p in paths:
files = [f for f in file_io.get_matching_files(p) if f.endswith(".tfrecords")]
if len(files) == 0:
raise Exception("Couldn't find any .tfrecords file in path or glob [{}]".format(p))
for f in files:
yield f, resolve_schema(os.path.dirname(f), default_schema)
示例5: copy_data_to_tmp
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def copy_data_to_tmp(input_files):
"""Copies data to /tmp/ and returns glob matching the files."""
files = []
for e in input_files:
for path in e.split(','):
files.extend(file_io.get_matching_files(path))
for path in files:
if not path.startswith('gs://'):
return input_files
tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
os.makedirs(tmp_path)
subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
return [os.path.join(tmp_path, '*')]
示例6: _delete_file_if_exists
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def _delete_file_if_exists(self, filespec):
for pathname in file_io.get_matching_files(filespec):
file_io.delete_file(pathname)
示例7: latest_checkpoint
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V2)
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
return ckpt.model_checkpoint_path
else:
logging.error("Couldn't match files for checkpoint %s",
ckpt.model_checkpoint_path)
return None
示例8: get_checkpoint_mtimes
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def get_checkpoint_mtimes(checkpoint_prefixes):
"""Returns the mtimes (modification timestamps) of the checkpoints.
Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
that priority.
This is the recommended way to get the mtimes, since it takes into account
the naming difference between V1 and V2 formats.
Args:
checkpoint_prefixes: a list of checkpoint paths, typically the results of
`Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
sharded/non-sharded or V1/V2.
Returns:
A list of mtimes (in microseconds) of the found checkpoints.
"""
mtimes = []
def match_maybe_append(pathname):
fnames = file_io.get_matching_files(pathname)
if fnames:
mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
return True
return False
for checkpoint_prefix in checkpoint_prefixes:
# Tries V2's metadata file first.
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if match_maybe_append(pathname):
continue
# Otherwise, tries V1, where the prefix is the complete pathname.
match_maybe_append(checkpoint_prefix)
return mtimes
示例9: _run_batch_prediction
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def _run_batch_prediction(self, output_dir, use_target):
reglinear.batch_predict(
training_dir=self._train_output,
prediction_input_file=(self._csv_eval_filename if use_target
else self._csv_predict_filename),
output_dir=output_dir,
mode='evaluation' if use_target else 'prediction',
batch_size=4,
output_format='csv')
# check errors file is empty
errors = file_io.get_matching_files(os.path.join(output_dir, 'errors*'))
self.assertEqual(len(errors), 1)
if os.path.getsize(errors[0]):
with open(errors[0]) as errors_file:
self.fail(msg=errors_file.read())
# check predictions files are not empty
predictions = file_io.get_matching_files(os.path.join(output_dir,
'predictions*'))
self.assertGreater(os.path.getsize(predictions[0]), 0)
# check the schema is correct
schema_file = os.path.join(output_dir, 'csv_schema.json')
self.assertTrue(os.path.isfile(schema_file))
schema = json.loads(file_io.read_file_to_string(schema_file))
self.assertEqual(schema[0]['name'], 'key')
self.assertEqual(schema[1]['name'], 'predicted')
if use_target:
self.assertEqual(schema[2]['name'], 'target')
self.assertEqual(len(schema), 3)
else:
self.assertEqual(len(schema), 2)
示例10: get_train_eval_files
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def get_train_eval_files(input_dir):
"""Get preprocessed training and eval files."""
data_dir = _get_latest_data_dir(input_dir)
train_pattern = os.path.join(data_dir, 'train*.tfrecord.gz')
eval_pattern = os.path.join(data_dir, 'eval*.tfrecord.gz')
train_files = file_io.get_matching_files(train_pattern)
eval_files = file_io.get_matching_files(eval_pattern)
return train_files, eval_files
示例11: _run_transform
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def _run_transform(self):
"""Runs DataFlow for makint tf.example files.
Only the train file uses DataFlow, the eval file runs beam locally to save
time.
"""
cloud = True
extra_args = []
if cloud:
extra_args = ['--cloud',
'--job-name=test-mltoolbox-df-%s' % uuid.uuid4().hex,
'--project-id=%s' % self._get_default_project_id(),
'--num-workers=3']
cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
'--csv=' + self._csv_train_filename,
'--analysis=' + self._analysis_output,
'--prefix=features_train',
'--output=' + self._transform_output,
'--shuffle'] + extra_args
self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
subprocess.check_call(' '.join(cmd), shell=True)
# Don't wate time running a 2nd DF job, run it locally.
cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
'--csv=' + self._csv_eval_filename,
'--analysis=' + self._analysis_output,
'--prefix=features_eval',
'--output=' + self._transform_output]
self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
subprocess.check_call(' '.join(cmd), shell=True)
# Check the files were made
train_files = file_io.get_matching_files(
os.path.join(self._transform_output, 'features_train*'))
eval_files = file_io.get_matching_files(
os.path.join(self._transform_output, 'features_eval*'))
self.assertNotEqual([], train_files)
self.assertNotEqual([], eval_files)
示例12: _batch_predict
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def _batch_predict(args, cell):
if args['cloud_config'] and not args['cloud']:
raise ValueError('"cloud_config" is provided but no "--cloud". '
'Do you want local run or cloud run?')
if args['cloud']:
job_request = {
'data_format': 'TEXT',
'input_paths': file_io.get_matching_files(args['data']['csv']),
'output_path': args['output'],
}
if args['model'].startswith('gs://'):
job_request['uri'] = args['model']
else:
parts = args['model'].split('.')
if len(parts) != 2:
raise ValueError('Invalid model name for cloud prediction. Use "model.version".')
version_name = ('projects/%s/models/%s/versions/%s' %
(Context.default().project_id, parts[0], parts[1]))
job_request['version_name'] = version_name
cloud_config = args['cloud_config'] or {}
job_id = cloud_config.pop('job_id', None)
job_request.update(cloud_config)
job = datalab_ml.Job.submit_batch_prediction(job_request, job_id)
_show_job_link(job)
else:
print('local prediction...')
_local_predict.local_batch_predict(args['model'],
args['data']['csv'],
args['output'],
args['format'],
args['batch_size'])
print('done.')
# Helper classes for explainer. Each for is for a combination
# of algorithm (LIME, IG) and type (text, image, tabular)
# ===========================================================
示例13: local_batch_predict
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def local_batch_predict(model_dir, csv_file_pattern, output_dir, output_format, batch_size=100):
""" Batch Predict with a specified model.
It does batch prediction, saves results to output files and also creates an output
schema file. The output file names are input file names prepended by 'predict_results_'.
Args:
model_dir: The model directory containing a SavedModel (usually saved_model.pb).
csv_file_pattern: a pattern of csv files as batch prediction source.
output_dir: the path of the output directory.
output_format: csv or json.
batch_size: Larger batch_size improves performance but may
cause more memory usage.
"""
file_io.recursive_create_dir(output_dir)
csv_files = file_io.get_matching_files(csv_file_pattern)
if len(csv_files) == 0:
raise ValueError('No files found given ' + csv_file_pattern)
with tf.Graph().as_default(), tf.Session() as sess:
input_alias_map, output_alias_map = _tf_load_model(sess, model_dir)
csv_tensor_name = list(input_alias_map.values())[0]
output_schema = _get_output_schema(sess, output_alias_map)
for csv_file in csv_files:
output_file = os.path.join(
output_dir,
'predict_results_' +
os.path.splitext(os.path.basename(csv_file))[0] + '.' + output_format)
with file_io.FileIO(output_file, 'w') as f:
prediction_source = _batch_csv_reader(csv_file, batch_size)
for batch in prediction_source:
batch = [l.rstrip() for l in batch if l]
predict_results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: batch})
formatted_results = _format_results(output_format, output_schema, predict_results)
f.write('\n'.join(formatted_results) + '\n')
file_io.write_string_to_file(os.path.join(output_dir, 'predict_results_schema.json'),
json.dumps(output_schema, indent=2))
示例14: create_object_test
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def create_object_test():
"""Verifies file_io's object manipulation methods ."""
starttime = int(round(time.time() * 1000))
dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime)
print("Creating dir %s." % dir_name)
file_io.create_dir(dir_name)
# Create a file in this directory.
file_name = "%s/test_file.txt" % dir_name
print("Creating file %s." % file_name)
file_io.write_string_to_file(file_name, "test file creation.")
list_files_pattern = "%s/test_file*.txt" % dir_name
print("Getting files matching pattern %s." % list_files_pattern)
files_list = file_io.get_matching_files(list_files_pattern)
print(files_list)
assert len(files_list) == 1
assert files_list[0] == file_name
# Cleanup test files.
print("Deleting file %s." % file_name)
file_io.delete_file(file_name)
# Delete directory.
print("Deleting directory %s." % dir_name)
file_io.delete_recursively(dir_name)
示例15: main
# 需要導入模塊: from tensorflow.python.lib.io import file_io [as 別名]
# 或者: from tensorflow.python.lib.io.file_io import get_matching_files [as 別名]
def main(argv):
if FLAGS.output_type not in VALID_OUTPUT_TYPES:
raise ValueError('output_type "%s" not in allowed types: %s' %
(FLAGS.output_type, VALID_OUTPUT_TYPES))
# Exclude argv[0], which is the current binary.
patterns = argv[1:]
if not patterns:
raise ValueError('PNG file glob(s) must be specified')
input_paths = []
for pattern in patterns:
pattern_paths = file_io.get_matching_files(pattern)
if not pattern_paths:
raise ValueError('Pattern "%s" failed to match any files' % pattern)
input_paths.extend(pattern_paths)
start = time.time()
output = run(
input_paths,
FLAGS.glyphs_saved_model,
output_notesequence=FLAGS.output_type == 'NoteSequence')
end = time.time()
sys.stderr.write('OMR elapsed time: %.2f\n' % (end - start))
if FLAGS.output_type == 'MusicXML':
output_bytes = conversions.score_to_musicxml(output)
else:
if FLAGS.text_format:
output_bytes = text_format.MessageToString(output).encode('utf-8')
else:
output_bytes = output.SerializeToString()
file_io.write_string_to_file(FLAGS.output, output_bytes)