本文整理汇总了Python中tensorflow.import_graph_def方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.import_graph_def方法的具体用法?Python tensorflow.import_graph_def怎么用?Python tensorflow.import_graph_def使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.import_graph_def方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we can use again a convenient built-in function to import a graph_def into the
# current default Graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
return graph
示例2: build_from_pb
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def build_from_pb(self):
with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(
graph_def,
name=""
)
with open(self.FLAGS.metaLoad, 'r') as fp:
self.meta = json.load(fp)
self.framework = create_framework(self.meta, self.FLAGS)
# Placeholders
self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
self.feed = dict() # other placeholders
self.out = tf.get_default_graph().get_tensor_by_name('output:0')
self.setup_meta_ops()
示例3: build_from_pb
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def build_from_pb(self):
with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(
graph_def,
name=""
)
with open(self.FLAGS.metaLoad, 'r') as fp:
self.meta = json.load(fp)
self.framework = create_framework(self.meta, self.FLAGS)
# Placeholders
self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
self.feed = dict() # other placeholders
self.out = tf.get_default_graph().get_tensor_by_name('output:0')
self.setup_meta_ops()
示例4: worker
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def worker(input_q, output_q):
# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
sess = tf.Session(graph=detection_graph)
fps = FPS().start()
while True:
fps.update()
frame = input_q.get()
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
output_q.put(detect_objects(frame_rgb, sess, detection_graph))
fps.stop()
sess.close()
示例5: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def __init__(self, checkpoint_filename, input_name="images",
output_name="features"):
self.session = tf.Session()
with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
graph_def = tf.GraphDef()
graph_def.ParseFromString(file_handle.read())
tf.import_graph_def(graph_def, name="net")
self.input_var = tf.get_default_graph().get_tensor_by_name(
"net/%s:0" % input_name)
self.output_var = tf.get_default_graph().get_tensor_by_name(
"net/%s:0" % output_name)
assert len(self.output_var.get_shape()) == 2
assert len(self.input_var.get_shape()) == 4
self.feature_dim = self.output_var.get_shape().as_list()[-1]
self.image_shape = self.input_var.get_shape().as_list()[1:]
示例6: load_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(graph_path,tensorboard=False,**kwargs):
'''
:param graph_filename: the path of the pb file
:return: tensorflow graph
'''
with gfile.FastGFile(graph_path,'rb') as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def,name="")
if tensorboard:
writer = tf.summary.FileWriter("log/")
writer.add_graph(graph)
return graph
示例7: create_inception_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def create_inception_graph():
""""Creates a graph from saved GraphDef file and returns a Graph object.
Returns:
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
with tf.Graph().as_default() as graph:
model_filename = os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb')
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
tf.import_graph_def(graph_def, name='', return_elements=[
BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
RESIZED_INPUT_TENSOR_NAME]))
return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
示例8: _import_graph_and_run_inference
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
"""Imports a tflite graph, runs single inference and returns outputs."""
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.Open(tflite_graph_file) as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
class_predictions = graph.get_tensor_by_name(
'raw_outputs/class_predictions:0')
with self.test_session(graph) as sess:
[box_encodings_np, class_predictions_np] = sess.run(
[box_encodings, class_predictions],
feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
return box_encodings_np, class_predictions_np
示例9: _check_output
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _check_output(gin, tf_input, expected):
"""
Takes a TFInputGraph object (assumed to have the input and outputs of the given
names above) and compares the outcome against some expected outcome.
"""
graph = tf.Graph()
graph_def = gin.graph_def
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name="")
tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
# Run on the testing target
tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
# Working on integers, the calculation should be exact
assert np.all(tgt_out == expected), (tgt_out, expected)
# TODO: we could factorize with _check_output, but this is not worth the time doing it.
示例10: _check_output_2
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _check_output_2(gin, tf_input1, tf_input2, expected):
"""
Takes a TFInputGraph object (assumed to have the input and outputs of the given
names above) and compares the outcome against some expected outcome.
"""
graph = tf.Graph()
graph_def = gin.graph_def
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name="")
tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph)
tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph)
tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
# Run on the testing target
tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2})
# Working on integers, the calculation should be exact
assert np.all(tgt_out == expected), (tgt_out, expected)
示例11: fromGraphDef
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
"""
Construct a TFInputGraph from a tf.GraphDef object.
:param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
computation units of the TensorFlow graph.
:param feed_names: list, names of the input tensors.
:param fetch_names: list, names of the output tensors.
"""
assert isinstance(graph_def, tf.GraphDef), \
('expect tf.GraphDef type but got', type(graph_def))
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.import_graph_def(graph_def, name='')
return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
fetch_names=fetch_names)
示例12: initialize
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def initialize():
print('Loading model...',end=''),
with open(filename, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# Retrieving 'network_input_size' from shape of 'input_node'
with tf.compat.v1.Session() as sess:
input_tensor_shape = sess.graph.get_tensor_by_name(input_node).shape.as_list()
assert len(input_tensor_shape) == 4
assert input_tensor_shape[1] == input_tensor_shape[2]
global network_input_size
network_input_size = input_tensor_shape[1]
print('Success!')
print('Loading labels...', end='')
with open(labels_filename, 'rt') as lf:
global labels
labels = [l.strip() for l in lf.readlines()]
print(len(labels), 'found. Success!')
示例13: inference
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def inference():
graph = tf.Graph()
with graph.as_default():
with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
image_data = f.read()
input_image = tf.image.decode_jpeg(image_data, channels=3)
input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
input_image = utils.convert2float(input_image)
input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])
with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_file.read())
[output_image] = tf.import_graph_def(graph_def,
input_map={'input_image': input_image},
return_elements=['output_image:0'],
name='output')
with tf.Session(graph=graph) as sess:
generated = output_image.eval()
with open(FLAGS.output, 'wb') as f:
f.write(generated)
示例14: load_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(frozen_graph_file):
# we parse the graph_def file
with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# we load the graph_def in the default graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None)
return graph
示例15: load_inference_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_inference_graph():
# load frozen tensorflow model into memory
print("> ====== loading HAND frozen graph into memory")
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
sess = tf.Session(graph=detection_graph)
print("> ====== Hand Inference graph loaded.")
return detection_graph, sess
# draw the detected bounding boxes on the images
# You can modify this to also draw a label.