本文整理汇总了Python中tensorflow.tools.graph_transforms.TransformGraph方法的典型用法代码示例。如果您正苦于以下问题:Python graph_transforms.TransformGraph方法的具体用法?Python graph_transforms.TransformGraph怎么用?Python graph_transforms.TransformGraph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.tools.graph_transforms
的用法示例。
在下文中一共展示了graph_transforms.TransformGraph方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: writeTextGraph
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def writeTextGraph(modelPath, outputPath, outNodes):
try:
import cv2 as cv
cv.dnn.writeTextGraph(modelPath, outputPath)
except:
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
with tf.gfile.FastGFile(modelPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
for node in graph_def.node:
if node.op == 'Const':
if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
node.attr['value'].tensor.tensor_content = b''
tf.train.write_graph(graph_def, "", outputPath, as_text=True)
示例2: optimize_graph
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def optimize_graph(model_dir, graph_filename, transforms, input_name, output_names, outname='optimized_model.pb'):
input_names = [input_name] # change this as per how you have saved the model
graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
optimized_graph_def = TransformGraph(
graph_def,
input_names,
output_names,
transforms)
tf.train.write_graph(optimized_graph_def,
logdir=model_dir,
as_text=False,
name=outname)
print('Graph optimized!')
示例3: transform
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def transform(self, ugraph):
if ugraph.lib_name != 'tensorflow':
raise ValueError('only support tensorflow graph')
graph_def = ugraph.graph_def
if TransformGraph is None:
raise RuntimeError("quantization is temporary not supported")
quant_graph_def = TransformGraph(input_graph_def=graph_def,
inputs=[],
outputs=ugraph.output_nodes,
transforms=["quantize_weights", "quantize_nodes"])
return GraphDefParser(config={}).parse(
quant_graph_def,
output_nodes=ugraph.output_nodes
)
示例4: convert_to_pb
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def convert_to_pb(model, path, input_layer_name, output_layer_name, pbfilename, verbose=False):
model.load(path,weights_only=True)
print("[INFO] Loaded CNN network weights from " + path + " ...")
print("[INFO] Re-export model ...")
del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
model.save("model-tmp.tfl")
# taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow
print("[INFO] Re-import model ...")
input_checkpoint = "model-tmp.tfl"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
sess = tf.Session();
saver.restore(sess, input_checkpoint)
# print out all layers to find name of output
if (verbose):
op = sess.graph.get_operations()
[print(m.values()) for m in op][1]
print("[INFO] Freeze model to " + pbfilename + " ...")
# freeze and removes nodes which are not related to feedforward prediction
minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])
graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name], [output_layer_name], tf.float32.as_datatype_enum)
graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])
with tf.gfile.GFile(pbfilename, 'wb') as f:
f.write(graph_def.SerializeToString())
# write model to logs dir so we can visualize it as:
# tensorboard --logdir="logs"
if (verbose):
writer = tf.summary.FileWriter('logs', graph_def)
writer.close()
# tidy up tmp files
for f in glob.glob("model-tmp.tfl*"):
os.remove(f)
os.remove('checkpoint')
################################################################################
# convert a binary .pb protocol buffer format model to tflite format
# e.g. for FireNet
# pbfilename = "firenet.pb"
# input_layer_name = 'InputData/X' # input layer of network
# output_layer_name= 'FullyConnected_2/Softmax' # output layer of network
示例5: export_tensorflow_model
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def export_tensorflow_model(self, output_fld, output_model_file=None,
output_graphdef_file=None,
num_output=None,
quantize=False,
save_output_graphdef_file=False,
output_node_prefix=None):
K.set_learning_phase(0)
if output_model_file is None:
output_model_file = Cifar10AudioClassifier.model_name + '.pb'
if output_graphdef_file is None:
output_graphdef_file = 'model.ascii'
if num_output is None:
num_output = 1
if output_node_prefix is None:
output_node_prefix = 'output_node'
pred = [None] * num_output
pred_node_names = [None] * num_output
for i in range(num_output):
pred_node_names[i] = output_node_prefix + str(i)
pred[i] = tf.identity(self.model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
sess = K.get_session()
if save_output_graphdef_file:
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, output_graphdef_file, as_text=True)
print('saved the graph definition in ascii format at: ', output_graphdef_file)
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from tensorflow.tools.graph_transforms import TransformGraph
if quantize:
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', output_model_file)
示例6: optimize
# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def optimize(self, sess, dataset, path, device):
"""The best performing model is frozen, optimized for inference
by removing unneeded training operations, and written to disk.
Args:
sess (object): The current TF training session.
path (str): The path used for saving the model.
device (str): Represents either "cpu" or "gpu".
.. seealso:: https://bit.ly/2VBBdqQ and https://bit.ly/2W7YqBa
"""
model_name = "model_%s_%s" % (dataset, device)
model_path = path + model_name
tf.train.write_graph(sess.graph.as_graph_def(),
path, model_name + ".pbtxt")
freeze_graph.freeze_graph(model_path + ".pbtxt", "", False,
model_path + ".ckpt", "output",
"save/restore_all", "save/Const:0",
model_path + ".pb", True, "")
os.remove(model_path + ".pbtxt")
graph_def = tf.GraphDef()
with tf.gfile.Open(model_path + ".pb", "rb") as file:
graph_def.ParseFromString(file.read())
transforms = ["remove_nodes(op=Identity)",
"merge_duplicate_nodes",
"strip_unused_nodes",
"fold_constants(ignore_errors=true)"]
optimized_graph_def = TransformGraph(graph_def,
["input"],
["output"],
transforms)
tf.train.write_graph(optimized_graph_def,
logdir=path,
as_text=False,
name=model_name + ".pb")