當前位置: 首頁>>代碼示例>>Python>>正文


Python tensorflow.import_graph_def方法代碼示例

本文整理匯總了Python中tensorflow.import_graph_def方法的典型用法代碼示例。如果您正苦於以下問題:Python tensorflow.import_graph_def方法的具體用法?Python tensorflow.import_graph_def怎麽用?Python tensorflow.import_graph_def使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.import_graph_def方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: load_graph

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="",
            op_dict=None,
            producer_op_list=None
        )
    return graph 
開發者ID:TobiasGruening,項目名稱:ARU-Net,代碼行數:21,代碼來源:util.py

示例2: build_from_pb

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def build_from_pb(self):
		with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
		
		tf.import_graph_def(
			graph_def,
			name=""
		)
		with open(self.FLAGS.metaLoad, 'r') as fp:
			self.meta = json.load(fp)
		self.framework = create_framework(self.meta, self.FLAGS)

		# Placeholders
		self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
		self.feed = dict() # other placeholders
		self.out = tf.get_default_graph().get_tensor_by_name('output:0')
		
		self.setup_meta_ops() 
開發者ID:AmeyaWagh,項目名稱:Traffic_sign_detection_YOLO,代碼行數:21,代碼來源:build.py

示例3: build_from_pb

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def build_from_pb(self):
        with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        tf.import_graph_def(
            graph_def,
            name=""
        )
        with open(self.FLAGS.metaLoad, 'r') as fp:
            self.meta = json.load(fp)
        self.framework = create_framework(self.meta, self.FLAGS)

        # Placeholders
        self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
        self.feed = dict()  # other placeholders
        self.out = tf.get_default_graph().get_tensor_by_name('output:0')

        self.setup_meta_ops() 
開發者ID:MahmudulAlam,項目名稱:Automatic-Identification-and-Counting-of-Blood-Cells,代碼行數:21,代碼來源:build.py

示例4: worker

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def worker(input_q, output_q):
    # Load a (frozen) Tensorflow model into memory.
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

        sess = tf.Session(graph=detection_graph)

    fps = FPS().start()
    while True:
        fps.update()
        frame = input_q.get()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        output_q.put(detect_objects(frame_rgb, sess, detection_graph))

    fps.stop()
    sess.close() 
開發者ID:datitran,項目名稱:object_detector_app,代碼行數:23,代碼來源:object_detection_multithreading.py

示例5: __init__

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def __init__(self, checkpoint_filename, input_name="images",
                 output_name="features"):
        self.session = tf.Session()
        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file_handle.read())
        tf.import_graph_def(graph_def, name="net")
        self.input_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % input_name)
        self.output_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % output_name)

        assert len(self.output_var.get_shape()) == 2
        assert len(self.input_var.get_shape()) == 4
        self.feature_dim = self.output_var.get_shape().as_list()[-1]
        self.image_shape = self.input_var.get_shape().as_list()[1:] 
開發者ID:nwojke,項目名稱:deep_sort,代碼行數:18,代碼來源:generate_detections.py

示例6: load_graph

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def load_graph(graph_path,tensorboard=False,**kwargs):
    '''
    :param graph_filename: the path of the pb file
    :return: tensorflow graph
    '''
    with gfile.FastGFile(graph_path,'rb') as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,name="")

    if tensorboard:
        writer = tf.summary.FileWriter("log/")
        writer.add_graph(graph)

    return graph 
開發者ID:bill9800,項目名稱:speech_separation,代碼行數:19,代碼來源:utils.py

示例7: create_inception_graph

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Graph().as_default() as graph:
    model_filename = os.path.join(
        FLAGS.model_dir, 'classify_image_graph_def.pb')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 
開發者ID:ArunMichaelDsouza,項目名稱:tensorflow-image-detection,代碼行數:20,代碼來源:retrain.py

示例8: _import_graph_and_run_inference

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
    """Imports a tflite graph, runs single inference and returns outputs."""
    graph = tf.Graph()
    with graph.as_default():
      graph_def = tf.GraphDef()
      with tf.gfile.Open(tflite_graph_file) as f:
        graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
      input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
      box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
      class_predictions = graph.get_tensor_by_name(
          'raw_outputs/class_predictions:0')
      with self.test_session(graph) as sess:
        [box_encodings_np, class_predictions_np] = sess.run(
            [box_encodings, class_predictions],
            feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
    return box_encodings_np, class_predictions_np 
開發者ID:ahmetozlu,項目名稱:vehicle_counting_tensorflow,代碼行數:19,代碼來源:export_tflite_ssd_graph_lib_test.py

示例9: _check_output

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def _check_output(gin, tf_input, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected)


# TODO: we could factorize with _check_output, but this is not worth the time doing it. 
開發者ID:databricks,項目名稱:spark-deep-learning,代碼行數:20,代碼來源:test_import.py

示例10: _check_output_2

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def _check_output_2(gin, tf_input1, tf_input2, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph)
        tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected) 
開發者ID:databricks,項目名稱:spark-deep-learning,代碼行數:18,代碼來源:test_import.py

示例11: fromGraphDef

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
        """
        Construct a TFInputGraph from a tf.GraphDef object.

        :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
                           computation units of the TensorFlow graph.
        :param feed_names: list, names of the input tensors.
        :param fetch_names: list, names of the output tensors.
        """
        assert isinstance(graph_def, tf.GraphDef), \
            ('expect tf.GraphDef type but got', type(graph_def))

        graph = tf.Graph()
        with tf.Session(graph=graph) as sess:
            tf.import_graph_def(graph_def, name='')
            return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
                                             fetch_names=fetch_names) 
開發者ID:databricks,項目名稱:spark-deep-learning,代碼行數:19,代碼來源:input.py

示例12: initialize

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def initialize():
    print('Loading model...',end=''),
    with open(filename, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    # Retrieving 'network_input_size' from shape of 'input_node'
    with tf.compat.v1.Session() as sess:
        input_tensor_shape = sess.graph.get_tensor_by_name(input_node).shape.as_list()
        
    assert len(input_tensor_shape) == 4
    assert input_tensor_shape[1] == input_tensor_shape[2]

    global network_input_size
    network_input_size = input_tensor_shape[1]
   
    print('Success!')
    print('Loading labels...', end='')
    with open(labels_filename, 'rt') as lf:
        global labels
        labels = [l.strip() for l in lf.readlines()]
    print(len(labels), 'found. Success!') 
開發者ID:Azure-Samples,項目名稱:Custom-vision-service-iot-edge-raspberry-pi,代碼行數:24,代碼來源:predict-amd64.py

示例13: inference

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated) 
開發者ID:vanhuyz,項目名稱:CycleGAN-TensorFlow,代碼行數:25,代碼來源:inference.py

示例14: load_graph

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def load_graph(frozen_graph_file):

    # we parse the graph_def file
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # we load the graph_def in the default graph

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,
                            input_map=None,
                            return_elements=None,
                            name="",
                            op_dict=None,
                            producer_op_list=None)
    return graph 
開發者ID:DetectionTeamUCAS,項目名稱:R2CNN_Faster-RCNN_Tensorflow,代碼行數:19,代碼來源:test_exportPb.py

示例15: load_inference_graph

# 需要導入模塊: import tensorflow [as 別名]
# 或者: from tensorflow import import_graph_def [as 別名]
def load_inference_graph():
    # load frozen tensorflow model into memory
    print("> ====== loading HAND frozen graph into memory")
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
        sess = tf.Session(graph=detection_graph)
    print(">  ====== Hand Inference graph loaded.")
    return detection_graph, sess


# draw the detected bounding boxes on the images
# You can modify this to also draw a label. 
開發者ID:akshaybahadur21,項目名稱:Emojinator,代碼行數:19,代碼來源:detector_utils.py


注:本文中的tensorflow.import_graph_def方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。