当前位置: 首页>>代码示例>>Python>>正文


Python freeze_graph.freeze_graph方法代码示例

本文整理汇总了Python中tensorflow.python.tools.freeze_graph.freeze_graph方法的典型用法代码示例。如果您正苦于以下问题:Python freeze_graph.freeze_graph方法的具体用法?Python freeze_graph.freeze_graph怎么用?Python freeze_graph.freeze_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.tools.freeze_graph的用法示例。


在下文中一共展示了freeze_graph.freeze_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: freeze_model

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def freeze_model(saved_model_dir, output_node_names, output_filename):
  output_graph_filename = os.path.join(saved_model_dir, output_filename)
  initializer_nodes = ''
  freeze_graph.freeze_graph(
      input_saved_model_dir=saved_model_dir,
      output_graph=output_graph_filename,
      saved_model_tags = tag_constants.SERVING,
      output_node_names=output_node_names,
      initializer_nodes=initializer_nodes,
      input_graph=None,
      input_saver=False,
      input_binary=False,
      input_checkpoint=None,
      restore_op_name=None,
      filename_tensor_name=None,
      clear_devices=True,
      input_meta_graph=False,
  ) 
开发者ID:PINTO0309,项目名称:PINTO_model_zoo,代码行数:20,代码来源:01_freeze_the_saved_model_v1.py

示例2: export_frozenPB

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def export_frozenPB():

    tf.reset_default_graph()

    dets = build_detection_graph()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        print("we have restred the weights from =====>>\n", CKPT_PATH)
        saver.restore(sess, CKPT_PATH)

        tf.train.write_graph(sess.graph_def, OUT_DIR, PB_NAME)
        freeze_graph.freeze_graph(input_graph=os.path.join(OUT_DIR, PB_NAME),
                                  input_saver='',
                                  input_binary=False,
                                  input_checkpoint=CKPT_PATH,
                                  output_node_names="DetResults",
                                  restore_op_name="save/restore_all",
                                  filename_tensor_name='save/Const:0',
                                  output_graph=os.path.join(OUT_DIR, PB_NAME.replace('.pb', '_Frozen.pb')),
                                  clear_devices=False,
                                  initializer_nodes='') 
开发者ID:DetectionTeamUCAS,项目名称:R2CNN_Faster-RCNN_Tensorflow,代码行数:25,代码来源:exportPb.py

示例3: _load_saved_model

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def _load_saved_model(self):
        """Load the tensorflow saved model."""
        try:
            from tensorflow.python.tools import freeze_graph
            from tensorflow.python.framework import ops
            from tensorflow.python.framework import graph_util
            from tensorflow.core.framework import graph_pb2
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")

        saved_model_dir = self._model_dir
        output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
        input_saved_model_dir = saved_model_dir
        output_node_names = self._get_output_names()

        input_binary = False
        input_saver_def_path = False
        restore_op_name = None
        filename_tensor_name = None
        clear_devices = True
        input_meta_graph = False
        checkpoint_path = None
        input_graph_filename = None
        saved_model_tags = ",".join(self._get_tag_set())

        freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
                                  input_binary, checkpoint_path, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_filename, clear_devices, "", "", "",
                                  input_meta_graph, input_saved_model_dir,
                                  saved_model_tags)

        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_filename, "rb") as f:
                output_graph_def.ParseFromString(f.read())
            output_graph_def = graph_util.remove_training_nodes(output_graph_def,
                                                                protected_nodes=self._outputs)
            return output_graph_def 
开发者ID:apache,项目名称:incubator-tvm,代码行数:43,代码来源:tensorflow_parser.py

示例4: convert_model

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def convert_model():
    input_checkpoint_path = tf.train.latest_checkpoint(hyper_parameters['ckpt_dir'])
    input_graph_path = os.path.join(hyper_parameters['output_dir'], hyper_parameters['input_graph_name'])
    input_saver_def_path = ""
    input_binary = False
    output_node_names = hyper_parameters[net_type]
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(hyper_parameters['output_dir'], hyper_parameters['output_graph_name'])
    clear_devices = False

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, input_binary, input_checkpoint_path,
                              output_node_names, restore_op_name, filename_tensor_name, output_graph_path,
                              clear_devices, initializer_nodes='', variable_names_blacklist='') 
开发者ID:Enigma-li,项目名称:SketchCNN,代码行数:16,代码来源:freeze_graph_tool.py

示例5: _export_graph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def _export_graph(self):
        """
        Exports latest saved model to .bytes format for Unity embedding.
        """
        target_nodes = ','.join(self._process_graph())
        ckpt = tf.train.get_checkpoint_state(self.model_path)
        freeze_graph.freeze_graph(input_graph=self.model_path + '/raw_graph_def.pb',
                                  input_binary=True,
                                  input_checkpoint=ckpt.model_checkpoint_path,
                                  output_node_names=target_nodes,
                                  output_graph=self.model_path + '/' + self.env_name + "_" + self.run_id + '.bytes',
                                  clear_devices=True, initializer_nodes="", input_saver="",
                                  restore_op_name="save/restore_all", filename_tensor_name="save/Const:0") 
开发者ID:ArztSamuel,项目名称:DRL_DeliveryDuel,代码行数:15,代码来源:trainer_controller.py

示例6: model_freeze

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def model_freeze(path,MODEL_NAME='model'):

    # Freeze the graph

    input_graph_path = path + MODEL_NAME+'.pbtxt'
    checkpoint_path = path + 'model_ckpt'
    input_saver_def_path = ""
    input_binary = False
    output_node_names = 'positive_sentiment_probability'
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_frozen_graph_name = path + 'frozen_'+MODEL_NAME+'.pb'
    output_optimized_graph_name = path + 'optimized_'+MODEL_NAME+'.pb'
    clear_devices = True


    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                            input_binary, checkpoint_path, output_node_names,
                            restore_op_name, filename_tensor_name,
    output_frozen_graph_name, clear_devices, "")

    input_graph_def = tf.GraphDef()

    with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
        data = f.read()
        input_graph_def.ParseFromString(data)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def,
            ["inputs/X" ],#an array of the input node(s)
            ["positive_sentiment_probability"],
            tf.int32.as_datatype_enum # an array of output nodes
            )

    # Save the optimized graph

    f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
    f.write(output_graph_def.SerializeToString()) 
开发者ID:PacktPublishing,项目名称:Intelligent-Projects-Using-Python,代码行数:40,代码来源:freeze_code.py

示例7: export_model

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def export_model(saver, model, input_node_names, output_node_name):
    if not path.exists('out'):
        os.mkdir('out')

    tf.train.write_graph(K.get_session().graph_def, 'out', model_name + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, False,
                              'out/' + model_name + '.chkp', output_node_name,
                              "save/restore_all", "save/Const:0",
                              'out/frozen_' + model_name + '.bytes', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.bytes', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + model_name + '.bytes', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")

########################################################################################################################
# Main program 
开发者ID:jzharris,项目名称:Unity-MNIST,代码行数:30,代码来源:mnist_cnn1.py

示例8: export_model

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def export_model(model_output_dir, input_node_names, output_node_name):
    """Export the model so we can use it later.

    This will create two Protocol Buffer files in the model output directory.
    These files represent a serialized version of our model with all the
    learned weights and biases. One of the ProtoBuf files is a version
    optimized for inference-only usage.
    """

    name_base = os.path.join(model_output_dir, MODEL_NAME)
    frozen_graph_file = os.path.join(model_output_dir,
                                     'frozen_' + MODEL_NAME + '.pb')
    freeze_graph.freeze_graph(
        name_base + '.pbtxt', None, False, name_base + '.chkp',
        output_node_name, "save/restore_all", "save/Const:0",
        frozen_graph_file, True, ""
    )

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open(frozen_graph_file, "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    optimized_graph_file = os.path.join(model_output_dir,
                                        'optimized_' + MODEL_NAME + '.pb')
    with tf.gfile.GFile(optimized_graph_file, "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("Inference optimized graph saved at: " + optimized_graph_file) 
开发者ID:IBM,项目名称:tensorflow-hangul-recognition,代码行数:34,代码来源:hangul_model.py

示例9: save

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def save(name, data_input_path):
    def getpardir(path): return osp.split(path)[0]
    sys.path.append(getpardir(getpardir(getpardir(osp.realpath(__file__)))))
    # Import the converted model's class
    caffe_net_module = __import__(name)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        image_input = tf.placeholder(tf.float32, shape=[1, 227, 227, 3], name="data")
        net = caffe_net_module.CaffeNet({'data': image_input})

        # Save protocol buffer
        pb_name = name + '.pb'
        tf.train.write_graph(sess.graph_def, '.', pb_name + 'txt', True)
        tf.train.write_graph(sess.graph_def, '.', pb_name, False)

        if data_input_path is not None:
            # Load the data
            sess.run(tf.global_variables_initializer())
            net.load(data_input_path, sess)
            # Save the data
            saver = saver_lib.Saver(tf.global_variables())
            checkpoint_prefix = osp.join(osp.curdir, name + '.ckpt')
            checkpoint_path = saver.save(sess, checkpoint_prefix)

            # Freeze the graph
            freeze_graph.freeze_graph(pb_name, "",
                                      True, checkpoint_path, 'fc8/fc8',
                                      'save/restore_all', 'save/Const:0',
                                      name + '_frozen.pb', False, "") 
开发者ID:PacktPublishing,项目名称:Machine-Learning-with-TensorFlow-1.x,代码行数:30,代码来源:save_model.py

示例10: freeze_graph_func

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def freeze_graph_func(model_dir, output_node_names, output_dir):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 

    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    sub_dirs = [name for name in os.listdir(model_dir)
         if os.path.isdir(os.path.join(model_dir, name))]
    model_dir = os.path.join(model_dir, sub_dirs[0])

    output_graph_filename = os.path.join(output_dir, 'frozen_model.pb')
    initializer_nodes = ''
    freeze_graph(
        input_graph=None,
        input_saver=False,
        input_binary=False,
        input_checkpoint=None,
        output_node_names=output_node_names,
        restore_op_name=None,
        filename_tensor_name=None,
        output_graph=output_graph_filename,
        clear_devices=True,
        initializer_nodes=initializer_nodes,
        input_meta_graph=False,
        input_saved_model_dir=model_dir,
        saved_model_tags=tag_constants.SERVING)
    print('model has been frozen!') 
开发者ID:nolanliou,项目名称:mobile-deeplab-v3-plus,代码行数:41,代码来源:freeze.py

示例11: _freeze_graph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def _freeze_graph(model, basename, output_dir):
    name, _ = os.path.splitext(basename)

    saver = tf.train.Saver()

    with keras.backend.get_session() as sess:
        checkpoint_filename = os.path.join(output_dir, '%s.ckpt' % name)
        output_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
        saver.save(sess, checkpoint_filename)
        tf.train.write_graph(
            sess.graph_def, output_dir, '%s_graph_def.pbtext' % name
        )

        freeze_graph.freeze_graph(
            input_graph=os.path.join(output_dir, '%s_graph_def.pbtext' % name),
            input_saver='',
            input_binary=False,
            input_checkpoint=checkpoint_filename,
            output_graph=output_graph_filename,
            output_node_names='conv6_interp/ResizeBilinear',
            restore_op_name="save/restore_all",
            filename_tensor_name="save/Const:0",
            clear_devices=True,
            initializer_nodes=None
        )
        logger.info('Saved frozen graph to: %s' % output_graph_filename) 
开发者ID:fritzlabs,项目名称:fritz-models,代码行数:28,代码来源:convert_to_tfmobile.py

示例12: _freeze_graph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def _freeze_graph(model, basename, output_dir):
    name, _ = os.path.splitext(basename)

    saver = tf.train.Saver()

    with keras.backend.get_session() as sess:
        checkpoint_filename = os.path.join(output_dir, '%s.ckpt' % name)
        output_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
        saver.save(sess, checkpoint_filename)
        tf.train.write_graph(
            sess.graph_def, output_dir, '%s_graph_def.pbtext' % name
        )

        freeze_graph.freeze_graph(
            input_graph=os.path.join(output_dir, '%s_graph_def.pbtext' % name),
            input_saver='',
            input_binary=False,
            input_checkpoint=checkpoint_filename,
            output_graph=output_graph_filename,
            output_node_names='deprocess_stylized_image_1/mul',
            restore_op_name="save/restore_all",
            filename_tensor_name="save/Const:0",
            clear_devices=True,
            initializer_nodes=None
        )
        logger.info('Saved frozen graph to: %s' % output_graph_filename) 
开发者ID:fritzlabs,项目名称:fritz-models,代码行数:28,代码来源:convert_to_tfmobile.py

示例13: freeze_keras_model_graph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def freeze_keras_model_graph(model, basename, output_dir):
    """Extract and freeze the tensorflow graph from a Keras model.

    Args:
        model (keras.models.Model): A Keras model.
        basename (str): the basename of the Keras model. E.g. starry_night.h5
        output_dir (str): a directory to output the frozen graph
    
    Returns:
        output_graph_filename (str): a path to the saved frozen graph.
    """
    name, _ = os.path.splitext(basename)

    saver = tf.train.Saver()

    with keras.backend.get_session() as sess:
        checkpoint_filename = os.path.join(output_dir, '%s.ckpt' % name)
        output_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
        saver.save(sess, checkpoint_filename)
        tf.train.write_graph(
            sess.graph_def, output_dir, '%s_graph_def.pbtext' % name
        )

        freeze_graph.freeze_graph(
            input_graph=os.path.join(output_dir, '%s_graph_def.pbtext' % name),
            input_saver='',
            input_binary=False,
            input_checkpoint=checkpoint_filename,
            output_graph=output_graph_filename,
            output_node_names='deprocess_stylized_image_1/mul',
            restore_op_name="save/restore_all",
            filename_tensor_name="save/Const:0",
            clear_devices=True,
            initializer_nodes=None
        )
        logger.info('Saved frozen graph to: %s' % output_graph_filename)
    return output_graph_filename 
开发者ID:fritzlabs,项目名称:fritz-models,代码行数:39,代码来源:tf_utils.py

示例14: export_graph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def export_graph(model_path, env_name="env", target_nodes="action,value_estimate,action_probs"):
    """
    Exports latest saved model to .bytes format for Unity embedding.
    :param model_path: path of model checkpoints.
    :param env_name: Name of associated Learning Environment.
    :param target_nodes: Comma separated string of needed output nodes for embedded graph.
    """
    ckpt = tf.train.get_checkpoint_state(model_path)
    freeze_graph.freeze_graph(input_graph=model_path + '/raw_graph_def.pb',
                              input_binary=True,
                              input_checkpoint=ckpt.model_checkpoint_path,
                              output_node_names=target_nodes,
                              output_graph=model_path + '/' + env_name + '.bytes',
                              clear_devices=True, initializer_nodes="", input_saver="",
                              restore_op_name="save/restore_all", filename_tensor_name="save/Const:0") 
开发者ID:llSourcell,项目名称:Unity_ML_Agents,代码行数:17,代码来源:models.py

示例15: _testFreezeGraph

# 需要导入模块: from tensorflow.python.tools import freeze_graph [as 别名]
# 或者: from tensorflow.python.tools.freeze_graph import freeze_graph [as 别名]
def _testFreezeGraph(self, saver_write_version):

    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
    checkpoint_state_name = "checkpoint_state"
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that has a single variable containing 1.0,
    # and that then multiplies it by 2.
    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
      sess = session.Session()
      init = variables.global_variables_initializer()
      sess.run(init)
      output = sess.run(output_node)
      self.assertNear(2.0, output, 0.00001)
      saver = saver_lib.Saver(write_version=saver_write_version)
      checkpoint_path = saver.save(
          sess,
          checkpoint_prefix,
          global_step=0,
          latest_filename=checkpoint_state_name)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_saver_def_path = ""
    input_binary = False
    output_node_names = "output_node"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
    clear_devices = False

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path, output_node_names,
                              restore_op_name, filename_tensor_name,
                              output_graph_path, clear_devices, "")

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(4, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("VariableV2", node.op)
        self.assertNotEqual("Variable", node.op)

      with session.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:60,代码来源:freeze_graph_test.py


注:本文中的tensorflow.python.tools.freeze_graph.freeze_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。