本文整理汇总了Python中tensorflow.python.tools.optimize_for_inference_lib.optimize_for_inference方法的典型用法代码示例。如果您正苦于以下问题:Python optimize_for_inference_lib.optimize_for_inference方法的具体用法?Python optimize_for_inference_lib.optimize_for_inference怎么用?Python optimize_for_inference_lib.optimize_for_inference使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.tools.optimize_for_inference_lib
的用法示例。
在下文中一共展示了optimize_for_inference_lib.optimize_for_inference方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _optimize_graph
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def _optimize_graph(basename, output_dir):
name, _ = os.path.splitext(basename)
frozen_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
graph_def = load_graph_def(frozen_graph_filename)
optimized_graph = optimize_for_inference_lib.optimize_for_inference(
input_graph_def=graph_def,
input_node_names=['input_1'],
placeholder_type_enum=dtypes.float32.as_datatype_enum,
output_node_names=['conv6_interp/ResizeBilinear'],
toco_compatible=True
)
optimized_graph_filename = os.path.basename(
frozen_graph_filename).replace('frozen', 'optimized')
optimized_graph_filename = optimized_graph_filename
tf.train.write_graph(
optimized_graph, output_dir, optimized_graph_filename, as_text=False
)
logger.info('Saved optimized graph to: %s' %
os.path.join(output_dir, optimized_graph_filename))
示例2: _optimize_graph
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def _optimize_graph(basename, output_dir):
name, _ = os.path.splitext(basename)
frozen_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
graph_def = load_graph_def(frozen_graph_filename)
optimized_graph = optimize_for_inference_lib.optimize_for_inference(
input_graph_def=graph_def,
input_node_names=['input_1'],
placeholder_type_enum=dtypes.float32.as_datatype_enum,
output_node_names=['deprocess_stylized_image_1/mul'],
toco_compatible=True
)
optimized_graph_filename = os.path.basename(
frozen_graph_filename).replace('frozen', 'optimized')
optimized_graph_filename = optimized_graph_filename
tf.train.write_graph(
optimized_graph, output_dir, optimized_graph_filename, as_text=False
)
logger.info('Saved optimized graph to: %s' %
os.path.join(output_dir, optimized_graph_filename))
示例3: freeze
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def freeze(saved_model_dir, input_nodes, output_nodes, save_file):
graph_def = tf.Graph()
with tf.Session(graph=graph_def) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], saved_model_dir)
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_nodes
)
frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph_def,
input_nodes,
output_nodes,
tf.float32.as_datatype_enum
)
with open(save_file, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
示例4: _optimize_for_inference
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def _optimize_for_inference(self):
graph_def = self.getTFInputGraph().graph_def
# Get data types of input placeholders
placeholder_types = self._get_placeholder_types(graph_def)
# Strip away graph nodes not used in computing the tensors with the specified output names
input_names = [tfx.op_name(tnsr_name) for _, tnsr_name in self.getInputMapping()]
output_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in self.getOutputMapping()]
return infr_opt.optimize_for_inference(graph_def,
input_names,
output_names,
placeholder_types)
示例5: main
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
graph_io.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
示例6: main
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(FLAGS.input, "r") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
graph_io.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
示例7: load_graph
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def load_graph(self):
print('load graph from: ' + self.args.input_graph)
self.infer_graph = tf.Graph()
with self.infer_graph.as_default():
graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.gfile.FastGFile(self.args.input_graph, 'rb') as input_file:
input_graph_content = input_file.read()
graph_def.ParseFromString(input_graph_content)
output_graph = optimize_for_inference(graph_def, [self.input_layer],
self.output_layers, dtypes.uint8.as_datatype_enum, False)
tf.import_graph_def(output_graph, name='')
示例8: model_freeze
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def model_freeze(path,MODEL_NAME='model'):
# Freeze the graph
input_graph_path = path + MODEL_NAME+'.pbtxt'
checkpoint_path = path + 'model_ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = 'positive_sentiment_probability'
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = path + 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = path + 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["inputs/X" ],#an array of the input node(s)
["positive_sentiment_probability"],
tf.int32.as_datatype_enum # an array of output nodes
)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
示例9: main
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def main(unused_args):
if not tf.gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = tf.GraphDef()
with tf.gfile.Open(FLAGS.input, "r") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
if FLAGS.frozen_graph:
f = tf.gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
tf.train.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
示例10: export_model
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def export_model(saver, model, input_node_names, output_node_name):
if not path.exists('out'):
os.mkdir('out')
tf.train.write_graph(K.get_session().graph_def, 'out', model_name + '_graph.pbtxt')
saver.save(K.get_session(), 'out/' + model_name + '.chkp')
freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, False,
'out/' + model_name + '.chkp', output_node_name,
"save/restore_all", "save/Const:0",
'out/frozen_' + model_name + '.bytes', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + model_name + '.bytes', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/opt_' + model_name + '.bytes', "wb") as f:
f.write(output_graph_def.SerializeToString())
print("graph saved!")
########################################################################################################################
# Main program
示例11: export_model
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def export_model(model_output_dir, input_node_names, output_node_name):
"""Export the model so we can use it later.
This will create two Protocol Buffer files in the model output directory.
These files represent a serialized version of our model with all the
learned weights and biases. One of the ProtoBuf files is a version
optimized for inference-only usage.
"""
name_base = os.path.join(model_output_dir, MODEL_NAME)
frozen_graph_file = os.path.join(model_output_dir,
'frozen_' + MODEL_NAME + '.pb')
freeze_graph.freeze_graph(
name_base + '.pbtxt', None, False, name_base + '.chkp',
output_node_name, "save/restore_all", "save/Const:0",
frozen_graph_file, True, ""
)
input_graph_def = tf.GraphDef()
with tf.gfile.Open(frozen_graph_file, "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
optimized_graph_file = os.path.join(model_output_dir,
'optimized_' + MODEL_NAME + '.pb')
with tf.gfile.GFile(optimized_graph_file, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("Inference optimized graph saved at: " + optimized_graph_file)
示例12: main
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
FLAGS.output_names.split(","),
FLAGS.placeholder_type_enum,
FLAGS.toco_compatible)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
graph_io.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
示例13: optimize_graph
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def optimize_graph(frozen_graph_filename, suffix='optimized'):
"""Optimize a TensorFlow graph for inference.
Optimized graphs are saved to the same directory as the input frozen graph.
Args:
frozen_graph_filename (str): the filename of a frozen graph.
suffix (optional, str): a suffix to append to the optimized graph file.
Returns:
optimized_graph_filename (str): a path to the saved optimized graph.
"""
output_dir, basename = os.path.split(frozen_graph_filename)
graph_def = load_graph_def(frozen_graph_filename)
optimized_graph = optimize_for_inference_lib.optimize_for_inference(
input_graph_def=graph_def,
input_node_names=['input_1'],
placeholder_type_enum=dtypes.float32.as_datatype_enum,
output_node_names=['deprocess_stylized_image_1/mul'],
toco_compatible=True
)
optimized_graph_filename = os.path.basename(
frozen_graph_filename).replace('frozen', suffix)
optimized_graph_filename = optimized_graph_filename
tf.train.write_graph(
optimized_graph, output_dir, optimized_graph_filename, as_text=False
)
logger.info('Saved optimized graph to: %s' %
os.path.join(output_dir, optimized_graph_filename))
return optimized_graph_filename
示例14: convert_to_pb
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def convert_to_pb(model, path, input_layer_name, output_layer_name, pbfilename, verbose=False):
model.load(path,weights_only=True)
print("[INFO] Loaded CNN network weights from " + path + " ...")
print("[INFO] Re-export model ...")
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
model.save("model-tmp.tfl")
# taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow
print("[INFO] Re-import model ...")
input_checkpoint = "model-tmp.tfl"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
sess = tf.Session();
saver.restore(sess, input_checkpoint)
# print out all layers to find name of output
if (verbose):
op = sess.graph.get_operations()
[print(m.values()) for m in op][1]
print("[INFO] Freeze model to " + pbfilename + " ...")
# freeze and removes nodes which are not related to feedforward prediction
minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])
graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name], [output_layer_name], tf.float32.as_datatype_enum)
graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])
with tf.gfile.GFile(pbfilename, 'wb') as f:
f.write(graph_def.SerializeToString())
# write model to logs dir so we can visualize it as:
# tensorboard --logdir="logs"
if (verbose):
writer = tf.summary.FileWriter('logs', graph_def)
writer.close()
# tidy up tmp files
for f in glob.glob("model-tmp.tfl*"):
os.remove(f)
os.remove('checkpoint')
################################################################################
# convert a binary .pb protocol buffer format model to tflite format
# e.g. for FireNet
# pbfilename = "firenet.pb"
# input_layer_name = 'InputData/X' # input layer of network
# output_layer_name= 'FullyConnected_2/Softmax' # output layer of network
示例15: export_compact
# 需要导入模块: from tensorflow.python.tools import optimize_for_inference_lib [as 别名]
# 或者: from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference [as 别名]
def export_compact(self, filename, optimize=True, toco_compatible=False):
"""Create a self-contained inference-only graph and write final graph (in pb format) to disk.
Args:
filename (str): path to the output graph
optimize (bool): whether to use TensorFlow's `optimize_for_inference`
to prune and optimize the graph. This does not work on all types of graphs.
toco_compatible (bool): See TensorFlow's
`optimize_for_inference
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
for details. Only available after TF 1.8.
"""
if toco_compatible:
assert optimize, "toco_compatible is only effective when optimize=True!"
self.graph = self.config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(self.config.input_signature)
with PredictTowerContext(''):
self.config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(self.config.input_names)
output_tensors = get_tensors_by_names(self.config.output_names)
self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess)
dtypes = [n.dtype for n in input_tensors]
# freeze variables to constants
frozen_graph_def = graph_util.convert_variables_to_constants(
sess,
self.graph.as_graph_def(),
[n.name[:-2] for n in output_tensors],
variable_names_whitelist=None,
variable_names_blacklist=None)
# prune unused nodes from graph
if optimize:
toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph_def,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
*toco_args)
with gfile.FastGFile(filename, "wb") as f:
f.write(frozen_graph_def.SerializeToString())
logger.info("Output graph written to {}.".format(filename))