当前位置: 首页>>代码示例>>Python>>正文


Python graph_io.write_graph方法代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_io.write_graph方法的典型用法代码示例。如果您正苦于以下问题:Python graph_io.write_graph方法的具体用法?Python graph_io.write_graph怎么用?Python graph_io.write_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.framework.graph_io的用法示例。


在下文中一共展示了graph_io.write_graph方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: export_cnn

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def export_cnn() -> None:
    input = tf.placeholder(tf.float32, shape=(1, 1, 3, 3))
    filter = tf.constant(np.ones((3, 3, 1, 1)), dtype=tf.float32)
    x = tf.nn.conv2d(input, filter, (1, 1, 1, 1), "SAME", data_format="NCHW")
    x = tf.nn.sigmoid(x)
    x = tf.nn.relu(x)

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])

    with tf.Session() as sess:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names
        )

    frozen = graph_util.remove_training_nodes(constant_graph)

    output = "cnn.pb"
    graph_io.write_graph(frozen, ".", output, as_text=False) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:21,代码来源:convert.py

示例2: export

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def export(x: tf.Tensor, filename: str, sess=None):
    should_close = False
    if sess is None:
        should_close = True
        sess = tf.Session()

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])
    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_node_names
    )

    graph = graph_util.remove_training_nodes(graph)

    path = graph_io.write_graph(graph, ".", filename, as_text=False)

    if should_close:
        sess.close()

    return path 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:22,代码来源:convert_test.py

示例3: keras_to_tensorflow

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def keras_to_tensorflow(keras_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):

    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)

    out_nodes = []

    for i in range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))

        sess = K.get_session()

        init_graph = sess.graph.as_graph_def()

        main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

        graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)

        if log_tensorboard:
            import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir) 
开发者ID:mogoweb,项目名称:aiexamples,代码行数:23,代码来源:image_classifier_tf.py

示例4: convertMetaModelToPbModel

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def convertMetaModelToPbModel(meta_model, pb_model):
    # Step 1
    # import the model metagraph
    saver = tf.train.import_meta_graph(meta_model + '.meta', clear_devices=True)
    # make that as the default graph
    graph = tf.get_default_graph()
    sess = tf.Session()
    # now restore the variables
    saver.restore(sess, meta_model)
    # Step 2
    # Find the output name
    for op in graph.get_operations():
        print(op.name)
    # Step 3
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,  # The session
        sess.graph_def,  # input_graph_def is useful for retrieving the nodes
        ["Placeholder", "output/Sigmoid"])

    # Step 4
    # output folder
    output_fld = './'
    # output pb file name
    output_model_file = 'model.pb'
    # write the graph
    graph_io.write_graph(output_graph_def, pb_model + output_fld, output_model_file, as_text=False) 
开发者ID:junqiangchen,项目名称:LiTS---Liver-Tumor-Segmentation-Challenge,代码行数:28,代码来源:util.py

示例5: main

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:28,代码来源:optimize_for_inference.py

示例6: main

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "r") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:28,代码来源:optimize_for_inference.py

示例7: save_model_to_tensorflow

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def save_model_to_tensorflow(self, new_model_folder, new_model_name=""):

        """
        'save_model_to_tensorflow' function allows you to save your loaded Keras (.h5) model and save it to the Tensorflow (.pb) model format.
        - new_model_folder (required), the path to the folder you want the converted Tensorflow model to be saved
        - new_model_name (required), the desired filename for your converted Tensorflow model e.g 'my_new_model.pb'

        :param new_model_folder:
        :param new_model_name:
        :return:
        """

        if(self.__modelLoaded == True):
            out_prefix = "output_"
            output_dir = new_model_folder
            if os.path.exists(output_dir) == False:
                os.mkdir(output_dir)
            model_name = os.path.join(output_dir, new_model_name)

            keras_model = self.__model_collection[0]


            out_nodes = []

            for i in range(len(keras_model.outputs)):
                out_nodes.append(out_prefix + str(i + 1))
                tf.identity(keras_model.output[i], out_prefix + str(i + 1))

            sess = K.get_session()

            from tensorflow.python.framework import graph_util, graph_io

            init_graph = sess.graph.as_graph_def()

            main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

            graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
            print("Tensorflow Model Saved") 
开发者ID:OlafenwaMoses,项目名称:ImageAI,代码行数:40,代码来源:__init__.py

示例8: build_model

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def build_model(self, model_fn, params):
    """Build the TPU model and infeed enqueue ops."""
    tf.logging.info("TrainLowLevelRunner: build_model method")

    def tpu_train_step(loss):
      """Generate the TPU graph."""
      del loss
      values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                      values)
      features = unflattened_inputs["features"]
      labels = unflattened_inputs["labels"]
      estimator_spec = model_fn(features, labels, tf.estimator.ModeKeys.TRAIN,
                                params)
      loss, train_op = estimator_spec.loss, estimator_spec.train_op
      with tf.control_dependencies([train_op]):
        return tf.identity(loss)

    def train_loop():
      return tpu.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS])

    with self.graph.as_default():
      (self.loss,) = tpu.shard(
          train_loop,
          inputs=[],
          num_shards=self.hparams.num_shards,
          outputs_from_all_shards=False,
      )
      global_initializer = tf.global_variables_initializer()
      local_initializer = tf.local_variables_initializer()
      graph_io.write_graph(
          self.graph.as_graph_def(add_shapes=True), self.hparams.out_dir,
          "graph.pbtxt")
      self.saver = tf.train.Saver()

    self.sess.run(global_initializer)
    self.sess.run(local_initializer) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:39,代码来源:low_level_runner.py

示例9: export_to_pb

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def export_to_pb(sess, x, filename):
    pred_names = ["output"]
    tf.identity(x, name=pred_names[0])

    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_names
    )

    graph = graph_util.remove_training_nodes(graph)
    path = graph_io.write_graph(graph, ".", filename, as_text=False)
    print("saved the frozen graph (ready for inference) at: ", path) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:13,代码来源:main.py

示例10: export_to_pb

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def export_to_pb(sess, x, filename):
    pred_names = ["output"]
    tf.identity(x, name=pred_names[0])

    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_names
    )

    graph = graph_util.remove_training_nodes(graph)
    path = graph_io.write_graph(graph, ".", filename, as_text=False)
    print("saved the frozen graph (ready for inference) at: ", filename)

    return path 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:15,代码来源:mnist_deep_cnn.py

示例11: convertMetaModelToPbModel

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def convertMetaModelToPbModel(meta_model, pb_model):
    # Step 1
    # import the model metagraph
    saver = tf.train.import_meta_graph(meta_model + '.meta', clear_devices=True)
    # make that as the default graph
    graph = tf.get_default_graph()
    sess = tf.Session()
    # now restore the variables
    saver.restore(sess, meta_model)
    # Step 2
    # Find the output name
    for op in graph.get_operations():
        print(op.name)
    # Step 3
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,  # The session
        sess.graph_def,  # input_graph_def is useful for retrieving the nodes
        ["Input", "output/Sigmoid"])

    # Step 4
    # output folder
    output_fld = './'
    # output pb file name
    output_model_file = 'model.pb'
    # write the graph
    graph_io.write_graph(output_graph_def, pb_model + output_fld, output_model_file, as_text=False) 
开发者ID:junqiangchen,项目名称:VNet,代码行数:28,代码来源:util.py

示例12: main

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def main(unused_args):
    if not gfile.Exists(FLAGS.input):
        print("Input graph file '" + FLAGS.input + "' does not exist!")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    with gfile.Open(FLAGS.input, "rb") as f:
        data = f.read()
        if FLAGS.frozen_graph:
            input_graph_def.ParseFromString(data)
        else:
            text_format.Merge(data.decode("utf-8"), input_graph_def)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        FLAGS.input_names.split(","),
        FLAGS.output_names.split(","),
        FLAGS.placeholder_type_enum,
        FLAGS.toco_compatible)

    if FLAGS.frozen_graph:
        f = gfile.FastGFile(FLAGS.output, "w")
        f.write(output_graph_def.SerializeToString())
    else:
        graph_io.write_graph(output_graph_def,
                             os.path.dirname(FLAGS.output),
                             os.path.basename(FLAGS.output))
    return 0 
开发者ID:jiny2001,项目名称:dcscn-super-resolution,代码行数:30,代码来源:optimize_for_inference.py

示例13: convertGraph

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def convertGraph(modelPath, output, outPath):
    '''
    Converts an HD5F file to a .pb file for use with Tensorflow.
    Args:
        modelPath (str): path to the .h5 file
           output (str): name of the referenced output
          outPath (str): path to the output .pb file
    Returns:
        None
    '''

    dir = os.path.dirname(os.path.realpath(__file__))
    outdir = os.path.join(dir, os.path.dirname(outPath))
    name = os.path.basename(outPath)
    basename, ext = os.path.splitext(name)

    #NOTE: If using Python > 3.2, this could be replaced with os.makedirs( name, exist_ok=True )
    if not os.path.isdir(outdir):
        os.mkdir(outdir)

    K.set_learning_phase(0)

    net_model = load_model(modelPath)

    # Alias the outputs in the model - this sometimes makes them easier to access in TF
    tf.identity(net_model.output, name=output)

    sess = K.get_session()

    net_model.summary()

    # Write the graph in human readable
    f = '{}.reference.pb.ascii'.format(basename)
    tf.train.write_graph(sess.graph.as_graph_def(), outdir, f, as_text=True)
    print('Saved the graph definition in ascii format at: ', osp.join(outdir, f))

    # Write the graph in binary .pb file
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import graph_io
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output])
    graph_io.write_graph(constant_graph, outdir, name, as_text=False)
    print('Saved the constant graph (ready for inference) at: ', outPath) 
开发者ID:tympanix,项目名称:subsync,代码行数:44,代码来源:convert.py

示例14: testStripUnused

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def testStripUnused(self):
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that has a single constant containing 1.0,
    # and that then multiplies it by 2.
    with ops.Graph().as_default():
      constant_node = constant_op.constant(1.0, name="constant_node")
      wanted_input_node = math_ops.subtract(constant_node,
                                            3.0,
                                            name="wanted_input_node")
      output_node = math_ops.multiply(
          wanted_input_node, 2.0, name="output_node")
      math_ops.add(output_node, 2.0, name="later_node")
      sess = session.Session()
      output = sess.run(output_node)
      self.assertNear(-4.0, output, 0.00001)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_binary = False
    input_node_names = "wanted_input_node"
    output_binary = True
    output_node_names = "output_node"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)

    strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
                                             output_graph_path, output_binary,
                                             input_node_names,
                                             output_node_names,
                                             dtypes.float32.as_datatype_enum)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(3, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("Add", node.op)
        self.assertNotEqual("Sub", node.op)
        if node.name == input_node_names:
          self.assertTrue("shape" in node.attr)

      with session.Session() as sess:
        input_node = sess.graph.get_tensor_by_name("wanted_input_node:0")
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node, feed_dict={input_node: [10.0]})
        self.assertNear(20.0, output, 0.00001) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:56,代码来源:strip_unused_test.py

示例15: testStripUnusedMultipleInputs

# 需要导入模块: from tensorflow.python.framework import graph_io [as 别名]
# 或者: from tensorflow.python.framework.graph_io import write_graph [as 别名]
def testStripUnusedMultipleInputs(self):
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that multiplies two input nodes.
    with ops.Graph().as_default():
      constant_node1 = constant_op.constant(1.0, name="constant_node1")
      constant_node2 = constant_op.constant(2.0, name="constant_node2")
      input_node1 = math_ops.subtract(constant_node1, 3.0, name="input_node1")
      input_node2 = math_ops.subtract(constant_node2, 5.0, name="input_node2")
      output_node = math_ops.multiply(
          input_node1, input_node2, name="output_node")
      math_ops.add(output_node, 2.0, name="later_node")
      sess = session.Session()
      output = sess.run(output_node)
      self.assertNear(6.0, output, 0.00001)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_binary = False
    input_node_names = "input_node1,input_node2"
    input_node_types = [
        dtypes.float32.as_datatype_enum, dtypes.float32.as_datatype_enum
    ]
    output_binary = True
    output_node_names = "output_node"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)

    strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
                                             output_graph_path, output_binary,
                                             input_node_names,
                                             output_node_names,
                                             input_node_types)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(3, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("Add", node.op)
        self.assertNotEqual("Sub", node.op)
        if node.name == input_node_names:
          self.assertTrue("shape" in node.attr)

      with session.Session() as sess:
        input_node1 = sess.graph.get_tensor_by_name("input_node1:0")
        input_node2 = sess.graph.get_tensor_by_name("input_node2:0")
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node,
                          feed_dict={input_node1: [10.0],
                                     input_node2: [-5.0]})
        self.assertNear(-50.0, output, 0.00001) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:61,代码来源:strip_unused_test.py


注:本文中的tensorflow.python.framework.graph_io.write_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。