当前位置: 首页>>代码示例>>Python>>正文


Python graph_transforms.TransformGraph方法代码示例

本文整理汇总了Python中tensorflow.tools.graph_transforms.TransformGraph方法的典型用法代码示例。如果您正苦于以下问题:Python graph_transforms.TransformGraph方法的具体用法?Python graph_transforms.TransformGraph怎么用?Python graph_transforms.TransformGraph使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.tools.graph_transforms的用法示例。


在下文中一共展示了graph_transforms.TransformGraph方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: writeTextGraph

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def writeTextGraph(modelPath, outputPath, outNodes):
    try:
        import cv2 as cv

        cv.dnn.writeTextGraph(modelPath, outputPath)
    except:
        import tensorflow as tf
        from tensorflow.tools.graph_transforms import TransformGraph

        with tf.gfile.FastGFile(modelPath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

            graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])

            for node in graph_def.node:
                if node.op == 'Const':
                    if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
                        node.attr['value'].tensor.tensor_content = b''

        tf.train.write_graph(graph_def, "", outputPath, as_text=True) 
开发者ID:jing-vision,项目名称:lightnet,代码行数:23,代码来源:tf_text_graph_common.py

示例2: optimize_graph

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def optimize_graph(model_dir, graph_filename, transforms, input_name, output_names, outname='optimized_model.pb'):
  input_names = [input_name] # change this as per how you have saved the model
  graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
  optimized_graph_def = TransformGraph(
      graph_def,
      input_names,  
      output_names,
      transforms)
  tf.train.write_graph(optimized_graph_def,
                      logdir=model_dir,
                      as_text=False,
                      name=outname)
  print('Graph optimized!') 
开发者ID:PINTO0309,项目名称:PINTO_model_zoo,代码行数:15,代码来源:01_freeze_the_saved_model_v1.py

示例3: transform

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def transform(self, ugraph):
    if ugraph.lib_name != 'tensorflow':
      raise ValueError('only support tensorflow graph')
    graph_def = ugraph.graph_def
    if TransformGraph is None:
      raise RuntimeError("quantization is temporary not supported")
    quant_graph_def = TransformGraph(input_graph_def=graph_def,
                                     inputs=[],
                                     outputs=ugraph.output_nodes,
                                     transforms=["quantize_weights", "quantize_nodes"])
    return GraphDefParser(config={}).parse(
      quant_graph_def,
      output_nodes=ugraph.output_nodes
    ) 
开发者ID:uTensor,项目名称:utensor_cgen,代码行数:16,代码来源:quantize.py

示例4: convert_to_pb

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def convert_to_pb(model, path, input_layer_name,  output_layer_name, pbfilename, verbose=False):

  model.load(path,weights_only=True)
  print("[INFO] Loaded CNN network weights from " + path + " ...")

  print("[INFO] Re-export model ...")
  del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
  model.save("model-tmp.tfl")

  # taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow

  print("[INFO] Re-import model ...")

  input_checkpoint = "model-tmp.tfl"
  saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
  sess = tf.Session();
  saver.restore(sess, input_checkpoint)

  # print out all layers to find name of output

  if (verbose):
      op = sess.graph.get_operations()
      [print(m.values()) for m in op][1]

  print("[INFO] Freeze model to " +  pbfilename + " ...")

  # freeze and removes nodes which are not related to feedforward prediction

  minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])

  graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name], [output_layer_name], tf.float32.as_datatype_enum)
  graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])

  with tf.gfile.GFile(pbfilename, 'wb') as f:
      f.write(graph_def.SerializeToString())

  # write model to logs dir so we can visualize it as:
  # tensorboard --logdir="logs"

  if (verbose):
      writer = tf.summary.FileWriter('logs', graph_def)
      writer.close()

  # tidy up tmp files

  for f in glob.glob("model-tmp.tfl*"):
      os.remove(f)

  os.remove('checkpoint')

################################################################################
# convert a  binary .pb protocol buffer format model to tflite format

# e.g. for FireNet
#    pbfilename = "firenet.pb"
#    input_layer_name = 'InputData/X'                  # input layer of network
#    output_layer_name= 'FullyConnected_2/Softmax'     # output layer of network 
开发者ID:tobybreckon,项目名称:fire-detection-cnn,代码行数:59,代码来源:converter.py

示例5: export_tensorflow_model

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def export_tensorflow_model(self, output_fld, output_model_file=None,
                                 output_graphdef_file=None,
                                 num_output=None,
                                quantize=False,
                                save_output_graphdef_file=False,
                                 output_node_prefix=None):

        K.set_learning_phase(0)

        if output_model_file is None:
            output_model_file = Cifar10AudioClassifier.model_name + '.pb'

        if output_graphdef_file is None:
            output_graphdef_file = 'model.ascii'
        if num_output is None:
            num_output = 1
        if output_node_prefix is None:
            output_node_prefix = 'output_node'

        pred = [None] * num_output
        pred_node_names = [None] * num_output
        for i in range(num_output):
            pred_node_names[i] = output_node_prefix + str(i)
            pred[i] = tf.identity(self.model.outputs[i], name=pred_node_names[i])
        print('output nodes names are: ', pred_node_names)

        sess = K.get_session()

        if save_output_graphdef_file:
            tf.train.write_graph(sess.graph.as_graph_def(), output_fld, output_graphdef_file, as_text=True)
            print('saved the graph definition in ascii format at: ', output_graphdef_file)

        from tensorflow.python.framework import graph_util
        from tensorflow.python.framework import graph_io
        from tensorflow.tools.graph_transforms import TransformGraph
        if quantize:
            transforms = ["quantize_weights", "quantize_nodes"]
            transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
            constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
        else:
            constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
        graph_io.write_graph(constant_graph, output_fld, output_model_file, as_text=False)
        print('saved the freezed graph (ready for inference) at: ', output_model_file) 
开发者ID:chen0040,项目名称:keras-audio,代码行数:45,代码来源:cifar10.py

示例6: optimize

# 需要导入模块: from tensorflow.tools import graph_transforms [as 别名]
# 或者: from tensorflow.tools.graph_transforms import TransformGraph [as 别名]
def optimize(self, sess, dataset, path, device):
        """The best performing model is frozen, optimized for inference
           by removing unneeded training operations, and written to disk.

        Args:
            sess (object): The current TF training session.
            path (str): The path used for saving the model.
            device (str): Represents either "cpu" or "gpu".

        .. seealso:: https://bit.ly/2VBBdqQ and https://bit.ly/2W7YqBa
        """

        model_name = "model_%s_%s" % (dataset, device)
        model_path = path + model_name

        tf.train.write_graph(sess.graph.as_graph_def(),
                             path, model_name + ".pbtxt")

        freeze_graph.freeze_graph(model_path + ".pbtxt", "", False,
                                  model_path + ".ckpt", "output",
                                  "save/restore_all", "save/Const:0",
                                  model_path + ".pb", True, "")

        os.remove(model_path + ".pbtxt")

        graph_def = tf.GraphDef()

        with tf.gfile.Open(model_path + ".pb", "rb") as file:
            graph_def.ParseFromString(file.read())

        transforms = ["remove_nodes(op=Identity)",
                      "merge_duplicate_nodes",
                      "strip_unused_nodes",
                      "fold_constants(ignore_errors=true)"]

        optimized_graph_def = TransformGraph(graph_def,
                                             ["input"],
                                             ["output"],
                                             transforms)

        tf.train.write_graph(optimized_graph_def,
                             logdir=path,
                             as_text=False,
                             name=model_name + ".pb") 
开发者ID:alexanderkroner,项目名称:saliency,代码行数:46,代码来源:model.py


注:本文中的tensorflow.tools.graph_transforms.TransformGraph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。