当前位置: 首页>>代码示例>>Python>>正文


Python v1.import_graph_def方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.import_graph_def方法的典型用法代码示例。如果您正苦于以下问题:Python v1.import_graph_def方法的具体用法?Python v1.import_graph_def怎么用?Python v1.import_graph_def使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.import_graph_def方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _load_frozen_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def _load_frozen_graph(self, frozen_graph_path):
    frozen_graph = tf.GraphDef()
    with open(frozen_graph_path, 'rb') as f:
      frozen_graph.ParseFromString(f.read())

    self.graph = tf.Graph()
    with self.graph.as_default():
      self.output_node = tf.import_graph_def(
          frozen_graph, return_elements=[
              'probabilities:0',
          ])
    self.session = tf.InteractiveSession(graph=self.graph)

    tf_probabilities = self.graph.get_tensor_by_name('import/probabilities:0')
    self._output_nodes = [tf_probabilities]
    self.sliding_window = None
    self.frames_since_last_inference = self.config.inference_rate
    self.last_annotations = [] 
开发者ID:google,项目名称:automl-video-ondevice,代码行数:20,代码来源:tf_shot_classification.py

示例2: load

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def load(self, saved_model_dir_or_frozen_graph: Text):
    """Load the model using saved model or a frozen graph."""
    if not self.sess:
      self.sess = self._build_session()
    self.signitures = {
        'image_files': 'image_files:0',
        'image_arrays': 'image_arrays:0',
        'prediction': 'detections:0',
    }

    # Load saved model if it is a folder.
    if tf.io.gfile.isdir(saved_model_dir_or_frozen_graph):
      return tf.saved_model.load(self.sess, ['serve'],
                                 saved_model_dir_or_frozen_graph)

    # Load a frozen graph.
    graph_def = tf.GraphDef()
    with tf.gfile.GFile(saved_model_dir_or_frozen_graph, 'rb') as f:
      graph_def.ParseFromString(f.read())
    return tf.import_graph_def(graph_def, name='') 
开发者ID:PINTO0309,项目名称:PINTO_model_zoo,代码行数:22,代码来源:inference.py

示例3: create_model_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def create_model_graph(model_info):
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Args:
    model_info: Dictionary containing information about the model architecture.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Graph().as_default() as graph:
    model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
    with gfile.FastGFile(model_path, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
          graph_def,
          name='',
          return_elements=[
              model_info['bottleneck_tensor_name'],
              model_info['resized_input_tensor_name'],
          ]))
  return graph, bottleneck_tensor, resized_input_tensor 
开发者ID:iamvishnuks,项目名称:AudioNet,代码行数:25,代码来源:retrain.py

示例4: _import_graph_and_run_inference

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
    """Imports a tflite graph, runs single inference and returns outputs."""
    graph = tf.Graph()
    with graph.as_default():
      graph_def = tf.GraphDef()
      with tf.gfile.Open(tflite_graph_file, mode='rb') as f:
        graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
      input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
      box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
      class_predictions = graph.get_tensor_by_name(
          'raw_outputs/class_predictions:0')
      with self.test_session(graph) as sess:
        [box_encodings_np, class_predictions_np] = sess.run(
            [box_encodings, class_predictions],
            feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
    return box_encodings_np, class_predictions_np 
开发者ID:tensorflow,项目名称:models,代码行数:19,代码来源:export_tflite_ssd_graph_lib_tf1_test.py

示例5: initialize

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def initialize():
    print('Loading model...',end=''),
    with tf.gfile.FastGFile(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
    print('Success!')
    print('Loading labels...', end='')
    with open(labels_filename, 'rt') as lf:
        for l in lf:
            l = l[:-1]
            labels.append(l)
    print(len(labels), 'found. Success!') 
开发者ID:jamesbannan,项目名称:pluralsight,代码行数:14,代码来源:predict.py

示例6: _load_frozen_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def _load_frozen_graph(self, frozen_graph_path):
    trt_graph = tf.GraphDef()
    with open(frozen_graph_path, 'rb') as f:
      trt_graph.ParseFromString(f.read())

    self._is_lstm = self._check_lstm(trt_graph)
    if self._is_lstm:
      print('Loading an LSTM model.')

    self.graph = tf.Graph()
    with self.graph.as_default():
      self.output_node = tf.import_graph_def(
          trt_graph,
          return_elements=[
              'detection_boxes:0', 'detection_classes:0', 'detection_scores:0',
              'num_detections:0'
          ] + (['raw_outputs/lstm_c:0', 'raw_outputs/lstm_h:0']
               if self._is_lstm else []))
    self.session = tf.InteractiveSession(graph=self.graph)

    tf_scores = self.graph.get_tensor_by_name('import/detection_scores:0')
    tf_boxes = self.graph.get_tensor_by_name('import/detection_boxes:0')
    tf_classes = self.graph.get_tensor_by_name('import/detection_classes:0')
    tf_num_detections = self.graph.get_tensor_by_name('import/num_detections:0')
    if self._is_lstm:
      tf_lstm_c = self.graph.get_tensor_by_name('import/raw_outputs/lstm_c:0')
      tf_lstm_h = self.graph.get_tensor_by_name('import/raw_outputs/lstm_h:0')

    self._output_nodes = [tf_scores, tf_boxes, tf_classes, tf_num_detections
                         ] + ([tf_lstm_c, tf_lstm_h] if self._is_lstm else [])

    if self._is_lstm:
      self.lstm_c = np.ones((1, 8, 8, 320))
      self.lstm_h = np.ones((1, 8, 8, 320)) 
开发者ID:google,项目名称:automl-video-ondevice,代码行数:36,代码来源:tf_object_detection.py

示例7: create_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='') 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-TensorFlow-Second-Edition,代码行数:10,代码来源:classify_image.py

示例8: load_frozen_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def load_frozen_model(pb_path, prefix='', print_nodes=False):
    """Load frozen model (.pb file) for testing.
    After restoring the model, operators can be accessed by
    graph.get_tensor_by_name('<prefix>/<op_name>')
    Args:
        pb_path: the path of frozen model.
        prefix: prefix added to the operator name.
        print_nodes: whether to print node names.
    Returns:
        graph: tensorflow graph definition.
    """
    if os.path.exists(pb_path):
        #with tf.gfile.GFile(pb_path, "rb") as f:
        with tf.io.gfile.GFile(pb_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                name=prefix
            )
            if print_nodes:
                for op in graph.get_operations():
                    print(op.name)
            return graph
    else:
        print('Model file does not exist', pb_path)
        exit(-1) 
开发者ID:luigifreda,项目名称:pyslam,代码行数:30,代码来源:tf.py

示例9: load_frozen_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def load_frozen_model(pb_path, prefix='', print_nodes=False):
    """Load frozen model (.pb file) for testing.
    After restoring the model, operators can be accessed by
    graph.get_tensor_by_name('<prefix>/<op_name>')
    Args:
        pb_path: the path of frozen model.
        prefix: prefix added to the operator name.
        print_nodes: whether to print node names.
    Returns:
        graph: tensorflow graph definition.
    """
    if os.path.exists(pb_path):
        with tf.io.gfile.GFile(pb_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                name=prefix
            )
            if print_nodes:
                for op in graph.get_operations():
                    print(op.name)
            return graph
    else:
        print('Model file does not exist', pb_path)
        exit(-1) 
开发者ID:luigifreda,项目名称:pyslam,代码行数:29,代码来源:utils_tf.py

示例10: test_dropout_trans_1_1

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def test_dropout_trans_1_1(droput_graph_tuple):
    (graph_def,
     (rate_name, dropout_output_name),
     output_nodes) = droput_graph_tuple
    ugraph = GraphDefParser(config={}).parse(graph_def, output_nodes=output_nodes)
    transformer = DropoutTransformer()
    assert transformer.prune_graph
    new_ugraph = transformer.transform(ugraph)
    for op in new_ugraph.ops_info.values():
        assert op.ugraph
    out_op = new_ugraph.ops_info[output_nodes[0]]
    assert set([str(op.name) for op in out_op.input_nodes]) == set(['x', 'bias'])
    # all dropout nodes should be gone
    graph_1 = tf.Graph()
    graph_2 = tf.Graph()
    with graph_1.as_default():
        tf.import_graph_def(ugraph.graph_def, name='')
    with graph_2.as_default():
        tf.import_graph_def(new_ugraph.graph_def, name='')
    with tf.Session(graph=graph_1):
        rate = graph_1.get_tensor_by_name(rate_name)
        dropout_output = graph_1.get_tensor_by_name(dropout_output_name)
        output = graph_1.get_tensor_by_name(output_nodes[0]+":0")
        # test the dropout ops are gone
        assert rate.op.name not in new_ugraph.ops_info
        assert dropout_output.op.name not in new_ugraph.ops_info
        output_1 = output.eval({rate: 0.0})
    with tf.Session(graph=graph_2):
        output = graph_2.get_tensor_by_name(output_nodes[0]+":0")
        output_2 = output.eval()
    # expecting the same outputs with keep_prob == 1.0
    assert (output_1 == output_2).all() 
开发者ID:uTensor,项目名称:utensor_cgen,代码行数:34,代码来源:test_dropout_transormer.py

示例11: test_dropout_trans_1_2

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def test_dropout_trans_1_2(droput_graph_tuple):
    (graph_def,
     (keep_prob_name, dropout_output_name),
     output_nodes) = droput_graph_tuple
    ugraph = GraphDefParser(config={}).parse(graph_def, output_nodes=output_nodes)
    transformer = DropoutTransformerV2()
    assert transformer.prune_graph
    new_ugraph = transformer.transform(ugraph)
    for op in new_ugraph.ops_info.values():
        assert op.ugraph
    out_op = new_ugraph.ops_info[output_nodes[0]]
    assert set([str(op.name) for op in out_op.input_nodes]) == set(['x', 'bias'])
    # all dropout nodes should be gone
    graph_1 = tf.Graph()
    graph_2 = tf.Graph()
    with graph_1.as_default():
        tf.import_graph_def(ugraph.graph_def, name='')
    with graph_2.as_default():
        tf.import_graph_def(new_ugraph.graph_def, name='')
    with tf.Session(graph=graph_1):
        keep_prob = graph_1.get_tensor_by_name(keep_prob_name)
        dropout_output = graph_1.get_tensor_by_name(dropout_output_name)
        output = graph_1.get_tensor_by_name(output_nodes[0]+":0")
        # test the dropout ops are gone
        assert keep_prob.op.name not in new_ugraph.ops_info
        assert dropout_output.op.name not in new_ugraph.ops_info
        output_1 = output.eval({keep_prob:1.0})
    with tf.Session(graph=graph_2):
        output = graph_2.get_tensor_by_name(output_nodes[0]+":0")
        output_2 = output.eval()
    # expecting the same outputs with keep_prob == 1.0
    assert (output_1 == output_2).all() 
开发者ID:uTensor,项目名称:utensor_cgen,代码行数:34,代码来源:test_dropout_transormer.py

示例12: write_graph_and_checkpoint

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
  """Writes the graph and the checkpoint into disk."""
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
    with tf.Session() as sess:
      saver = tf.train.Saver(
          saver_def=input_saver_def, save_relative_paths=True)
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path) 
开发者ID:tensorflow,项目名称:models,代码行数:16,代码来源:exporter.py

示例13: _load_inference_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def _load_inference_graph(self, inference_graph_path, is_binary=True):
    od_graph = tf.Graph()
    with od_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(inference_graph_path, mode='rb') as fid:
        if is_binary:
          od_graph_def.ParseFromString(fid.read())
        else:
          text_format.Parse(fid.read(), od_graph_def)
        tf.import_graph_def(od_graph_def, name='')
    return od_graph 
开发者ID:tensorflow,项目名称:models,代码行数:13,代码来源:exporter_tf1_test.py

示例14: predict_image

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def predict_image(image):
    print('Predicting image')
    tf.reset_default_graph()
    tf.import_graph_def(graph_def, name='')
    
    with tf.Session() as sess:
        prob_tensor = sess.graph.get_tensor_by_name(output_layer)

        input_tensor_shape = sess.graph.get_tensor_by_name('Placeholder:0').shape.as_list()
        network_input_size = input_tensor_shape[1]

        # w = image.shape[0]
        # h = image.shape[1]
        w, h = image.size
        print('Image size',w,'x',h)

        # scaling
        if w > h:
            new_size = (int((float(size[1]) / h) * w), size[1], 3)
        else:
            new_size = (size[0], int((float(size[0]) / w) * h), 3)

        # resize
        if  not (new_size[0] == w and new_size[0] == h):
            print('Resizing to', new_size[0],'x',new_size[1])
            #augmented_image = scipy.misc.imresize(image, new_size)
            augmented_image = np.asarray(image.resize((new_size[0], new_size[1])))
        else:
            augmented_image = np.asarray(image)

        # crop center
        try:
            augmented_image = crop_center(augmented_image, network_input_size, network_input_size)
        except:
            return 'error: crop_center'

        augmented_image = augmented_image.astype(float)

        # RGB -> BGR
        red, green, blue = tf.split(axis=2, num_or_size_splits=3, value=augmented_image)

        image_normalized = tf.concat(axis=2, values=[
            blue - mean_values_b_g_r[0],
            green - mean_values_b_g_r[1],
            red - mean_values_b_g_r[2],
        ])

        image_normalized = image_normalized.eval()
        image_normalized = np.expand_dims(image_normalized, axis=0)

        predictions, = sess.run(prob_tensor, {input_node: image_normalized})

        result = []
        idx = 0
        for p in predictions:
            truncated_probablity = np.float64(round(p,8))
            if (truncated_probablity > 1e-8):
                result.append({'Tag': labels[idx], 'Probability': truncated_probablity })
            idx += 1
        print('Results: ', str(result))
        return result 
开发者ID:jamesbannan,项目名称:pluralsight,代码行数:63,代码来源:predict.py

示例15: main

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import import_graph_def [as 别名]
def main(m_path, out_dir, light=False, test_out=True):
    logger = get_logger("tf1_export", debug=test_out)
    g = Generator(light=light)
    t = tf.placeholder(tf.string, [])
    x = tf.expand_dims(tf.image.decode_jpeg(tf.read_file(t), channels=3), 0)
    x = (tf.cast(x, tf.float32) / 127.5) - 1
    x = g(x, training=False)
    out = tf.cast((tf.squeeze(x, 0) + 1) * 127.5, tf.uint8)
    in_name, out_name = t.op.name, out.op.name
    try:
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            g.load_weights(tf.train.latest_checkpoint(m_path))
            in_graph_def = tf.get_default_graph().as_graph_def()
            out_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, in_graph_def, [out_name])
        tf.reset_default_graph()
        tf.import_graph_def(out_graph_def, name='')
    except ValueError:
        logger.error("Failed to load specified weight.")
        logger.error("If you trained your model with --light, "
                     "consider adding --light when executing this script; otherwise, "
                     "do not add --light when executing this script.")
        exit(1)
    makedirs(out_dir)
    m_cnt = 0
    bpath = 'optimized_graph_light' if light else 'optimized_graph'
    out_path = os.path.join(out_dir, f'{bpath}_{m_cnt:04d}.pb')
    while os.path.exists(out_path):
        m_cnt += 1
        out_path = os.path.join(out_dir, f'{bpath}_{m_cnt:04d}.pb')
    with tf.gfile.GFile(out_path, 'wb') as f:
        f.write(out_graph_def.SerializeToString())
    if test_out:
        with tf.Graph().as_default():
            gd = tf.GraphDef()
            with tf.gfile.GFile(out_path, 'rb') as f:
                gd.ParseFromString(f.read())
            tf.import_graph_def(gd, name='')
            tf.get_default_graph().finalize()
            t = tf.get_default_graph().get_tensor_by_name(f"{in_name}:0")
            out = tf.get_default_graph().get_tensor_by_name(f"{out_name}:0")
            from time import time
            start = time()
            with tf.Session() as sess:
                img = Image.fromarray(sess.run(out, {t: "input_images/temple.jpg"}))
                img.show()
            elapsed = time() - start
            logger.debug(f"{elapsed} sec per img")
    logger.info(f"successfully exported ckpt to {out_path}")
    logger.info(f"input var name: {in_name}:0")
    logger.info(f"output var name: {out_name}:0") 
开发者ID:mnicnc404,项目名称:CartoonGan-tensorflow,代码行数:54,代码来源:to_pb.py


注:本文中的tensorflow.compat.v1.import_graph_def方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。