当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.import_graph_def方法代码示例

本文整理汇总了Python中tensorflow.import_graph_def方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.import_graph_def方法的具体用法?Python tensorflow.import_graph_def怎么用?Python tensorflow.import_graph_def使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.import_graph_def方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="",
            op_dict=None,
            producer_op_list=None
        )
    return graph 
开发者ID:TobiasGruening,项目名称:ARU-Net,代码行数:21,代码来源:util.py

示例2: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def build_from_pb(self):
		with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
		
		tf.import_graph_def(
			graph_def,
			name=""
		)
		with open(self.FLAGS.metaLoad, 'r') as fp:
			self.meta = json.load(fp)
		self.framework = create_framework(self.meta, self.FLAGS)

		# Placeholders
		self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
		self.feed = dict() # other placeholders
		self.out = tf.get_default_graph().get_tensor_by_name('output:0')
		
		self.setup_meta_ops() 
开发者ID:AmeyaWagh,项目名称:Traffic_sign_detection_YOLO,代码行数:21,代码来源:build.py

示例3: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def build_from_pb(self):
        with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        tf.import_graph_def(
            graph_def,
            name=""
        )
        with open(self.FLAGS.metaLoad, 'r') as fp:
            self.meta = json.load(fp)
        self.framework = create_framework(self.meta, self.FLAGS)

        # Placeholders
        self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
        self.feed = dict()  # other placeholders
        self.out = tf.get_default_graph().get_tensor_by_name('output:0')

        self.setup_meta_ops() 
开发者ID:MahmudulAlam,项目名称:Automatic-Identification-and-Counting-of-Blood-Cells,代码行数:21,代码来源:build.py

示例4: worker

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def worker(input_q, output_q):
    # Load a (frozen) Tensorflow model into memory.
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

        sess = tf.Session(graph=detection_graph)

    fps = FPS().start()
    while True:
        fps.update()
        frame = input_q.get()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        output_q.put(detect_objects(frame_rgb, sess, detection_graph))

    fps.stop()
    sess.close() 
开发者ID:datitran,项目名称:object_detector_app,代码行数:23,代码来源:object_detection_multithreading.py

示例5: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def __init__(self, checkpoint_filename, input_name="images",
                 output_name="features"):
        self.session = tf.Session()
        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file_handle.read())
        tf.import_graph_def(graph_def, name="net")
        self.input_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % input_name)
        self.output_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % output_name)

        assert len(self.output_var.get_shape()) == 2
        assert len(self.input_var.get_shape()) == 4
        self.feature_dim = self.output_var.get_shape().as_list()[-1]
        self.image_shape = self.input_var.get_shape().as_list()[1:] 
开发者ID:nwojke,项目名称:deep_sort,代码行数:18,代码来源:generate_detections.py

示例6: load_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(graph_path,tensorboard=False,**kwargs):
    '''
    :param graph_filename: the path of the pb file
    :return: tensorflow graph
    '''
    with gfile.FastGFile(graph_path,'rb') as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,name="")

    if tensorboard:
        writer = tf.summary.FileWriter("log/")
        writer.add_graph(graph)

    return graph 
开发者ID:bill9800,项目名称:speech_separation,代码行数:19,代码来源:utils.py

示例7: create_inception_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Graph().as_default() as graph:
    model_filename = os.path.join(
        FLAGS.model_dir, 'classify_image_graph_def.pb')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 
开发者ID:ArunMichaelDsouza,项目名称:tensorflow-image-detection,代码行数:20,代码来源:retrain.py

示例8: _import_graph_and_run_inference

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
    """Imports a tflite graph, runs single inference and returns outputs."""
    graph = tf.Graph()
    with graph.as_default():
      graph_def = tf.GraphDef()
      with tf.gfile.Open(tflite_graph_file) as f:
        graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
      input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
      box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
      class_predictions = graph.get_tensor_by_name(
          'raw_outputs/class_predictions:0')
      with self.test_session(graph) as sess:
        [box_encodings_np, class_predictions_np] = sess.run(
            [box_encodings, class_predictions],
            feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
    return box_encodings_np, class_predictions_np 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:19,代码来源:export_tflite_ssd_graph_lib_test.py

示例9: _check_output

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _check_output(gin, tf_input, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected)


# TODO: we could factorize with _check_output, but this is not worth the time doing it. 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:20,代码来源:test_import.py

示例10: _check_output_2

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def _check_output_2(gin, tf_input1, tf_input2, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph)
        tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected) 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:18,代码来源:test_import.py

示例11: fromGraphDef

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
        """
        Construct a TFInputGraph from a tf.GraphDef object.

        :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
                           computation units of the TensorFlow graph.
        :param feed_names: list, names of the input tensors.
        :param fetch_names: list, names of the output tensors.
        """
        assert isinstance(graph_def, tf.GraphDef), \
            ('expect tf.GraphDef type but got', type(graph_def))

        graph = tf.Graph()
        with tf.Session(graph=graph) as sess:
            tf.import_graph_def(graph_def, name='')
            return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
                                             fetch_names=fetch_names) 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:19,代码来源:input.py

示例12: initialize

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def initialize():
    print('Loading model...',end=''),
    with open(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    # Retrieving 'network_input_size' from shape of 'input_node'
    with tf.compat.v1.Session() as sess:
        input_tensor_shape = sess.graph.get_tensor_by_name(input_node).shape.as_list()
        
    assert len(input_tensor_shape) == 4
    assert input_tensor_shape[1] == input_tensor_shape[2]

    global network_input_size
    network_input_size = input_tensor_shape[1]
   
    print('Success!')
    print('Loading labels...', end='')
    with open(labels_filename, 'rt') as lf:
        global labels
        labels = [l.strip() for l in lf.readlines()]
    print(len(labels), 'found. Success!') 
开发者ID:Azure-Samples,项目名称:Custom-vision-service-iot-edge-raspberry-pi,代码行数:24,代码来源:predict-amd64.py

示例13: inference

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated) 
开发者ID:vanhuyz,项目名称:CycleGAN-TensorFlow,代码行数:25,代码来源:inference.py

示例14: load_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_graph(frozen_graph_file):

    # we parse the graph_def file
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # we load the graph_def in the default graph

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,
                            input_map=None,
                            return_elements=None,
                            name="",
                            op_dict=None,
                            producer_op_list=None)
    return graph 
开发者ID:DetectionTeamUCAS,项目名称:R2CNN_Faster-RCNN_Tensorflow,代码行数:19,代码来源:test_exportPb.py

示例15: load_inference_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import import_graph_def [as 别名]
def load_inference_graph():
    # load frozen tensorflow model into memory
    print("> ====== loading HAND frozen graph into memory")
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
        sess = tf.Session(graph=detection_graph)
    print(">  ====== Hand Inference graph loaded.")
    return detection_graph, sess


# draw the detected bounding boxes on the images
# You can modify this to also draw a label. 
开发者ID:akshaybahadur21,项目名称:Emojinator,代码行数:19,代码来源:detector_utils.py


注:本文中的tensorflow.import_graph_def方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。