当前位置: 首页>>代码示例>>Python>>正文


Python graph_util.convert_variables_to_constants函数代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_util.convert_variables_to_constants函数的典型用法代码示例。如果您正苦于以下问题:Python convert_variables_to_constants函数的具体用法?Python convert_variables_to_constants怎么用?Python convert_variables_to_constants使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了convert_variables_to_constants函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testConvertVariablesToConsts

  def testConvertVariablesToConsts(self):
    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      output_node = math_ops_lib.multiply(
          variable_node, 2.0, name="output_node")
      with session.Session() as sess:
        init = variables.initialize_variables([variable_node])
        sess.run(init)
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
        variable_graph_def = sess.graph.as_graph_def()
        # First get the constant_graph_def when variable_names_whitelist is set,
        # note that if variable_names_whitelist is not set an error will be
        # thrown because unused_variable_node is not initialized.
        constant_graph_def = graph_util.convert_variables_to_constants(
            sess,
            variable_graph_def, ["output_node"],
            variable_names_whitelist=set(["variable_node"]))

        # Then initialize the unused variable, and get another
        # constant_graph_def when variable_names_whitelist is not set.
        sess.run(variables.global_variables_initializer())
        constant_graph_def_without_variable_whitelist = (
            graph_util.convert_variables_to_constants(sess, variable_graph_def,
                                                      ["output_node"]))

        # The unused variable should be cleared so the two graphs should be
        # equivalent.
        self.assertEqual(
            str(constant_graph_def),
            str(constant_graph_def_without_variable_whitelist))

        # Test variable name black list. This should result in the variable not
        # being a const.
        sess.run(variables.global_variables_initializer())
        constant_graph_def_with_blacklist = (
            graph_util.convert_variables_to_constants(
                sess,
                variable_graph_def, ["output_node"],
                variable_names_blacklist=set(["variable_node"])))
        variable_node = None
        for node in constant_graph_def_with_blacklist.node:
          if node.name == "variable_node":
            variable_node = node
        self.assertIsNotNone(variable_node)
        self.assertEqual(variable_node.op, "VariableV2")

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      _ = importer.import_graph_def(constant_graph_def, name="")
      self.assertEqual(4, len(constant_graph_def.node))
      for node in constant_graph_def.node:
        self.assertNotEqual("Variable", node.op)
        self.assertNotEqual("VariableV2", node.op)
      with session.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:60,代码来源:graph_util_test.py

示例2: freeze_graph

def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices, initializer_nodes):
  """Converts all variables in a graph and checkpoint into constants."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if input_saver and not tf.gfile.Exists(input_saver):
    print("Input saver file '" + input_saver + "' does not exist!")
    return -1

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not tf.train.checkpoint_exists(input_checkpoint):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read().decode("utf-8"), input_graph_def)
  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ""
  _ = tf.import_graph_def(input_graph_def, name="")

  with tf.Session() as sess:
    if input_saver:
      with tf.gfile.FastGFile(input_saver, mode) as f:
        saver_def = tf.train.SaverDef()
        if input_binary:
          saver_def.ParseFromString(f.read())
        else:
          text_format.Merge(f.read(), saver_def)
        saver = tf.train.Saver(saver_def=saver_def)
        saver.restore(sess, input_checkpoint)
    else:
      sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
                                FLAGS.variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_blacklist=variable_names_blacklist)

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
开发者ID:curtiszimmerman,项目名称:tensorflow,代码行数:60,代码来源:freeze_graph.py

示例3: graph_def_from_checkpoint

def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
  """Converts checkpoint data to GraphDef.

  Reads the latest checkpoint data and produces a GraphDef in which the
  variables have been converted to constants.

  Args:
    checkpoint_dir: Path to the checkpoints.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    A GraphDef from the latest checkpoint

  Raises:
    ValueError: if no checkpoint is found
  """
  checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
  if checkpoint_path is None:
    raise ValueError('Could not find a checkpoint at: {0}.'
                     .format(checkpoint_dir))

  saver_for_restore = saver_lib.import_meta_graph(
      checkpoint_path + '.meta', clear_devices=True)
  with session.Session() as sess:
    saver_for_restore.restore(sess, checkpoint_path)
    graph_def = ops.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, graph_def, output_node_names)

  return output_graph_def
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:strip_pruning_vars_lib.py

示例4: freeze_graph_with_def_protos

def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
    raise ValueError(
        'Input checkpoint "' + input_checkpoint + '" does not exist!')

  if not output_node_names:
    raise ValueError(
        'You must supply the name of a node to --output_node_names.')

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ''

  _ = importer.import_graph_def(input_graph_def, name='')

  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ':0')
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor
      saver = saver_lib.Saver(var_list=var_list)
      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes)

    variable_names_blacklist = (variable_names_blacklist.split(',') if
                                variable_names_blacklist else None)
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        output_node_names.split(','),
        variable_names_blacklist=variable_names_blacklist)

  return output_graph_def
开发者ID:chenxiang204,项目名称:code,代码行数:60,代码来源:exporter.py

示例5: from_session

  def from_session(cls,
                   sess,
                   input_tensors,
                   output_tensors,
                   freeze_variables=False):
    """Creates a TocoConverter class from a TensorFlow Session.

    Args:
      sess: TensorFlow Session.
      input_tensors: List of input tensors. Type and shape are computed using
        `foo.get_shape()` and `foo.dtype`.
      output_tensors: List of output tensors (only .name is used from this).
      freeze_variables: Boolean indicating whether the variables need to be
        converted into constants via the freeze_graph.py script.
        (default False)

    Returns:
      TocoConverter class.
    """

    # Get GraphDef.
    if freeze_variables:
      sess.run(global_variables_initializer())
      output_arrays = [tensor_name(tensor) for tensor in output_tensors]
      graph_def = tf_graph_util.convert_variables_to_constants(
          sess, sess.graph_def, output_arrays)
    else:
      graph_def = sess.graph_def

    # Create TocoConverter class.
    return cls(graph_def, input_tensors, output_tensors)
开发者ID:jinxin0924,项目名称:tensorflow,代码行数:31,代码来源:lite.py

示例6: testConvertVariablesToConstsWithEmbeddings

  def testConvertVariablesToConstsWithEmbeddings(self):
    """Freezes a graph with embeddings."""
    input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)

    # Make model.
    state_input = keras.layers.Input(
        shape=(1,), name="state_input", dtype="int32")
    output = keras.layers.Embedding(
        output_dim=16, input_dim=100, input_length=1, name="state")(
            state_input)
    model = keras.models.Model(inputs=[state_input], outputs=[output])
    model.compile(
        loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")

    # Get associated session.
    sess = keras.backend.get_session()
    variable_graph_def = sess.graph_def
    output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
    constant_graph_def = graph_util.convert_variables_to_constants(
        sess, variable_graph_def, output_tensor)

    # Ensure graph has no variables.
    for node in constant_graph_def.node:
      self.assertNotIn(
          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])

    # Compare the value of the graphs.
    expected_value = model.predict(input_data)
    actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
                                            model.outputs, [input_data])
    np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:graph_util_test.py

示例7: freeze_session

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a prunned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    prunned so subgraphs that are not neccesary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph
开发者ID:DXZ,项目名称:git_test,代码行数:28,代码来源:keras_to_tfsevring_final.py

示例8: freeze_graph

def freeze_graph(sess, input_tensors, output_tensors):
  """Returns a frozen GraphDef.

  Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
  existing GraphDef is returned. The Grappler pass is only run on models that
  are frozen in order to inline the functions in the graph.
  If OpHints is present, it will try to convert the OpHint graph.

  Args:
    sess: TensorFlow Session.
    input_tensors: List of input tensors.
    output_tensors: List of output tensors (only .name is used from this).

  Returns:
    Frozen GraphDef.
  """
  # Grappler inline function optimization will break OpHints graph
  # transformation, so if OpHints are present, just convert it.
  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
  if len(hinted_outputs_nodes) > 0:  #  pylint: disable=g-explicit-length-test
    return _convert_op_hints_if_present(sess, output_tensors)

  # Runs a Grappler pass in order to inline any functions in the graph.
  config = get_grappler_config(function_only=True)
  graph_def = run_graph_optimizations(
      sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)

  if not is_frozen_graph(sess):
    output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
    return tf_graph_util.convert_variables_to_constants(sess, graph_def,
                                                        output_arrays)
  else:
    return sess.graph_def
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:util.py

示例9: testConvertVariablesToConstsWithFunctions

  def testConvertVariablesToConstsWithFunctions(self):
    @function.Defun(dtypes.float32)
    def plus_one(x):
      return x + 1.0

    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      defun_node = plus_one(variable_node)
      output_node = math_ops_lib.multiply(
          defun_node, 2.0, name="output_node")

      with session.Session() as sess:
        init = variables.initialize_variables([variable_node])
        sess.run(init)
        output = sess.run(output_node)
        self.assertNear(4.0, output, 0.00001)
        variable_graph_def = sess.graph.as_graph_def()

        # First get the constant_graph_def when variable_names_whitelist is set,
        # note that if variable_names_whitelist is not set an error will be
        # thrown because unused_variable_node is not initialized.
        constant_graph_def = graph_util.convert_variables_to_constants(
            sess,
            variable_graph_def, ["output_node"],
            variable_names_whitelist=set(["variable_node"]))

        self.assertEqual(variable_graph_def.library,
                         constant_graph_def.library)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:29,代码来源:graph_util_test.py

示例10: _freeze_graph_with_def_protos

def _freeze_graph_with_def_protos(input_graph_def, output_node_names,
                                  initializer_names, shared_init_op_name,
                                  input_saver_def, input_checkpoint):
  """Converts all variables in a graph and checkpoint into constants.

  During this process, we need to retain certain initializer nodes (e.g. table
  initializer nodes). Instead of determining which dependencies
  of the shared initializer node (e.g. group_deps) to keep, we
  reconstruct the connections between the individual initializer nodes and
  the shared node after freezing the graph.

  Args:
    input_graph_def: A GraphDef proto to be frozen.
    output_node_names: Names of output nodes.
    initializer_names: Names of initializer nodes to keep.
    shared_init_op_name: The name of the shared initializer node to connect the
      nodes in initializer names to.
    input_saver_def: A SaverDef proto used for restoring a checkpoint.
    input_checkpoint: A path to a checkpoint to restore.

  Returns:
    A frozen GraphDef.
  """

  with _ops.Graph().as_default():
    _ = _importer.import_graph_def(input_graph_def, name='')

    with _session.Session() as sess:
      saver = _saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
      output_graph_def = _graph_util.convert_variables_to_constants(
          sess, input_graph_def, output_node_names + initializer_names)
      _connect_to_shared_init_op(output_graph_def, shared_init_op_name,
                                 initializer_names)
  return output_graph_def
开发者ID:bikong2,项目名称:tensorflow,代码行数:35,代码来源:meta_graph_transform.py

示例11: freeze_graph_def

def freeze_graph_def(sess, input_graph_def, output_node_names):
    for node in input_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']
    
    # Get the list of important nodes
    whitelist_names = []
    for node in input_graph_def.node:
        if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or 
                node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
            whitelist_names.append(node.name)

    # Replace all the variables in the graph with constants of the same values
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, input_graph_def, output_node_names.split(","),
        variable_names_whitelist=whitelist_names)
    return output_graph_def
开发者ID:citysir,项目名称:facenet,代码行数:26,代码来源:freeze_graph.py

示例12: saveData

    def saveData(self,step):
        print('{} Saving checkpoint file to: {}'.format(
            datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
            self.output_dir))

        # 保存图的权值
        self.saver.save(
            self.sess, self.ckpt_file, global_step=self.global_step)
        # 保存图的结构
        tf.train.write_graph(self.sess.graph_def,
                             os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model'),
                             'train.pbtxt')

        # 保存到权值对图,生成可供android使用的.pb文件
        graph_def = tf.get_default_graph().as_graph_def()

        print("global_variables are")
        variables_to_save = []

        for variables in self.sess.graph_def.node:
            print("{}:{}".format(str(variables.name),type(variables)))
            variables_to_save.append(str(variables.name).split(':')[0])

        print("--------------")
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            self.sess,
            graph_def,
            #['yolo/pad_1/paddings']
            variables_to_save
            #self.net.logits
            # ["predictions"]  # 需要保存节点的名字///////////////////////////////////需要再改改
        )
        with tf.gfile.GFile(
                os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model', 'train.'+step+str(step)+'.pb'),
                "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))
        ####################################################################################################


        freezetime = datetime.datetime.now().strftime('%m-%d-%H-%M-%S')
        zu = ZipUtil()
        zipfilename = cfg.DATA_UploadZipFileName +'.'+str(step)+ '.' + freezetime
        # 添加啦step参数,可以按照训练对部分进行压缩,,不用全部压缩了
        zu.zip_dir(os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION),
                   step,
                   zipfilename)
        uploader = Uploader()
        uploader.setQiniuKEY('mMQxjyif6Uk8nSGIn9ZD3I19MBMEK3IUGngcX8_p',
                       'J5gFhdpQ-1O1rkCnlqYnzPiH3XTst2Szlv9GlmQM')

        #uploader.upload2qiniu(cfg.DATA_UploadZipFileName + '.' + freezetime,zipfilename).start()
        sendData = {"state":"prepared",
                "filename":str(zipfilename),
                "filepath":os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION,zipfilename),
                "step":step,
                }

        uploader.notifyForTrans(sendData)
开发者ID:wanglikang,项目名称:zzuARTensorflow2,代码行数:59,代码来源:train.py

示例13: train_network

def train_network(graph, batch_size, num_epochs, pb_file_path):
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        epoch_delta = 2
        for epoch_index in range(num_epochs):
            for i in range(12):
                sess.run([graph['optimize']], feed_dict={
                    graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                    graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                })
            if epoch_index % epoch_delta == 0:
                total_batches_in_train_set = 0
                total_correct_times_in_train_set = 0
                total_cost_in_train_set = 0.
                for i in range(12):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
                    })
                    total_batches_in_train_set += 1
                    total_correct_times_in_train_set += return_correct_times_in_batch
                    total_cost_in_train_set += (mean_cost_in_batch * batch_size)


                total_batches_in_test_set = 0
                total_correct_times_in_test_set = 0
                total_cost_in_test_set = 0.
                for i in range(3):
                    return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
                        graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
                    })
                    mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
                        graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
                        graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
                    })
                    total_batches_in_test_set += 1
                    total_correct_times_in_test_set += return_correct_times_in_batch
                    total_cost_in_test_set += (mean_cost_in_batch * batch_size)

                acy_on_test  = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
                acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
                print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
                                                                                                                                                   total_batches_in_test_set * batch_size,
                                                                                                                                                   total_cost_in_test_set,
                                                                                                                                                   acy_on_train * 100.0,
                                                                                                                                                   total_correct_times_in_train_set,
                                                                                                                                                   total_batches_in_train_set * batch_size,
                                                                                                                                                   total_cost_in_train_set))
            constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
            with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
                f.write(constant_graph.SerializeToString())
开发者ID:JGyoung33,项目名称:tensorflow-vgg16-train-and-test,代码行数:57,代码来源:train_vgg.py

示例14: save_graph_to_file

def save_graph_to_file(graph, graph_file_name, model_info, class_count):
  sess, _, _, _, _ = build_eval_session(model_info, class_count)
  graph = sess.graph

  output_graph_def = graph_util.convert_variables_to_constants(
      sess, graph.as_graph_def(), [FLAGS.final_tensor_name])

  with gfile.FastGFile(graph_file_name, 'wb') as f:
    f.write(output_graph_def.SerializeToString())
开发者ID:google,项目名称:makerfaire-2016,代码行数:9,代码来源:retrain.py

示例15: _convert_op_hints_if_present

def _convert_op_hints_if_present(sess, output_tensors):
  if is_frozen_graph(sess):
    raise ValueError("Try to convert op hints, needs unfrozen graph.")
  hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
  output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
  graph_def = tf_graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
  graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
  graph_def = tf_graph_util.remove_training_nodes(graph_def)
  return graph_def
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:10,代码来源:util.py


注:本文中的tensorflow.python.framework.graph_util.convert_variables_to_constants函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。