本文整理汇总了Python中tensorflow.python.framework.graph_io.write_graph函数的典型用法代码示例。如果您正苦于以下问题:Python write_graph函数的具体用法?Python write_graph怎么用?Python write_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了write_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)
else:
text_format.Merge(data.decode("utf-8"), input_graph_def)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
FLAGS.input_names.split(","),
FLAGS.output_names.split(","),
FLAGS.placeholder_type_enum,
FLAGS.toco_compatible)
if FLAGS.frozen_graph:
f = gfile.FastGFile(FLAGS.output, "w")
f.write(output_graph_def.SerializeToString())
else:
graph_io.write_graph(output_graph_def,
os.path.dirname(FLAGS.output),
os.path.basename(FLAGS.output))
return 0
示例2: strip_pruning_vars
def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename):
"""Remove pruning-related auxiliary variables and ops from the graph.
Accepts training checkpoints and produces a GraphDef in which the pruning vars
and ops have been removed.
Args:
checkpoint_dir: Path to the checkpoints.
output_node_names: The name of the output nodes, comma separated.
output_dir: Directory where to write the graph.
filename: Output GraphDef file name.
Returns:
None
Raises:
ValueError: if output_nodes_names are not provided.
"""
if not output_node_names:
raise ValueError(
'Need to specify atleast 1 output node through output_node_names flag')
output_node_names = output_node_names.replace(' ', '').split(',')
initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint(
checkpoint_dir, output_node_names)
final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
initial_graph_def, output_node_names)
graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False)
logging.info('\nFinal graph written to %s', os.path.join(
output_dir, filename))
示例3: testStripUnusedMultipleInputs
def testStripUnusedMultipleInputs(self):
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# We'll create an input graph that multiplies two input nodes.
with ops.Graph().as_default():
constant_node1 = constant_op.constant(1.0, name="constant_node1")
constant_node2 = constant_op.constant(2.0, name="constant_node2")
input_node1 = math_ops.subtract(constant_node1, 3.0, name="input_node1")
input_node2 = math_ops.subtract(constant_node2, 5.0, name="input_node2")
output_node = math_ops.multiply(
input_node1, input_node2, name="output_node")
math_ops.add(output_node, 2.0, name="later_node")
sess = session.Session()
output = sess.run(output_node)
self.assertNear(6.0, output, 0.00001)
graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
# We save out the graph to disk, and then call the const conversion
# routine.
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_binary = False
input_node_names = "input_node1,input_node2"
input_node_types = [
dtypes.float32.as_datatype_enum, dtypes.float32.as_datatype_enum
]
output_binary = True
output_node_names = "output_node"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
output_graph_path, output_binary,
input_node_names,
output_node_names,
input_node_types)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
self.assertEqual(3, len(output_graph_def.node))
for node in output_graph_def.node:
self.assertNotEqual("Add", node.op)
self.assertNotEqual("Sub", node.op)
if node.name == input_node_names:
self.assertTrue("shape" in node.attr)
with session.Session() as sess:
input_node1 = sess.graph.get_tensor_by_name("input_node1:0")
input_node2 = sess.graph.get_tensor_by_name("input_node2:0")
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node,
feed_dict={input_node1: [10.0],
input_node2: [-5.0]})
self.assertNear(-50.0, output, 0.00001)
示例4: main
def main(unused_argv):
# Model definition.
g = ops.Graph()
with g.as_default():
images = array_ops.placeholder(
dtypes.float32, shape=(1, None, None, 3), name='input_image')
inception.inception_resnet_v2_base(images)
graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
cmd_args.graph_filename)
示例5: _WriteGraph
def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
label = "Original"
elif graph_state == GraphState.CALIBRATE:
label = "CalibEngine"
elif graph_state == GraphState.INFERENCE:
label = "InferEngine"
graph_name = (
self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
".pbtxt")
temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
if temp_dir:
logging.info("Writing graph to %s/%s", temp_dir, graph_name)
graph_io.write_graph(gdef, temp_dir, graph_name)
示例6: freeze_graph
def freeze_graph(sess, ckpt, output):
print("Loading checkpoint...")
saver = tf.train.Saver()
saver.restore(sess, ckpt)
print("Writing graph...")
if not os.path.isdir("_Cache"):
os.makedirs("_Cache")
_dir = os.path.join("_Cache", "Model")
saver.save(sess, _dir)
graph_io.write_graph(sess.graph, "_Cache", "Model.pb", False)
print("Freezing graph...")
freeze_graph.freeze_graph(
os.path.join("_Cache", "Model.pb"),
"", True, os.path.join("_Cache", "Model"),
output, "save/restore_all", "save/Const:0", "Frozen.pb", True, ""
)
print("Done")
示例7: saveModel
def saveModel(self, sess, outputDirectory = ""):
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
input_graph_path = outputDirectory + "tfModel.pb"
graph_io.write_graph(sess.graph, "./", input_graph_path)
#create frozen version of graph for distribution
input_saver_def_path = ""
input_binary = False
checkpoint_path = outputDirectory + "models/model.ckpt"
output_node_names = "y_ph"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = outputDirectory + "tfModel_frozen.pb"
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_path, clear_devices, "")
print("Frozen model (model and weights) saved in file: %s" % output_graph_path)
示例8: _testFreezeGraph
def _testFreezeGraph(self, saver_write_version):
checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# We'll create an input graph that has a single variable containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
sess.run(init)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
saver = saver_lib.Saver(write_version=saver_write_version)
checkpoint_path = saver.save(
sess,
checkpoint_prefix,
global_step=0,
latest_filename=checkpoint_state_name)
graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
# We save out the graph to disk, and then call the const conversion
# routine.
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "output_node"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
clear_devices = False
freeze_graph.freeze_graph(
input_graph_path,
input_saver_def_path,
input_binary,
checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_path,
clear_devices,
"",
"",
"",
checkpoint_version=saver_write_version)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
self.assertEqual(4, len(output_graph_def.node))
for node in output_graph_def.node:
self.assertNotEqual("VariableV2", node.op)
self.assertNotEqual("Variable", node.op)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
示例9: testSinglePartitionedVariable
def testSinglePartitionedVariable(self):
"""Ensures partitioned variables fail cleanly with freeze graph."""
checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
checkpoint_state_name = "checkpoint_state"
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# Create a graph with partition variables. When weights are partitioned into
# a single partition, the weights variable is followed by a identity ->
# identity (an additional identity node).
partitioner = partitioned_variables.fixed_size_partitioner(1)
with ops.Graph().as_default():
with variable_scope.variable_scope("part", partitioner=partitioner):
batch_size, height, width, depth = 5, 128, 128, 3
input1 = array_ops.zeros(
(batch_size, height, width, depth), name="input1")
input2 = array_ops.zeros(
(batch_size, height, width, depth), name="input2")
num_nodes = depth
filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
conv = nn.conv2d(
input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
node = math_ops.add(conv, input2, name="test/add")
node = nn.relu6(node, name="test/relu6")
# Save graph and checkpoints.
sess = session.Session()
sess.run(variables.global_variables_initializer())
saver = saver_lib.Saver()
checkpoint_path = saver.save(
sess,
checkpoint_prefix,
global_step=0,
latest_filename=checkpoint_state_name)
graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
# Ensure this graph has partition variables.
self.assertTrue([
tensor.name.split(":")[0]
for op in sess.graph.get_operations()
for tensor in op.values()
if re.search(r"/part_\d+/", tensor.name)
])
# Test freezing graph doesn't make it crash.
output_node_names = "save/restore_all"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
return_value = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=sess.graph_def,
input_saver_def=None,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name="save/restore_all", # default value
filename_tensor_name="save/Const:0", # default value
output_graph=output_graph_path,
clear_devices=False,
initializer_nodes="")
self.assertTrue(return_value, -1)
示例10: export_scoped_meta_graph
#.........这里部分代码省略.........
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
strip_default_attrs: Set to true if default valued attributes must be
removed while exporting the GraphDef.
**kwargs: Optional keyed arguments, including meta_info_def and
collection_list.
Returns:
A `MetaGraphDef` proto and dictionary of `Variables` in the exported
name scope.
Raises:
ValueError: When the `GraphDef` is larger than 2GB.
"""
if context.executing_eagerly():
raise ValueError("Exporting/importing meta graphs is not supported when "
"Eager Execution is enabled.")
graph = graph or ops.get_default_graph()
exclude_nodes = None
unbound_inputs = []
if export_scope or clear_extraneous_savers or clear_devices:
if graph_def:
new_graph_def = graph_pb2.GraphDef()
new_graph_def.versions.CopyFrom(graph_def.versions)
new_graph_def.library.CopyFrom(graph_def.library)
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
for node_def in graph_def.node:
if _should_include_node(node_def.name, export_scope, exclude_nodes):
new_node_def = _node_def(node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
new_graph_def.node.extend([new_node_def])
graph_def = new_graph_def
else:
# Only do this complicated work if we want to remove a name scope.
graph_def = graph_pb2.GraphDef()
# pylint: disable=protected-access
graph_def.versions.CopyFrom(graph.graph_def_versions)
bytesize = 0
if clear_extraneous_savers:
exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
saver_def)
for key in sorted(graph._nodes_by_id):
if _should_include_node(graph._nodes_by_id[key].name,
export_scope,
exclude_nodes):
value = graph._nodes_by_id[key]
# pylint: enable=protected-access
node_def = _node_def(value.node_def, export_scope, unbound_inputs,
clear_devices=clear_devices)
graph_def.node.extend([node_def])
if value.outputs:
assert "_output_shapes" not in graph_def.node[-1].attr
graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
output.get_shape().as_proto() for output in value.outputs])
bytesize += value.node_def.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
graph._copy_functions_to_graph_def(graph_def, bytesize) # pylint: disable=protected-access
# It's possible that not all the inputs are in the export_scope.
# If we would like such information included in the exported meta_graph,
# add them to a special unbound_inputs collection.
if unbound_inputs_col_name:
# Clears the unbound_inputs collections.
graph.clear_collection(unbound_inputs_col_name)
for k in unbound_inputs:
graph.add_to_collection(unbound_inputs_col_name, k)
var_list = {}
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=export_scope)
for v in variables:
if _should_include_node(v, export_scope, exclude_nodes):
var_list[ops.strip_name_scope(v.name, export_scope)] = v
scoped_meta_graph_def = create_meta_graph_def(
graph_def=graph_def,
graph=graph,
export_scope=export_scope,
exclude_nodes=exclude_nodes,
clear_extraneous_savers=clear_extraneous_savers,
saver_def=saver_def,
strip_default_attrs=strip_default_attrs,
**kwargs)
if filename:
graph_io.write_graph(
scoped_meta_graph_def,
os.path.dirname(filename),
os.path.basename(filename),
as_text=as_text)
return scoped_meta_graph_def, var_list
示例11: load_model
weight_file_path = osp.join(input_fld, weight_file)
K.set_learning_phase(0)
net_model = load_model(weight_file_path)
print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
# --------
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def, predict_signature_def
from tensorflow.contrib.session_bundle import exporter
export_path = 'folder_to_export'
builder = saved_model_builder.SavedModelBuilder(export_path)
signature = predict_signature_def(inputs={'images': net_model.input},
outputs={'scores': net_model.output})
示例12: print
# [optional] write graph definition in ascii
# In[ ]:
sess = K.get_session()
if args.graph_def:
f = args.output_graphdef_file
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))
# convert variables to constants and save
# In[ ]:
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
示例13: testStripUnused
def testStripUnused(self):
input_graph_name = "input_graph.pb"
output_graph_name = "output_graph.pb"
# We'll create an input graph that has a single constant containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
constant_node = constant_op.constant(1.0, name="constant_node")
wanted_input_node = math_ops.subtract(constant_node,
3.0,
name="wanted_input_node")
output_node = math_ops.multiply(
wanted_input_node, 2.0, name="output_node")
math_ops.add(output_node, 2.0, name="later_node")
sess = session.Session()
output = sess.run(output_node)
self.assertNear(-4.0, output, 0.00001)
graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
# We save out the graph to disk, and then call the const conversion
# routine.
input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
input_binary = False
output_binary = True
output_node_names = "output_node"
output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
def strip(input_node_names):
strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
output_graph_path, output_binary,
input_node_names,
output_node_names,
dtypes.float32.as_datatype_enum)
with self.assertRaises(KeyError):
strip("does_not_exist")
with self.assertRaises(ValueError):
strip("wanted_input_node:0")
input_node_names = "wanted_input_node"
strip(input_node_names)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
self.assertEqual(3, len(output_graph_def.node))
for node in output_graph_def.node:
self.assertNotEqual("Add", node.op)
self.assertNotEqual("Sub", node.op)
if node.name == input_node_names:
self.assertTrue("shape" in node.attr)
with session.Session() as sess:
input_node = sess.graph.get_tensor_by_name("wanted_input_node:0")
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node, feed_dict={input_node: [10.0]})
self.assertNear(20.0, output, 0.00001)
示例14: main
def main(argv):
argparser = argparse.ArgumentParser(description='Compile some op')
argparser.add_argument('config', help="filename to config-file")
argparser.add_argument('--train', type=int, default=0, help='0 disable (default), 1 enable, -1 dynamic')
argparser.add_argument('--eval', type=int, default=0, help='calculate losses. 0 disable (default), 1 enable')
argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable')
argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
argparser.add_argument("--summaries_tensor_name")
argparser.add_argument("--output_file", help='output pb, pbtxt or meta, metatxt file')
argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params")
argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars")
args = argparser.parse_args(argv[1:])
assert args.train in [0, 1, 2] and args.eval in [0, 1] and args.search in [0, 1]
init(config_filename=args.config, log_verbosity=args.verbosity)
with tf.Graph().as_default() as graph:
assert isinstance(graph, tf.Graph)
print("Create graph...")
# See :func:`Engine._init_network`.
tf.set_random_seed(42)
if args.train < 0:
from TFUtil import get_global_train_flag_placeholder
train_flag = get_global_train_flag_placeholder()
else:
train_flag = bool(args.train)
eval_flag = bool(args.eval)
search_flag = bool(args.search)
network = create_graph(train_flag=train_flag, eval_flag=eval_flag, search_flag=search_flag)
from TFNetworkLayer import LayerBase
for layer in network.layers.values():
assert isinstance(layer, LayerBase)
if layer.output.time_dim_axis is None:
continue
with layer.cls_layer_scope(layer.name):
tf.identity(layer.output.get_placeholder_as_batch_major(), name="output_batch_major")
tf.group(*network.get_post_control_dependencies(), name="post_control_dependencies")
if args.summaries_tensor_name:
summaries_tensor = tf.summary.merge_all()
assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?"
tf.identity(summaries_tensor, name=args.summaries_tensor_name)
if args.output_file and os.path.splitext(args.output_file)[1] in [".meta", ".metatxt"]:
# https://www.tensorflow.org/api_guides/python/meta_graph
saver = tf.train.Saver(
var_list=network.get_saveable_params_list(), max_to_keep=2 ** 31 - 1)
graph_def = saver.export_meta_graph()
else:
graph_def = graph.as_graph_def(add_shapes=True)
print("Graph collection keys:", graph.get_all_collection_keys())
print("Graph num operations:", len(graph.get_operations()))
print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))
if args.output_file:
filename = args.output_file
_, ext = os.path.splitext(filename)
assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"], 'filename %r extension invalid' % filename
print("Write graph to file:", filename)
graph_io.write_graph(
graph_def,
logdir=os.path.dirname(filename),
name=os.path.basename(filename),
as_text=ext.endswith("txt"))
else:
print("Use --output_file if you want to store the graph.")
if args.output_file_model_params_list:
print("Write model param list to:", args.output_file_model_params_list)
with open(args.output_file_model_params_list, "w") as f:
for param in network.get_params_list():
assert param.name[-2:] == ":0"
f.write("%s\n" % param.name[:-2])
if args.output_file_state_vars_list:
print("Write state var list to:", args.output_file_state_vars_list)
from TFUtil import CollectionKeys
with open(args.output_file_state_vars_list, "w") as f:
for param in tf.get_collection(CollectionKeys.STATE_VARS):
assert param.name[-2:] == ":0"
f.write("%s\n" % param.name[:-2])
示例15: single_worker_inference
def single_worker_inference(infer_model,
ckpt,
inference_input_file,
inference_output_file,
hparams):
"""Inference with a single worker."""
output_infer = inference_output_file
# Read data
infer_data = load_data(inference_input_file, hparams)
print ("Batch size type:", type(hparams.infer_batch_size))
with tf.Session(
graph=infer_model.graph, config=utils.get_config_proto()) as sess:
# revo debug
# sess = tf_debug.TensorBoardDebugWrapperSession(sess, 'xy:6064')
# initi table
# sess.run(infer_model.insert_op[0])
# sess.run(infer_model.insert_op[1])
# sess.run(infer_model.insert_op[2])
#
loaded_infer_model = model_helper.load_model(
infer_model.model, ckpt, sess, "infer", infer_model.insert_op)
sess.run(
infer_model.iterator.initializer,
feed_dict={
infer_model.src_placeholder: infer_data,
infer_model.batch_size_placeholder: hparams.infer_batch_size
})
# Debug By Revo
# value = sess.run(infer_model.iterator.source)
# print ("Value:", value)
# print ("Value,len:", len(value))
# print ("Value,Type:", type(value))
# print ("Value,Shape:", value.shape)
# tmp_i = sess.run(infer_model.iterator)
# print ("iterator:", tmp_i)
# print ("iterator shape:", tmp_i.shape())
# sys.exit()
# print ("TEST")
# # Initialize keys and values.
# keys = tf.constant([1, 2, 3], dtype=tf.int64)
# vals = tf.constant([1, 2, 3], dtype=tf.int64)
# # Initialize hash table.
# table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1,
# empty_key=0)
# # Insert values to hash table and run the op.
# insert_op = table.insert(keys, vals)
# sess.run(insert_op)
# # Print hash table lookups.
# print(sess.run(table.lookup(keys)))
# print("HERE2")
# Saving Decoder model
# Decode
utils.print_out("# Start decoding ff3")
# print ("indices:", hparams.inference_indices)
if hparams.inference_indices:
_decode_inference_indices(
loaded_infer_model,
sess,
output_infer=output_infer,
output_infer_summary_prefix=output_infer,
inference_indices=hparams.inference_indices,
tgt_eos=hparams.eos,
subword_option=hparams.subword_option)
else:
nmt_utils.decode_and_evaluate(
"infer",
loaded_infer_model,
sess,
output_infer,
ref_file=None,
metrics=hparams.metrics,
subword_option=hparams.subword_option,
beam_width=hparams.beam_width,
tgt_eos=hparams.eos,
num_translations_per_input=hparams.num_translations_per_input)
# saving model
OUTPUT_FOLDER = '7.19'
utils.print_out("Ouput Folder : " + OUTPUT_FOLDER)
utils.print_out("# Saving Decoder model (Normal,ckpt) By Revo")
loaded_infer_model.saver.save(sess, OUTPUT_FOLDER+"/current.ckpt")
# save pb file
graph_io.write_graph(sess.graph_def, OUTPUT_FOLDER, "current.graphdef")
tf.train.export_meta_graph(filename=OUTPUT_FOLDER + '/current.meta')
writer = tf.summary.FileWriter(OUTPUT_FOLDER, sess.graph)
writer.close()
# Frozen graph saving
OUTPUT_FROZEN_FILE = 'nmt.pb'
# OUTPUT_NAMES = ['index_to_string_Lookup', 'table_init', 'batch_iter_init']
# maybe it is not utf8 as output
OUTPUT_NODES = ['reverse_table_Lookup']
utils.print_out("# Saving Decoder model (Frozen) By Revo")
# extract method try
# new_graph_def = tf.graph_util.extract_sub_graph(sess.graph_def, ["hash_table_2_Lookup"])
#
# remove train node
#.........这里部分代码省略.........