本文整理汇总了Python中tensorflow.python.framework.graph_util.convert_variables_to_constants函数的典型用法代码示例。如果您正苦于以下问题:Python convert_variables_to_constants函数的具体用法?Python convert_variables_to_constants怎么用?Python convert_variables_to_constants使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了convert_variables_to_constants函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testConvertVariablesToConsts
def testConvertVariablesToConsts(self):
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
init = variables.initialize_variables([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set,
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
sess.run(variables.global_variables_initializer())
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(sess, variable_graph_def,
["output_node"]))
# The unused variable should be cleared so the two graphs should be
# equivalent.
self.assertEqual(
str(constant_graph_def),
str(constant_graph_def_without_variable_whitelist))
# Test variable name black list. This should result in the variable not
# being a const.
sess.run(variables.global_variables_initializer())
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VariableV2")
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
for node in constant_graph_def.node:
self.assertNotEqual("Variable", node.op)
self.assertNotEqual("VariableV2", node.op)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
示例2: freeze_graph
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not tf.train.checkpoint_exists(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
FLAGS.variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_blacklist=variable_names_blacklist)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
示例3: graph_def_from_checkpoint
def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
"""Converts checkpoint data to GraphDef.
Reads the latest checkpoint data and produces a GraphDef in which the
variables have been converted to constants.
Args:
checkpoint_dir: Path to the checkpoints.
output_node_names: List of name strings for the result nodes of the graph.
Returns:
A GraphDef from the latest checkpoint
Raises:
ValueError: if no checkpoint is found
"""
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None:
raise ValueError('Could not find a checkpoint at: {0}.'
.format(checkpoint_dir))
saver_for_restore = saver_lib.import_meta_graph(
checkpoint_path + '.meta', clear_devices=True)
with session.Session() as sess:
saver_for_restore.restore(sess, checkpoint_path)
graph_def = ops.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, output_node_names)
return output_graph_def
示例4: freeze_graph_with_def_protos
def freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
clear_devices,
initializer_nodes,
variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint):
raise ValueError(
'Input checkpoint "' + input_checkpoint + '" does not exist!')
if not output_node_names:
raise ValueError(
'You must supply the name of a node to --output_node_names.')
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ''
_ = importer.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ':0')
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (variable_names_blacklist.split(',') if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
return output_graph_def
示例5: from_session
def from_session(cls,
sess,
input_tensors,
output_tensors,
freeze_variables=False):
"""Creates a TocoConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
freeze_variables: Boolean indicating whether the variables need to be
converted into constants via the freeze_graph.py script.
(default False)
Returns:
TocoConverter class.
"""
# Get GraphDef.
if freeze_variables:
sess.run(global_variables_initializer())
output_arrays = [tensor_name(tensor) for tensor in output_tensors]
graph_def = tf_graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_arrays)
else:
graph_def = sess.graph_def
# Create TocoConverter class.
return cls(graph_def, input_tensors, output_tensors)
示例6: testConvertVariablesToConstsWithEmbeddings
def testConvertVariablesToConstsWithEmbeddings(self):
"""Freezes a graph with embeddings."""
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
# Make model.
state_input = keras.layers.Input(
shape=(1,), name="state_input", dtype="int32")
output = keras.layers.Embedding(
output_dim=16, input_dim=100, input_length=1, name="state")(
state_input)
model = keras.models.Model(inputs=[state_input], outputs=[output])
model.compile(
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
# Get associated session.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Ensure graph has no variables.
for node in constant_graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
# Compare the value of the graphs.
expected_value = model.predict(input_data)
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
model.outputs, [input_data])
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
示例7: freeze_session
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a prunned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
示例8: freeze_graph
def freeze_graph(sess, input_tensors, output_tensors):
"""Returns a frozen GraphDef.
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
existing GraphDef is returned. The Grappler pass is only run on models that
are frozen in order to inline the functions in the graph.
If OpHints is present, it will try to convert the OpHint graph.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors.
output_tensors: List of output tensors (only .name is used from this).
Returns:
Frozen GraphDef.
"""
# Grappler inline function optimization will break OpHints graph
# transformation, so if OpHints are present, just convert it.
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
if len(hinted_outputs_nodes) > 0: # pylint: disable=g-explicit-length-test
return _convert_op_hints_if_present(sess, output_tensors)
# Runs a Grappler pass in order to inline any functions in the graph.
config = get_grappler_config(function_only=True)
graph_def = run_graph_optimizations(
sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)
if not is_frozen_graph(sess):
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
return tf_graph_util.convert_variables_to_constants(sess, graph_def,
output_arrays)
else:
return sess.graph_def
示例9: testConvertVariablesToConstsWithFunctions
def testConvertVariablesToConstsWithFunctions(self):
@function.Defun(dtypes.float32)
def plus_one(x):
return x + 1.0
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
defun_node = plus_one(variable_node)
output_node = math_ops_lib.multiply(
defun_node, 2.0, name="output_node")
with session.Session() as sess:
init = variables.initialize_variables([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set,
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
self.assertEqual(variable_graph_def.library,
constant_graph_def.library)
示例10: _freeze_graph_with_def_protos
def _freeze_graph_with_def_protos(input_graph_def, output_node_names,
initializer_names, shared_init_op_name,
input_saver_def, input_checkpoint):
"""Converts all variables in a graph and checkpoint into constants.
During this process, we need to retain certain initializer nodes (e.g. table
initializer nodes). Instead of determining which dependencies
of the shared initializer node (e.g. group_deps) to keep, we
reconstruct the connections between the individual initializer nodes and
the shared node after freezing the graph.
Args:
input_graph_def: A GraphDef proto to be frozen.
output_node_names: Names of output nodes.
initializer_names: Names of initializer nodes to keep.
shared_init_op_name: The name of the shared initializer node to connect the
nodes in initializer names to.
input_saver_def: A SaverDef proto used for restoring a checkpoint.
input_checkpoint: A path to a checkpoint to restore.
Returns:
A frozen GraphDef.
"""
with _ops.Graph().as_default():
_ = _importer.import_graph_def(input_graph_def, name='')
with _session.Session() as sess:
saver = _saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = _graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names + initializer_names)
_connect_to_shared_init_op(output_graph_def, shared_init_op_name,
initializer_names)
return output_graph_def
示例11: freeze_graph_def
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
示例12: saveData
def saveData(self,step):
print('{} Saving checkpoint file to: {}'.format(
datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
self.output_dir))
# 保存图的权值
self.saver.save(
self.sess, self.ckpt_file, global_step=self.global_step)
# 保存图的结构
tf.train.write_graph(self.sess.graph_def,
os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model'),
'train.pbtxt')
# 保存到权值对图,生成可供android使用的.pb文件
graph_def = tf.get_default_graph().as_graph_def()
print("global_variables are")
variables_to_save = []
for variables in self.sess.graph_def.node:
print("{}:{}".format(str(variables.name),type(variables)))
variables_to_save.append(str(variables.name).split(':')[0])
print("--------------")
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
self.sess,
graph_def,
#['yolo/pad_1/paddings']
variables_to_save
#self.net.logits
# ["predictions"] # 需要保存节点的名字///////////////////////////////////需要再改改
)
with tf.gfile.GFile(
os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model', 'train.'+step+str(step)+'.pb'),
"wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
####################################################################################################
freezetime = datetime.datetime.now().strftime('%m-%d-%H-%M-%S')
zu = ZipUtil()
zipfilename = cfg.DATA_UploadZipFileName +'.'+str(step)+ '.' + freezetime
# 添加啦step参数,可以按照训练对部分进行压缩,,不用全部压缩了
zu.zip_dir(os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION),
step,
zipfilename)
uploader = Uploader()
uploader.setQiniuKEY('mMQxjyif6Uk8nSGIn9ZD3I19MBMEK3IUGngcX8_p',
'J5gFhdpQ-1O1rkCnlqYnzPiH3XTst2Szlv9GlmQM')
#uploader.upload2qiniu(cfg.DATA_UploadZipFileName + '.' + freezetime,zipfilename).start()
sendData = {"state":"prepared",
"filename":str(zipfilename),
"filepath":os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION,zipfilename),
"step":step,
}
uploader.notifyForTrans(sendData)
示例13: train_network
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
if epoch_index % epoch_delta == 0:
total_batches_in_train_set = 0
total_correct_times_in_train_set = 0
total_cost_in_train_set = 0.
for i in range(12):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
total_batches_in_train_set += 1
total_correct_times_in_train_set += return_correct_times_in_batch
total_cost_in_train_set += (mean_cost_in_batch * batch_size)
total_batches_in_test_set = 0
total_correct_times_in_test_set = 0
total_cost_in_test_set = 0.
for i in range(3):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
total_batches_in_test_set += 1
total_correct_times_in_test_set += return_correct_times_in_batch
total_cost_in_test_set += (mean_cost_in_batch * batch_size)
acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
total_batches_in_test_set * batch_size,
total_cost_in_test_set,
acy_on_train * 100.0,
total_correct_times_in_train_set,
total_batches_in_train_set * batch_size,
total_cost_in_train_set))
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
示例14: save_graph_to_file
def save_graph_to_file(graph, graph_file_name, model_info, class_count):
sess, _, _, _, _ = build_eval_session(model_info, class_count)
graph = sess.graph
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
示例15: _convert_op_hints_if_present
def _convert_op_hints_if_present(sess, output_tensors):
if is_frozen_graph(sess):
raise ValueError("Try to convert op hints, needs unfrozen graph.")
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
graph_def = tf_graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
graph_def = tf_graph_util.remove_training_nodes(graph_def)
return graph_def