当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.GraphDef方法代码示例

本文整理汇总了Python中tensorflow.GraphDef方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.GraphDef方法的具体用法?Python tensorflow.GraphDef怎么用?Python tensorflow.GraphDef使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.GraphDef方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="",
            op_dict=None,
            producer_op_list=None
        )
    return graph 
开发者ID:TobiasGruening,项目名称:ARU-Net,代码行数:21,代码来源:util.py

示例2: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def build_from_pb(self):
		with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
		
		tf.import_graph_def(
			graph_def,
			name=""
		)
		with open(self.FLAGS.metaLoad, 'r') as fp:
			self.meta = json.load(fp)
		self.framework = create_framework(self.meta, self.FLAGS)

		# Placeholders
		self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
		self.feed = dict() # other placeholders
		self.out = tf.get_default_graph().get_tensor_by_name('output:0')
		
		self.setup_meta_ops() 
开发者ID:AmeyaWagh,项目名称:Traffic_sign_detection_YOLO,代码行数:21,代码来源:build.py

示例3: build_from_pb

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def build_from_pb(self):
        with tf.gfile.FastGFile(self.FLAGS.pbLoad, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        tf.import_graph_def(
            graph_def,
            name=""
        )
        with open(self.FLAGS.metaLoad, 'r') as fp:
            self.meta = json.load(fp)
        self.framework = create_framework(self.meta, self.FLAGS)

        # Placeholders
        self.inp = tf.get_default_graph().get_tensor_by_name('input:0')
        self.feed = dict()  # other placeholders
        self.out = tf.get_default_graph().get_tensor_by_name('output:0')

        self.setup_meta_ops() 
开发者ID:MahmudulAlam,项目名称:Automatic-Identification-and-Counting-of-Blood-Cells,代码行数:21,代码来源:build.py

示例4: worker

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def worker(input_q, output_q):
    # Load a (frozen) Tensorflow model into memory.
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

        sess = tf.Session(graph=detection_graph)

    fps = FPS().start()
    while True:
        fps.update()
        frame = input_q.get()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        output_q.put(detect_objects(frame_rgb, sess, detection_graph))

    fps.stop()
    sess.close() 
开发者ID:datitran,项目名称:object_detector_app,代码行数:23,代码来源:object_detection_multithreading.py

示例5: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def __init__(self, checkpoint_filename, input_name="images",
                 output_name="features"):
        self.session = tf.Session()
        with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(file_handle.read())
        tf.import_graph_def(graph_def, name="net")
        self.input_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % input_name)
        self.output_var = tf.get_default_graph().get_tensor_by_name(
            "net/%s:0" % output_name)

        assert len(self.output_var.get_shape()) == 2
        assert len(self.input_var.get_shape()) == 4
        self.feature_dim = self.output_var.get_shape().as_list()[-1]
        self.image_shape = self.input_var.get_shape().as_list()[1:] 
开发者ID:nwojke,项目名称:deep_sort,代码行数:18,代码来源:generate_detections.py

示例6: create_inception_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Graph().as_default() as graph:
    model_filename = os.path.join(
        FLAGS.model_dir, 'classify_image_graph_def.pb')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 
开发者ID:ArunMichaelDsouza,项目名称:tensorflow-image-detection,代码行数:20,代码来源:retrain.py

示例7: _import_graph_and_run_inference

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def _import_graph_and_run_inference(self, tflite_graph_file, num_channels=3):
    """Imports a tflite graph, runs single inference and returns outputs."""
    graph = tf.Graph()
    with graph.as_default():
      graph_def = tf.GraphDef()
      with tf.gfile.Open(tflite_graph_file) as f:
        graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')
      input_tensor = graph.get_tensor_by_name('normalized_input_image_tensor:0')
      box_encodings = graph.get_tensor_by_name('raw_outputs/box_encodings:0')
      class_predictions = graph.get_tensor_by_name(
          'raw_outputs/class_predictions:0')
      with self.test_session(graph) as sess:
        [box_encodings_np, class_predictions_np] = sess.run(
            [box_encodings, class_predictions],
            feed_dict={input_tensor: np.random.rand(1, 10, 10, num_channels)})
    return box_encodings_np, class_predictions_np 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:19,代码来源:export_tflite_ssd_graph_lib_test.py

示例8: fromGraphDef

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def fromGraphDef(cls, graph_def, feed_names, fetch_names):
        """
        Construct a TFInputGraph from a tf.GraphDef object.

        :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and
                           computation units of the TensorFlow graph.
        :param feed_names: list, names of the input tensors.
        :param fetch_names: list, names of the output tensors.
        """
        assert isinstance(graph_def, tf.GraphDef), \
            ('expect tf.GraphDef type but got', type(graph_def))

        graph = tf.Graph()
        with tf.Session(graph=graph) as sess:
            tf.import_graph_def(graph_def, name='')
            return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names,
                                             fetch_names=fetch_names) 
开发者ID:databricks,项目名称:spark-deep-learning,代码行数:19,代码来源:input.py

示例9: inference

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated) 
开发者ID:vanhuyz,项目名称:CycleGAN-TensorFlow,代码行数:25,代码来源:inference.py

示例10: load_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def load_graph(frozen_graph_file):

    # we parse the graph_def file
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # we load the graph_def in the default graph

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,
                            input_map=None,
                            return_elements=None,
                            name="",
                            op_dict=None,
                            producer_op_list=None)
    return graph 
开发者ID:DetectionTeamUCAS,项目名称:R2CNN_Faster-RCNN_Tensorflow,代码行数:19,代码来源:test_exportPb.py

示例11: load_inference_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def load_inference_graph():
    # load frozen tensorflow model into memory
    print("> ====== loading HAND frozen graph into memory")
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
        sess = tf.Session(graph=detection_graph)
    print(">  ====== Hand Inference graph loaded.")
    return detection_graph, sess


# draw the detected bounding boxes on the images
# You can modify this to also draw a label. 
开发者ID:akshaybahadur21,项目名称:Emojinator,代码行数:19,代码来源:detector_utils.py

示例12: create_inception_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Session() as sess:
    model_filename = os.path.join(
        FLAGS.model_dir, 'classify_image_graph_def.pb')
    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor 
开发者ID:javathunderman,项目名称:diabetic-retinopathy-screening,代码行数:20,代码来源:retrain.py

示例13: load_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def load_model(model, input_map=None):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    #  or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with gfile.FastGFile(model_exp,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, input_map=input_map, name='')
    else:
        print('Model directory: %s' % model_exp)
        meta_file, ckpt_file = get_model_filenames(model_exp)
        
        print('Metagraph file: %s' % meta_file)
        print('Checkpoint file: %s' % ckpt_file)
      
        saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
        saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file)) 
开发者ID:GaoangW,项目名称:TNT,代码行数:21,代码来源:facenet.py

示例14: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def __init__(self, PATH_TO_CKPT):
        """Tensorflow detector
        """

        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')


        with self.detection_graph.as_default():
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            with tf.Session(graph=self.detection_graph, config=config) as self.sess:

                self.windowNotSet = True 
开发者ID:Seymour-Lee,项目名称:face-detection-ssd-mobilenet,代码行数:21,代码来源:detect_face.py

示例15: load_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import GraphDef [as 别名]
def load_model(self):
        """
        Loads the detection model

        Args:

        Returns:

        """

        with self._detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self._path_to_ckpt, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')

        label_map = label_map_util.load_labelmap(self._path_to_labels)
        categories = label_map_util.convert_label_map_to_categories(\
            label_map, max_num_classes=self._num_classes, use_display_name=True)
        self.category_index = label_map_util.create_category_index(categories) 
开发者ID:cagbal,项目名称:ros_people_object_detection_tensorflow,代码行数:23,代码来源:detector.py


注:本文中的tensorflow.GraphDef方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。