当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.import_graph_def函数代码示例

本文整理汇总了Python中tensorflow.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了import_graph_def函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: loadmodle

def loadmodle():
	print u"step2:模型加载测试".decode('utf8')
	with tf.Session() as persisted_sess:
		print("---1:load graph") #加载计算图
		with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
			persisted_sess.graph.as_default()
			tf.import_graph_def(graph_def, name='') #加载图定义

		print("---2,map variables")
		persisted_result = persisted_sess.graph.get_tensor_by_name("saved1_result:0") #获取这个tensor
		tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)  				 #将这个tensor加入到要恢复的变量中

		# 恢复数据
		print("---3,load data")
		try:
			saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!  #将变量恢复
		except Exception,e:
			print(str(e))
		saver.restore(persisted_sess, "checkpoint.data")  # 将变量的数据重新加载到各个tensor


		#重现运算
		print(persisted_result.eval())
		print("DONE")
开发者ID:tuling56,项目名称:Python,代码行数:26,代码来源:model_save_restore.py

示例2: run_graph_def

def run_graph_def(graph_def, input_map, outputs):
  graph = tf.Graph()
  with graph.as_default():
    tf.import_graph_def(graph_def, input_map={}, name="")
  with tf.Session(graph=graph) as sess:
    results = sess.run(outputs, feed_dict=input_map)
  return results
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:7,代码来源:quantize_graph_test.py

示例3: testInvalidInputForInputMap

 def testInvalidInputForInputMap(self):
   with tf.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       tf.import_graph_def(self._MakeGraphDef(''),
                               input_map=[tf.constant(5.0)])
     self.assertEqual('input_map must be a dictionary mapping strings to '
                      'Tensor objects.', str(e.exception))
开发者ID:yevgeniyfrenkel,项目名称:tensorflow,代码行数:7,代码来源:importer_test.py

示例4: graphdef_to_pbtxt

def graphdef_to_pbtxt(filename): 
  with gfile.FastGFile(filename,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pbtxt', as_text=True)
  return
开发者ID:chrhansen,项目名称:tensorflow.rb,代码行数:7,代码来源:converter.py

示例5: __init__

    def __init__(self):
        # Now load the Inception model from file. The way TensorFlow
        # does this is confusing and requires several steps.

        # Create a new TensorFlow computational graph.
        self.graph = tf.Graph()

        # Set the new graph as the default.
        with self.graph.as_default():

            # TensorFlow graphs are saved to disk as so-called Protocol Buffers
            # aka. proto-bufs which is a file-format that works on multiple
            # platforms. In this case it is saved as a binary file.

            # Open the graph-def file for binary reading.
            path = os.path.join(data_dir, path_graph_def)
            with tf.gfile.FastGFile(path, 'rb') as file:
                # The graph-def is a saved copy of a TensorFlow graph.
                # First we need to create an empty graph-def.
                graph_def = tf.GraphDef()

                # Then we load the proto-buf file into the graph-def.
                graph_def.ParseFromString(file.read())

                # Finally we import the graph-def to the default TensorFlow graph.
                tf.import_graph_def(graph_def, name='')

                # Now self.graph holds the Inception model from the proto-buf file.

            # Get a reference to the tensor for inputting images to the graph.
            self.input = self.graph.get_tensor_by_name(self.tensor_name_input_image)

            # Get references to the tensors for the commonly used layers.
            self.layer_tensors = [self.graph.get_tensor_by_name(name + ":0") for name in self.layer_names]
开发者ID:Hvass-Labs,项目名称:TensorFlow-Tutorials,代码行数:34,代码来源:inception5h.py

示例6: strip_and_freeze_until

def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
    """
    Create a static view of the graph by

    * Converting all variables into constants
    * Removing graph elements not reachacble to `fetches`

    :param graph: tf.Graph, the graph to be frozen
    :param fetches: list, graph elements representing the outputs of the graph
    :param return_graph: bool, if set True, return the graph function object
    :return: GraphDef, the GraphDef object with cleanup procedure applied
    """
    graph = validated_graph(graph)
    should_close_session = False
    if not sess:
        sess = tf.Session(graph=graph)
        should_close_session = True

    gdef_frozen = tf.graph_util.convert_variables_to_constants(
        sess,
        graph.as_graph_def(add_shapes=True),
        [op_name(graph, tnsr) for tnsr in fetches])

    if should_close_session:
        sess.close()

    if return_graph:
        g = tf.Graph()
        with g.as_default():
            tf.import_graph_def(gdef_frozen, name='')
        return g
    else:
        return gdef_frozen
开发者ID:seanpquig,项目名称:spark-deep-learning,代码行数:33,代码来源:utils.py

示例7: __init__

	def __init__(self, proxy_map):
		super(SpecificWorker, self).__init__(proxy_map)
		self.timer.timeout.connect(self.compute)
		self.Period = 100
		self.timer.start(self.Period)

		# SIFT feature extractor
		self.feature_extractor = cv2.xfeatures2d.SIFT_create()

		# Create a dense grid of keypoints
		self.keypoints=list()
		for i in range(5,IMAGE_SIZE,12):
			for j in range(5,IMAGE_SIZE,12):
				self.keypoints.append(cv2.KeyPoint(i,j,12))

		# Create a tensorflow session
		self.sess=tf.Session()

		# Read the frozen graph from the model file
		with gfile.FastGFile(MODEL_FILE,'rb') as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
			self.sess.graph.as_default()
			tf.import_graph_def(graph_def, name='')

			# Get input and output tensors from graph
			self.x_input = self.sess.graph.get_tensor_by_name("input:0")
			self.output = self.sess.graph.get_tensor_by_name("output:0")
			self.dsift = self.sess.graph.get_tensor_by_name("sift:0")
开发者ID:robocomp,项目名称:robocomp-robolab,代码行数:29,代码来源:specificworker.py

示例8: Import

def Import(sess):
    with gfile.FastGFile("../models/producttype/graph.pb",'rb') as f:
        graph_def = tf.GraphDef()
        content = f.read()
        graph_def.ParseFromString(content)
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
开发者ID:daizhen,项目名称:ImagesCategory,代码行数:7,代码来源:import_model.py

示例9: _get_expected_result

def _get_expected_result(gin, local_features):
    """
    Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results.
    :param: gin, a :py:obj:`TFInputGraph`
    :return: expected results in NumPy array
    """
    graph = tf.Graph()
    with tf.Session(graph=graph) as sess, graph.as_default():
        # Build test graph and transformers from here
        tf.import_graph_def(gin.graph_def, name='')

        # Build the results
        _results = []
        for row in local_features:
            fetches = [tfx.get_tensor(tnsr_name, graph)
                       for tnsr_name, _ in _output_mapping.items()]
            feed_dict = {}
            for colname, tnsr_name in _input_mapping.items():
                tnsr = tfx.get_tensor(tnsr_name, graph)
                feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :]

            curr_res = sess.run(fetches, feed_dict=feed_dict)
            _results.append(np.ravel(curr_res))

        expected = np.hstack(_results)

    return expected
开发者ID:pawanrana,项目名称:spark-deep-learning,代码行数:27,代码来源:tf_transformer_test.py

示例10: main

def main(_):
    labels = [line.rstrip() for line in tf.gfile.GFile(FLAGS.output_labels)]

    with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fp.read())
        tf.import_graph_def(graph_def, name='')

    with tf.Session() as sess:
        logits = sess.graph.get_tensor_by_name('final_result:0')
        image = tf.gfile.FastGFile(sys.argv[1], 'rb').read()
        prediction = sess.run(logits, {'DecodeJpeg/contents:0': image})

    # print('=== 예측 결과 ===')
    # top_result = int(np.argmax(prediction[0]))
    # name = labels[top_result]
    # score = prediction[0][top_result]
    # print('%s (%.2f%%)' % (name, score * 100))

    print('=== 예측 결과 ===')
    for i in range(len(labels)):
        name = labels[i]
        score = prediction[0][i]
        print('%s (%.2f%%)' % (name, score * 100))

    if FLAGS.show_image:
        img = mpimg.imread(sys.argv[1])
        plt.imshow(img)
        plt.show()
开发者ID:superhg2012,项目名称:TensorFlow-Tutorials,代码行数:29,代码来源:predict.py

示例11: classify

    def classify(self, path, resize_height, resize_width):
        """ Resizes the passed image to indicated dimensions and estimates its
            VP using the graph stored self.filename.
        """ 
        self.info("Manually classifying the image in " + str(path))
        # Load freezed graph from file.
        graph_def = tf.GraphDef()
        with open(self.filename, 'rb') as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def)

        predictions = []
        with tf.Session() as sess:
            # Load output node to use for predictions.
            output_node_processed = sess.graph.get_tensor_by_name('import/output_processed:0')
            # Iterate files from directory.
            start_time = time.time()
            # Read image 
            img = cv.imread(path, 1)
            # Process image that will be evaluated by the model.
            img_pred = imresize(img, [resize_height, resize_width], 'bilinear')
            img_pred = img_pred.astype(np.float32)
            img_pred = np.multiply(img_pred, 1.0 / 256.0)
            img_pred = img_pred.flatten()
            # Compute prediction point.
            predictions = output_node_processed.eval(
                feed_dict = {
                    'import/input_images:0': img_pred,
                    'import/keep_prob:0': 1.0
                }
            )
            predictions = np.round(predictions).astype(int)
            self.info('Predicted Point Processed: (' + str(int(round(predictions[0][0]))) + ', ' + str(int(round(predictions[0][1]))) + ')')
        return predictions
开发者ID:se-research-studies,项目名称:2016-itsc,代码行数:34,代码来源:VPClassifier.py

示例12: __init__

    def __init__(self):
        logger.info('Loading Tensorflow Detection API')

        weights_path = get_file(config.SSD_INCEPTION_FILENAME, config.SSD_INCEPTION_URL,
                                cache_dir=os.path.abspath(config.WEIGHT_PATH),
                                cache_subdir='models')

        extract_path = weights_path.replace('.tar.gz', '')
        if not os.path.exists(extract_path):
            tar = tarfile.open(weights_path, "r:gz")
            tar.extractall(path=os.path.join(config.WEIGHT_PATH, 'models'))
            tar.close()
        pb_path = os.path.join(extract_path, self.PB_NAME)

        self.graph = tf.Graph()
        with self.graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(pb_path, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')

        self.label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        self.categories = label_map_util.convert_label_map_to_categories(self.label_map,
                                                                         max_num_classes=self.NUM_CLASSES,
                                                                         use_display_name=True)
        self.category_index = label_map_util.create_category_index(self.categories)
开发者ID:mohamed-akram,项目名称:pretrained.ml,代码行数:27,代码来源:models.py

示例13: build_prepro_graph

def build_prepro_graph(inception_path):
    global input_layer, output_layer
    with open(inception_path, 'rb') as f:
        fileContent = f.read()

    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fileContent)
    tf.import_graph_def(graph_def)
    graph = tf.get_default_graph()

    input_layer = graph.get_tensor_by_name("import/InputImage:0")
    output_layer = graph.get_tensor_by_name(
        "import/InceptionV4/Logits/AvgPool_1a/AvgPool:0")

    input_file = tf.placeholder(dtype=tf.string, name="InputFile")
    image_file = tf.read_file(input_file)
    jpg = tf.image.decode_jpeg(image_file, channels=3)
    png = tf.image.decode_png(image_file, channels=3)
    output_jpg = tf.image.resize_images(jpg, [299, 299]) / 255.0
    output_jpg = tf.reshape(
        output_jpg, [
            1, 299, 299, 3], name="Preprocessed_JPG")
    output_png = tf.image.resize_images(png, [299, 299]) / 255.0
    output_png = tf.reshape(
        output_png, [
            1, 299, 299, 3], name="Preprocessed_PNG")
    return input_file, output_jpg, output_png
开发者ID:suryawanshishantanu6,项目名称:image-caption-generator,代码行数:27,代码来源:convfeatures.py

示例14: __init__

 def __init__(self, name, input):
     self.name = name
     with open("models/vgg16.tfmodel", mode='rb') as f:
         fileContent = f.read()
     graph_def = tf.GraphDef()
     graph_def.ParseFromString(fileContent)
     tf.import_graph_def(graph_def, input_map={ "images": input }, name=self.name)
开发者ID:fgeorg,项目名称:texture-networks,代码行数:7,代码来源:vgg_network.py

示例15: load_graph

def load_graph(path):
    with tf.gfile.GFile(path, mode='rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph
开发者ID:forin-xyz,项目名称:FoolNLTK,代码行数:7,代码来源:model.py


注:本文中的tensorflow.import_graph_def函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。