当前位置: 首页>>代码示例>>Python>>正文


Python tensorrt.create_inference_graph方法代码示例

本文整理汇总了Python中tensorflow.contrib.tensorrt.create_inference_graph方法的典型用法代码示例。如果您正苦于以下问题:Python tensorrt.create_inference_graph方法的具体用法?Python tensorrt.create_inference_graph怎么用?Python tensorrt.create_inference_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.tensorrt的用法示例。


在下文中一共展示了tensorrt.create_inference_graph方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def __init__(self, graph, batch_size, precision):
        tftrt_graph = tftrt.create_inference_graph(
            graph.frozen,
            outputs=graph.y_name,
            max_batch_size=batch_size,
            max_workspace_size_bytes=1 << 30,
            precision_mode=precision,
            minimum_segment_size=2)

        self.tftrt_graph = tftrt_graph
        self.graph = graph

        # deep copy causes issues with the latest graph (apparently it contains an RLock
        # passing this by reference seems to work, but more investigation is needed.
        # opt_graph = copy.deepcopy(graph)

        opt_graph = graph
        opt_graph.frozen = tftrt_graph
        super(MobileDetectnetTFTRTEngine, self).__init__(opt_graph)
        self.batch_size = batch_size 
开发者ID:csvance,项目名称:keras-mobile-detectnet,代码行数:22,代码来源:model.py

示例2: freeze_graph

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def freeze_graph(model_path, use_trt=False, trt_max_batch_size=8,
                 trt_precision='fp32'):
    output_names = ['policy_output', 'value_output']

    n = DualNetwork(model_path)
    out_graph = tf.graph_util.convert_variables_to_constants(
        n.sess, n.sess.graph.as_graph_def(), output_names)

    if use_trt:
        import tensorflow.contrib.tensorrt as trt
        out_graph = trt.create_inference_graph(
            input_graph_def=out_graph,
            outputs=output_names,
            max_batch_size=trt_max_batch_size,
            max_workspace_size_bytes=1 << 29,
            precision_mode=trt_precision)

    metadata = make_model_metadata({
        'engine': 'tf',
        'use_trt': bool(use_trt),
    })

    minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo') 
开发者ID:mlperf,项目名称:training,代码行数:25,代码来源:dual_net.py

示例3: main

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--model', help='.pb model path')
  parser.add_argument(
      '--downgrade',
      help='Downgrades the model for use with Tensorflow 1.14 '
      '(There maybe some quality degradation.)',
      action='store_true')
  args = parser.parse_args()

  filename, extension = os.path.splitext(args.model)
  output_file_path = '{}_trt{}'.format(filename, extension)

  frozen_graph = tf.GraphDef()
  with open(args.model, 'rb') as f:
    frozen_graph.ParseFromString(f.read())

  if args.downgrade:
    downgrade_equal_op(frozen_graph)
    downgrade_nmv5_op(frozen_graph)

  is_lstm = check_lstm(frozen_graph)
  if is_lstm:
    print('Converting LSTM model.')

  trt_graph = trt.create_inference_graph(
      input_graph_def=frozen_graph,
      outputs=[
          'detection_boxes', 'detection_classes', 'detection_scores',
          'num_detections'
      ] + ([
          'raw_outputs/lstm_c', 'raw_outputs/lstm_h', 'raw_inputs/init_lstm_c',
          'raw_inputs/init_lstm_h'
      ] if is_lstm else []),
      max_batch_size=1,
      max_workspace_size_bytes=1 << 25,
      precision_mode='FP16',
      minimum_segment_size=50)

  with open(output_file_path, 'wb') as f:
    f.write(trt_graph.SerializeToString()) 
开发者ID:google,项目名称:automl-video-ondevice,代码行数:43,代码来源:trt_compiler.py

示例4: get_trt_graph

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def get_trt_graph(graph_name, graph_def, precision_mode, output_dir,
                  output_node, batch_size=128, workspace_size=2<<10):
  """Create and save inference graph using the TensorRT library.

  Args:
    graph_name: string, name of the graph to be used for saving.
    graph_def: GraphDef, the Frozen Graph to be converted.
    precision_mode: string, the precision that TensorRT should convert into.
      Options- FP32, FP16, INT8.
    output_dir: string, the path to where files should be written.
    output_node: string, the names of the output node that will
      be returned during inference.
    batch_size: int, the number of examples that will be predicted at a time.
    workspace_size: int, size in megabytes that can be used during conversion.

  Returns:
    GraphDef for the TensorRT inference graph.
  """
  trt_graph = trt.create_inference_graph(
      graph_def, [output_node], max_batch_size=batch_size,
      max_workspace_size_bytes=workspace_size<<20,
      precision_mode=precision_mode)

  write_graph_to_file(graph_name, trt_graph, output_dir)

  return trt_graph 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:28,代码来源:tensorrt.py

示例5: getFP32

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def getFP32(input_graph, out_tensor, precision, batch_size, workspace_size):
  graph_prefix = input_graph.split('.pb')[0]
  output_graph = graph_prefix + "_tftrt_" + precision + ".pb"
  #print("output graph is ", output_graph)
  tftrt_graph = trt.create_inference_graph(
      getFrozenGraph(input_graph), [out_tensor],
      max_batch_size=batch_size,
      max_workspace_size_bytes=workspace_size,
      precision_mode=precision)  # Get optimized graph
  with gfile.FastGFile(output_graph, 'wb') as f:
    f.write(tftrt_graph.SerializeToString()) 
开发者ID:didi,项目名称:delta,代码行数:13,代码来源:convert_frozen_pb_to_tftrt.py

示例6: convert_saved_model_to_tensorrt

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def convert_saved_model_to_tensorrt(
        saved_model_dir: str,
        tensorrt_config: TensorrtConfig = None,
        session_config: Optional[tf.ConfigProto] = None
) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor], tf.GraphDef]:
    """
    Convert saved model to tensorrt.

    Uses default tag and signature_def

    Parameters
    ----------
    saved_model_dir
        directory with saved model inside
    tensorrt_config
        tensorrt config which holds all the tensorrt parameters
    session_config
        session config to use

    Returns
    -------
    input_tensors
        dict holding input tensors from saved model signature_def
    output_tensors
        dict holding output tensors from saved model signature_def
    trt_graph
        graph_def with tensorrt graph with variables

    Raises
    ------
    ValueError
        if tensorrt import was unsuccessful
    """
    if trt is None:
        raise ImportError(
            "No tensorrt is found under tensorflow.contrib.tensorrt")
    tensorrt_kwargs = (
        tensorrt_config._asdict() if tensorrt_config is not None else {})
    tensorrt_kwargs.pop("use_tensorrt", None)
    (input_tensors, output_tensors, frozen_graph_def
     ) = _load_saved_model_as_frozen_graph(saved_model_dir)
    output_tensors_list = list(output_tensors.values())

    trt_graph = trt.create_inference_graph(
        input_graph_def=frozen_graph_def,
        outputs=output_tensors_list,
        session_config=session_config,
        **tensorrt_kwargs)
    return input_tensors, output_tensors, trt_graph 
开发者ID:audi,项目名称:nucleus7,代码行数:51,代码来源:tensorrt_utils.py

示例7: createModel

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def createModel(config_path, checkpoint_path, graph_path):
    """ Create a TensorRT Model.
    config_path (string) - The path to the model config file.
    checkpoint_path (string) - The path to the model checkpoint file(s).
    graph_path (string) - The path to the model graph.
    returns (Model) - The TRT model built or loaded from the input files.
    """

    global build_graph, prev_classes

    trt_graph = None
    input_names = None
    
    if build_graph:
        frozen_graph, input_names, output_names = build_detection_graph(
            config=config_path,
            checkpoint=checkpoint_path
        )
    
        trt_graph = trt.create_inference_graph(
            input_graph_def=frozen_graph,
            outputs=output_names,
            max_batch_size=1,
            max_workspace_size_bytes=1 << 25,
            precision_mode='FP16',
            minimum_segment_size=50
        )

        with open(graph_path, 'wb') as f:
            f.write(trt_graph.SerializeToString())

        with open('config.txt', 'r+') as json_file:  
            data = json.load(json_file)
            data['model'] = []
            data['model'] = [{'input_names': input_names}]
            json_file.seek(0)
            json_file.truncate()
            json.dump(data, json_file)

    else:
        with open(graph_path, 'rb') as f:
            trt_graph = tf.GraphDef()
            trt_graph.ParseFromString(f.read())
        with open('config.txt') as json_file:  
            data = json.load(json_file)
            input_names = data['model'][0]['input_names']

    return Model(trt_graph, input_names) 
开发者ID:NVIDIA-AI-IOT,项目名称:GreenMachine,代码行数:50,代码来源:GreenMachine.py

示例8: load_model

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            #JJia TensorRT enable
            print('TensorRT Enabled')
            trt_graph = trt.create_inference_graph(input_graph_def=graph_def,
            outputs=['embeddings:0'],
            max_batch_size = 1, 
            max_workspace_size_bytes= 500000000, # 500MB mem assgined to TRT
            precision_mode="FP16",  # Precision "FP32","FP16" or "INT8"                                        
            minimum_segment_size=1
            )
            ##trt_graph=trt.calib_graph_to_infer_graph(trt_graph)
            #tf.import_graph_def(trt_graph, input_map=input_map, name='')
            return trt_graph        #"return graph_def" for trt disable, "return trt_graph" for trt enable

    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)
        
        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)
      
        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
        #JJia TensorRT enable
        print('TensorRT Enabled', 1<<20)
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            tf.get_default_session(),
            tf.get_default_graph().as_graph_def(),
            output_node_names=["embeddings"])
        for node in frozen_graph.node:
          if node.op == 'RefSwitch':
            node.op = 'Switch'
          elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        trt_graph = trt.create_inference_graph(
            input_graph_def=frozen_graph,
            outputs=["embeddings"],
            max_batch_size = 1,
            max_workspace_size_bytes= 1 << 20,
            precision_mode="FP16",                                       
            minimum_segment_size=1)
        #tf.import_graph_def(trt_graph,return_elements=["embeddings:0"])
        return trt_graph        #"return frozen_graph" for trt disable, "return trt_graph" for trt enable 
开发者ID:JerryJiaGit,项目名称:facenet_trt,代码行数:54,代码来源:facenet.py

示例9: main

# 需要导入模块: from tensorflow.contrib import tensorrt [as 别名]
# 或者: from tensorflow.contrib.tensorrt import create_inference_graph [as 别名]
def main(argv):
  del argv  # Unused.

  original_saved_model_dir = FLAGS.saved_model_dir.rstrip('/')
  tensorrt_saved_model_dir = '{}_trt'.format(original_saved_model_dir)

  # Converts `SavedModel` to TensorRT inference graph.
  trt.create_inference_graph(
      None,
      None,
      input_saved_model_dir=original_saved_model_dir,
      output_saved_model_dir=tensorrt_saved_model_dir)
  print('Model conversion completed.')

  # Gets the image.
  get_image_response = requests.get(FLAGS.image_url)
  number = FLAGS.number
  saved_model_dirs = [original_saved_model_dir, tensorrt_saved_model_dir]
  latencies = {}
  for saved_model_dir in saved_model_dirs:
    with tf.Graph().as_default():
      with tf.Session() as sess:

        # Loads the saved model.
        loader.load(sess, [tag_constants.SERVING], saved_model_dir)
        print('Model loaded {}'.format(saved_model_dir))

        def _run_inf(session=sess, n=1):
          """Runs inference repeatedly."""
          for _ in range(n):
            session.run(
                FLAGS.model_outputs,
                feed_dict={
                    FLAGS.model_input: [get_image_response.content]})

        # Run inference once to perform XLA compile step.
        _run_inf(sess, 1)

        start = time.time()
        _run_inf(sess, number)
        end = time.time()
        latencies[saved_model_dir] = end - start

  print('Time to run {} predictions:'.format(number))
  for saved_model_dir, latency in latencies.items():
    print('* {} seconds for {} runs for {}'.format(
        latency, number, saved_model_dir)) 
开发者ID:artyompal,项目名称:tpu_models,代码行数:49,代码来源:retinanet_tensorrt.py


注:本文中的tensorflow.contrib.tensorrt.create_inference_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。