本文整理汇总了Python中tensorflow.contrib.tensorrt.create_inference_graph方法的典型用法代码示例。如果您正苦于以下问题:Python tensorrt.create_inference_graph方法的具体用法?Python tensorrt.create_inference_graph怎么用?Python tensorrt.create_inference_graph使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.tensorrt
的用法示例。
在下文中一共展示了tensorrt.create_inference_graph方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def __init__(self, graph, batch_size, precision):
tftrt_graph = tftrt.create_inference_graph(
graph.frozen,
outputs=graph.y_name,
max_batch_size=batch_size,
max_workspace_size_bytes=1 << 30,
precision_mode=precision,
minimum_segment_size=2)
self.tftrt_graph = tftrt_graph
self.graph = graph
# deep copy causes issues with the latest graph (apparently it contains an RLock
# passing this by reference seems to work, but more investigation is needed.
# opt_graph = copy.deepcopy(graph)
opt_graph = graph
opt_graph.frozen = tftrt_graph
super(MobileDetectnetTFTRTEngine, self).__init__(opt_graph)
self.batch_size = batch_size
示例2: freeze_graph
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def freeze_graph(model_path, use_trt=False, trt_max_batch_size=8,
trt_precision='fp32'):
output_names = ['policy_output', 'value_output']
n = DualNetwork(model_path)
out_graph = tf.graph_util.convert_variables_to_constants(
n.sess, n.sess.graph.as_graph_def(), output_names)
if use_trt:
import tensorflow.contrib.tensorrt as trt
out_graph = trt.create_inference_graph(
input_graph_def=out_graph,
outputs=output_names,
max_batch_size=trt_max_batch_size,
max_workspace_size_bytes=1 << 29,
precision_mode=trt_precision)
metadata = make_model_metadata({
'engine': 'tf',
'use_trt': bool(use_trt),
})
minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
示例3: main
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='.pb model path')
parser.add_argument(
'--downgrade',
help='Downgrades the model for use with Tensorflow 1.14 '
'(There maybe some quality degradation.)',
action='store_true')
args = parser.parse_args()
filename, extension = os.path.splitext(args.model)
output_file_path = '{}_trt{}'.format(filename, extension)
frozen_graph = tf.GraphDef()
with open(args.model, 'rb') as f:
frozen_graph.ParseFromString(f.read())
if args.downgrade:
downgrade_equal_op(frozen_graph)
downgrade_nmv5_op(frozen_graph)
is_lstm = check_lstm(frozen_graph)
if is_lstm:
print('Converting LSTM model.')
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=[
'detection_boxes', 'detection_classes', 'detection_scores',
'num_detections'
] + ([
'raw_outputs/lstm_c', 'raw_outputs/lstm_h', 'raw_inputs/init_lstm_c',
'raw_inputs/init_lstm_h'
] if is_lstm else []),
max_batch_size=1,
max_workspace_size_bytes=1 << 25,
precision_mode='FP16',
minimum_segment_size=50)
with open(output_file_path, 'wb') as f:
f.write(trt_graph.SerializeToString())
示例4: get_trt_graph
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
output_node, batch_size=128, workspace_size=2<<10):
"""Create and save inference graph using the TensorRT library.
Args:
graph_name: string, name of the graph to be used for saving.
graph_def: GraphDef, the Frozen Graph to be converted.
precision_mode: string, the precision that TensorRT should convert into.
Options- FP32, FP16, INT8.
output_dir: string, the path to where files should be written.
output_node: string, the names of the output node that will
be returned during inference.
batch_size: int, the number of examples that will be predicted at a time.
workspace_size: int, size in megabytes that can be used during conversion.
Returns:
GraphDef for the TensorRT inference graph.
"""
trt_graph = trt.create_inference_graph(
graph_def, [output_node], max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size<<20,
precision_mode=precision_mode)
write_graph_to_file(graph_name, trt_graph, output_dir)
return trt_graph
示例5: getFP32
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def getFP32(input_graph, out_tensor, precision, batch_size, workspace_size):
graph_prefix = input_graph.split('.pb')[0]
output_graph = graph_prefix + "_tftrt_" + precision + ".pb"
#print("output graph is ", output_graph)
tftrt_graph = trt.create_inference_graph(
getFrozenGraph(input_graph), [out_tensor],
max_batch_size=batch_size,
max_workspace_size_bytes=workspace_size,
precision_mode=precision) # Get optimized graph
with gfile.FastGFile(output_graph, 'wb') as f:
f.write(tftrt_graph.SerializeToString())
示例6: convert_saved_model_to_tensorrt
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def convert_saved_model_to_tensorrt(
saved_model_dir: str,
tensorrt_config: TensorrtConfig = None,
session_config: Optional[tf.ConfigProto] = None
) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor], tf.GraphDef]:
"""
Convert saved model to tensorrt.
Uses default tag and signature_def
Parameters
----------
saved_model_dir
directory with saved model inside
tensorrt_config
tensorrt config which holds all the tensorrt parameters
session_config
session config to use
Returns
-------
input_tensors
dict holding input tensors from saved model signature_def
output_tensors
dict holding output tensors from saved model signature_def
trt_graph
graph_def with tensorrt graph with variables
Raises
------
ValueError
if tensorrt import was unsuccessful
"""
if trt is None:
raise ImportError(
"No tensorrt is found under tensorflow.contrib.tensorrt")
tensorrt_kwargs = (
tensorrt_config._asdict() if tensorrt_config is not None else {})
tensorrt_kwargs.pop("use_tensorrt", None)
(input_tensors, output_tensors, frozen_graph_def
) = _load_saved_model_as_frozen_graph(saved_model_dir)
output_tensors_list = list(output_tensors.values())
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph_def,
outputs=output_tensors_list,
session_config=session_config,
**tensorrt_kwargs)
return input_tensors, output_tensors, trt_graph
示例7: createModel
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def createModel(config_path, checkpoint_path, graph_path):
""" Create a TensorRT Model.
config_path (string) - The path to the model config file.
checkpoint_path (string) - The path to the model checkpoint file(s).
graph_path (string) - The path to the model graph.
returns (Model) - The TRT model built or loaded from the input files.
"""
global build_graph, prev_classes
trt_graph = None
input_names = None
if build_graph:
frozen_graph, input_names, output_names = build_detection_graph(
config=config_path,
checkpoint=checkpoint_path
)
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=output_names,
max_batch_size=1,
max_workspace_size_bytes=1 << 25,
precision_mode='FP16',
minimum_segment_size=50
)
with open(graph_path, 'wb') as f:
f.write(trt_graph.SerializeToString())
with open('config.txt', 'r+') as json_file:
data = json.load(json_file)
data['model'] = []
data['model'] = [{'input_names': input_names}]
json_file.seek(0)
json_file.truncate()
json.dump(data, json_file)
else:
with open(graph_path, 'rb') as f:
trt_graph = tf.GraphDef()
trt_graph.ParseFromString(f.read())
with open('config.txt') as json_file:
data = json.load(json_file)
input_names = data['model'][0]['input_names']
return Model(trt_graph, input_names)
示例8: load_model
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def load_model(model, input_map=None):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with gfile.FastGFile(model_exp,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#JJia TensorRT enable
print('TensorRT Enabled')
trt_graph = trt.create_inference_graph(input_graph_def=graph_def,
outputs=['embeddings:0'],
max_batch_size = 1,
max_workspace_size_bytes= 500000000, # 500MB mem assgined to TRT
precision_mode="FP16", # Precision "FP32","FP16" or "INT8"
minimum_segment_size=1
)
##trt_graph=trt.calib_graph_to_infer_graph(trt_graph)
#tf.import_graph_def(trt_graph, input_map=input_map, name='')
return trt_graph #"return graph_def" for trt disable, "return trt_graph" for trt enable
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
#JJia TensorRT enable
print('TensorRT Enabled', 1<<20)
frozen_graph = tf.graph_util.convert_variables_to_constants(
tf.get_default_session(),
tf.get_default_graph().as_graph_def(),
output_node_names=["embeddings"])
for node in frozen_graph.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=["embeddings"],
max_batch_size = 1,
max_workspace_size_bytes= 1 << 20,
precision_mode="FP16",
minimum_segment_size=1)
#tf.import_graph_def(trt_graph,return_elements=["embeddings:0"])
return trt_graph #"return frozen_graph" for trt disable, "return trt_graph" for trt enable
示例9: main
# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def main(argv):
del argv # Unused.
original_saved_model_dir = FLAGS.saved_model_dir.rstrip('/')
tensorrt_saved_model_dir = '{}_trt'.format(original_saved_model_dir)
# Converts `SavedModel` to TensorRT inference graph.
trt.create_inference_graph(
None,
None,
input_saved_model_dir=original_saved_model_dir,
output_saved_model_dir=tensorrt_saved_model_dir)
print('Model conversion completed.')
# Gets the image.
get_image_response = requests.get(FLAGS.image_url)
number = FLAGS.number
saved_model_dirs = [original_saved_model_dir, tensorrt_saved_model_dir]
latencies = {}
for saved_model_dir in saved_model_dirs:
with tf.Graph().as_default():
with tf.Session() as sess:
# Loads the saved model.
loader.load(sess, [tag_constants.SERVING], saved_model_dir)
print('Model loaded {}'.format(saved_model_dir))
def _run_inf(session=sess, n=1):
"""Runs inference repeatedly."""
for _ in range(n):
session.run(
FLAGS.model_outputs,
feed_dict={
FLAGS.model_input: [get_image_response.content]})
# Run inference once to perform XLA compile step.
_run_inf(sess, 1)
start = time.time()
_run_inf(sess, number)
end = time.time()
latencies[saved_model_dir] = end - start
print('Time to run {} predictions:'.format(number))
for saved_model_dir, latency in latencies.items():
print('* {} seconds for {} runs for {}'.format(
latency, number, saved_model_dir))