本文整理汇总了Python中tensorflow.compat.v1.GraphDef方法的典型用法代码示例。如果您正苦于以下问题:Python v1.GraphDef方法的具体用法?Python v1.GraphDef怎么用?Python v1.GraphDef使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.GraphDef方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _load_frozen_graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def _load_frozen_graph(self, frozen_graph_path):
frozen_graph = tf.GraphDef()
with open(frozen_graph_path, 'rb') as f:
frozen_graph.ParseFromString(f.read())
self.graph = tf.Graph()
with self.graph.as_default():
self.output_node = tf.import_graph_def(
frozen_graph, return_elements=[
'probabilities:0',
])
self.session = tf.InteractiveSession(graph=self.graph)
tf_probabilities = self.graph.get_tensor_by_name('import/probabilities:0')
self._output_nodes = [tf_probabilities]
self.sliding_window = None
self.frames_since_last_inference = self.config.inference_rate
self.last_annotations = []
示例2: load
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def load(self, saved_model_dir_or_frozen_graph: Text):
"""Load the model using saved model or a frozen graph."""
if not self.sess:
self.sess = self._build_session()
self.signitures = {
'image_files': 'image_files:0',
'image_arrays': 'image_arrays:0',
'prediction': 'detections:0',
}
# Load saved model if it is a folder.
if tf.io.gfile.isdir(saved_model_dir_or_frozen_graph):
return tf.saved_model.load(self.sess, ['serve'],
saved_model_dir_or_frozen_graph)
# Load a frozen graph.
graph_def = tf.GraphDef()
with tf.gfile.GFile(saved_model_dir_or_frozen_graph, 'rb') as f:
graph_def.ParseFromString(f.read())
return tf.import_graph_def(graph_def, name='')
示例3: __init__
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def __init__(self, weight_path):
helpers.ensure_lpips_weights_exist(weight_path)
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.graph_util.import_graph_def(graph_def, name="")
wrapped_import = tf.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
# Pack LPIPS network into a tf function
graph_def = tf.GraphDef()
with open(weight_path, "rb") as f:
graph_def.ParseFromString(f.read())
self._lpips_func = tf.function(
wrap_frozen_graph(
graph_def, inputs=("0:0", "1:0"), outputs="Reshape_10:0"))
示例4: create_model_graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def create_model_graph(model_info):
""""Creates a graph from saved GraphDef file and returns a Graph object.
Args:
model_info: Dictionary containing information about the model architecture.
Returns:
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
with tf.Graph().as_default() as graph:
model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
graph_def,
name='',
return_elements=[
model_info['bottleneck_tensor_name'],
model_info['resized_input_tensor_name'],
]))
return graph, bottleneck_tensor, resized_input_tensor
示例5: _import_graph_and_run_inference
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
"""Imports a tflite graph, runs single inference and returns outputs."""
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.Open(tflite_graph_file, mode='rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
class_predictions = graph.get_tensor_by_name(
'raw_outputs/class_predictions:0')
with self.test_session(graph) as sess:
[box_encodings_np, class_predictions_np] = sess.run(
[box_encodings, class_predictions],
feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
return box_encodings_np, class_predictions_np
示例6: test_export_tflite_graph_with_postprocess_op_and_additional_tensors
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def test_export_tflite_graph_with_postprocess_op_and_additional_tensors(self):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
pipeline_config.model.ssd.post_processing.score_converter = (
post_processing_pb2.PostProcessing.SIGMOID)
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10
pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10
pipeline_config.model.ssd.num_classes = 2
tflite_graph_file = self._export_graph_with_postprocessing_op(
pipeline_config, additional_output_tensors=['UnattachedTensor'])
self.assertTrue(os.path.exists(tflite_graph_file))
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.Open(tflite_graph_file, mode='rb') as f:
graph_def.ParseFromString(f.read())
all_op_names = [node.name for node in graph_def.node]
self.assertIn('TFLite_Detection_PostProcess', all_op_names)
self.assertIn('UnattachedTensor', all_op_names)
示例7: _load_frozen_graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def _load_frozen_graph(self, frozen_graph_path):
trt_graph = tf.GraphDef()
with open(frozen_graph_path, 'rb') as f:
trt_graph.ParseFromString(f.read())
self._is_lstm = self._check_lstm(trt_graph)
if self._is_lstm:
print('Loading an LSTM model.')
self.graph = tf.Graph()
with self.graph.as_default():
self.output_node = tf.import_graph_def(
trt_graph,
return_elements=[
'detection_boxes:0', 'detection_classes:0', 'detection_scores:0',
'num_detections:0'
] + (['raw_outputs/lstm_c:0', 'raw_outputs/lstm_h:0']
if self._is_lstm else []))
self.session = tf.InteractiveSession(graph=self.graph)
tf_scores = self.graph.get_tensor_by_name('import/detection_scores:0')
tf_boxes = self.graph.get_tensor_by_name('import/detection_boxes:0')
tf_classes = self.graph.get_tensor_by_name('import/detection_classes:0')
tf_num_detections = self.graph.get_tensor_by_name('import/num_detections:0')
if self._is_lstm:
tf_lstm_c = self.graph.get_tensor_by_name('import/raw_outputs/lstm_c:0')
tf_lstm_h = self.graph.get_tensor_by_name('import/raw_outputs/lstm_h:0')
self._output_nodes = [tf_scores, tf_boxes, tf_classes, tf_num_detections
] + ([tf_lstm_c, tf_lstm_h] if self._is_lstm else [])
if self._is_lstm:
self.lstm_c = np.ones((1, 8, 8, 320))
self.lstm_h = np.ones((1, 8, 8, 320))
示例8: main
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='.pb model path')
parser.add_argument(
'--downgrade',
help='Downgrades the model for use with Tensorflow 1.14 '
'(There maybe some quality degradation.)',
action='store_true')
args = parser.parse_args()
filename, extension = os.path.splitext(args.model)
output_file_path = '{}_trt{}'.format(filename, extension)
frozen_graph = tf.GraphDef()
with open(args.model, 'rb') as f:
frozen_graph.ParseFromString(f.read())
if args.downgrade:
downgrade_equal_op(frozen_graph)
downgrade_nmv5_op(frozen_graph)
is_lstm = check_lstm(frozen_graph)
if is_lstm:
print('Converting LSTM model.')
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=[
'detection_boxes', 'detection_classes', 'detection_scores',
'num_detections'
] + ([
'raw_outputs/lstm_c', 'raw_outputs/lstm_h', 'raw_inputs/init_lstm_c',
'raw_inputs/init_lstm_h'
] if is_lstm else []),
max_batch_size=1,
max_workspace_size_bytes=1 << 25,
precision_mode='FP16',
minimum_segment_size=50)
with open(output_file_path, 'wb') as f:
f.write(trt_graph.SerializeToString())
示例9: create_graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-TensorFlow-Second-Edition,代码行数:10,代码来源:classify_image.py
示例10: run_inference_on_image
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def run_inference_on_image(image):
"""Runs inference on an image.
Args:
image: Image file name.
Returns:
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Creates graph from saved GraphDef.
create_graph()
with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
# Creates node ID --> English string lookup.
node_lookup = NodeLookup()
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-TensorFlow-Second-Edition,代码行数:40,代码来源:classify_image.py
示例11: load_frozen_model
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def load_frozen_model(pb_path, prefix='', print_nodes=False):
"""Load frozen model (.pb file) for testing.
After restoring the model, operators can be accessed by
graph.get_tensor_by_name('<prefix>/<op_name>')
Args:
pb_path: the path of frozen model.
prefix: prefix added to the operator name.
print_nodes: whether to print node names.
Returns:
graph: tensorflow graph definition.
"""
if os.path.exists(pb_path):
#with tf.gfile.GFile(pb_path, "rb") as f:
with tf.io.gfile.GFile(pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=prefix
)
if print_nodes:
for op in graph.get_operations():
print(op.name)
return graph
else:
print('Model file does not exist', pb_path)
exit(-1)
示例12: load_frozen_model
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def load_frozen_model(pb_path, prefix='', print_nodes=False):
"""Load frozen model (.pb file) for testing.
After restoring the model, operators can be accessed by
graph.get_tensor_by_name('<prefix>/<op_name>')
Args:
pb_path: the path of frozen model.
prefix: prefix added to the operator name.
print_nodes: whether to print node names.
Returns:
graph: tensorflow graph definition.
"""
if os.path.exists(pb_path):
with tf.io.gfile.GFile(pb_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=prefix
)
if print_nodes:
for op in graph.get_operations():
print(op.name)
return graph
else:
print('Model file does not exist', pb_path)
exit(-1)
示例13: __init__
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def __init__(self, graph_pb_path=None, graph_def=None):
if graph_pb_path is not None:
with tf.compat.v1.gfile.GFile(graph_pb_path, 'rb') as f:
self.graph = tf.compat.v1.GraphDef()
self.graph.ParseFromString(f.read())
else:
self.graph = graph_def
self.summray_dict = {}
示例14: graph
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def graph(self, graph):
if graph is not None:
if isinstance(graph, GraphDef):
self.graph_def = graph
else:
raise ValueError("graph({}) should be type of GraphDef.".format(type(graph)))
示例15: _load_graph_def
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import GraphDef [as 别名]
def _load_graph_def(pb_file):
if isinstance(pb_file, tf.GraphDef):
return pb_file, "tf_graph_{}".format(random_str(6))
assert isinstance(pb_file, six.string_types)
graph_name, ext = os.path.splitext(os.path.basename(pb_file))
graph_def = tf.GraphDef()
if ext == ".pb":
with open(pb_file, "rb") as fid:
graph_def.ParseFromString(fid.read())
elif ext == ".pbtxt":
with open(pb_file, "r") as fid:
text_format.Parse(fid.read(), graph_def)
else:
raise ValueError("unknown file format: %s" % pb_file)
return graph_def, graph_name