当前位置: 首页>>代码示例>>Python>>正文


Python graph_pb2.GraphDef方法代码示例

本文整理汇总了Python中tensorflow.core.framework.graph_pb2.GraphDef方法的典型用法代码示例。如果您正苦于以下问题:Python graph_pb2.GraphDef方法的具体用法?Python graph_pb2.GraphDef怎么用?Python graph_pb2.GraphDef使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.core.framework.graph_pb2的用法示例。


在下文中一共展示了graph_pb2.GraphDef方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_graph

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def load_graph(graph_path,tensorboard=False,**kwargs):
    '''
    :param graph_filename: the path of the pb file
    :return: tensorflow graph
    '''
    with gfile.FastGFile(graph_path,'rb') as f:
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def,name="")

    if tensorboard:
        writer = tf.summary.FileWriter("log/")
        writer.add_graph(graph)

    return graph 
开发者ID:bill9800,项目名称:speech_separation,代码行数:19,代码来源:utils.py

示例2: get_header

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def get_header(graphs,
               proto_fileformat='rawproto',
               default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
  """Computes a header for use with tensorflow SELECTIVE_REGISTRATION.

  Args:
    graphs: a list of paths to GraphDef files to include.
    proto_fileformat: optional format of proto file, either 'textproto' or
      'rawproto' (default).
    default_ops: optional comma-separated string of operator:kernel pairs to
      always include implementation for. Pass 'all' to have all operators and
      kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
  Returns:
    the string of the header that should be written as ops_to_register.h.
  """
  ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
  if not ops_and_kernels:
    print('Error reading graph!')
    return 1

  return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all') 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:23,代码来源:selective_registration_header_lib.py

示例3: get_stats_for_node_def

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def get_stats_for_node_def(graph, node, statistic_type):
  """Looks up the node's statistics function in the registry and calls it.

  This function takes a Graph object and a NodeDef from a GraphDef, and if
  there's an associated statistics method, calls it and returns a result. If no
  function has been registered for the particular node type, it returns an empty
  statistics object.

  Args:
    graph: A Graph object that's been set up with the node's graph.
    node: A NodeDef describing the operator.
    statistic_type: A string identifying the statistic we're interested in.
  Returns:
    An OpStats object containing information about resource usage.
  """

  try:
    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
    result = stats_func(graph, node)
  except LookupError:
    result = OpStats(statistic_type)
  return result 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:24,代码来源:ops.py

示例4: as_graph_def

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def as_graph_def(self, from_version=None, add_shapes=False):
    """Returns a serialized `GraphDef` representation of this graph.

    The serialized `GraphDef` can be imported into another `Graph`
    (using @{tf.import_graph_def}) or used with the
    [C++ Session API](../../api_docs/cc/index.md).

    This method is thread-safe.

    Args:
      from_version: Optional.  If this is set, returns a `GraphDef`
        containing only the nodes that were added to this graph since
        its `version` property had the given value.
      add_shapes: If true, adds an "_output_shapes" list attr to each
        node with the inferred shapes of each of its outputs.

    Returns:
      A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
      protocol buffer.

    Raises:
      ValueError: If the `graph_def` would be too large.
    """
    result, _ = self._as_graph_def(from_version, add_shapes)
    return result 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:27,代码来源:ops.py

示例5: Graph

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def Graph(self):
    """Return the graph definition, if there is one.

    If the graph is stored directly, return that.  If no graph is stored
    directly but a metagraph is stored containing a graph, return that.

    Raises:
      ValueError: If there is no graph for this run.

    Returns:
      The `graph_def` proto.
    """
    graph = graph_pb2.GraphDef()
    if self._graph is not None:
      graph.ParseFromString(self._graph)
      return graph
    raise ValueError('There is no graph in this EventAccumulator') 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:19,代码来源:event_accumulator.py

示例6: testAll

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def testAll(self):
    default_ops = 'all'
    graphs = [
        text_format.Parse(d, graph_pb2.GraphDef())
        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
    ]
    ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
        'rawproto', self.WriteGraphFiles(graphs), default_ops)

    header = print_selective_registration_header.get_header(ops_and_kernels,
                                                            default_ops)
    self.assertListEqual(
        [
            '#ifndef OPS_TO_REGISTER',  #
            '#define OPS_TO_REGISTER',  #
            '#define SHOULD_REGISTER_OP(op) true',  #
            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
            '#endif'
        ],
        header.split('\n')) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:23,代码来源:print_selective_registration_header_test.py

示例7: as_graph_def

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def as_graph_def(self, from_version=None, add_shapes=False):
    """Returns a serialized `GraphDef` representation of this graph.

    The serialized `GraphDef` can be imported into another `Graph`
    (using [`import_graph_def()`](#import_graph_def)) or used with the
    [C++ Session API](../../api_docs/cc/index.md).

    This method is thread-safe.

    Args:
      from_version: Optional.  If this is set, returns a `GraphDef`
        containing only the nodes that were added to this graph since
        its `version` property had the given value.
      add_shapes: If true, adds an "_output_shapes" list attr to each
        node with the inferred shapes of each of its outputs.

    Returns:
      A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
      protocol buffer.

    Raises:
      ValueError: If the `graph_def` would be too large.
    """
    result, _ = self._as_graph_def(from_version, add_shapes)
    return result 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:27,代码来源:ops.py

示例8: ProcessGraphDefParam

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def ProcessGraphDefParam(graph_def):
    """Type-checks and possibly canonicalizes `graph_def`.

    Parameters
    ----------
    graph_def : Obj
        tensorflow graph definition.

    Returns
    -------
    graph_def : Obj
        tensorflow graph devinition

    """

    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    return graph_def 
开发者ID:mlperf,项目名称:training_results_v0.6,代码行数:27,代码来源:tf.py

示例9: read

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def read(self, pb_path: str) -> Graph:
        """Read TF file and load model.

        Args:
            pb_path (str): Path to TF file

        Returns:
            Model: Loaded model

        """

        # load tensorflow model
        graph_def = graph_pb2.GraphDef()
        try:
            f = open(path.abspath(pb_path), "rb")
            graph_def.ParseFromString(f.read())
            f.close()
        except IOError:
            print("Could not open file. Creating a new one.")

        # import graph
        graph = Importer.make_graph(graph_def)

        return graph 
开发者ID:blue-oil,项目名称:blueoil,代码行数:26,代码来源:tensorflow.py

示例10: create_tfevent_from_pb

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def create_tfevent_from_pb(model,optimized=False):
    print("> creating tfevent of model: {}".format(model))

    if optimized:
        model_path=ROOT_DIR+'/models/{}/optimized_inference_graph.pb'.format(model)
        log_dir=ROOT_DIR+'/models/{}/log_opt/'.format(model)
    else:
        model_path=ROOT_DIR+'/models/{}/frozen_inference_graph.pb'.format(model)
        log_dir=ROOT_DIR+'/models/{}/log/'.format(model)

    with session.Session(graph=ops.Graph()) as sess:
        with gfile.FastGFile(model_path, "rb") as f:
          graph_def = graph_pb2.GraphDef()
          graph_def.ParseFromString(f.read())
          importer.import_graph_def(graph_def)
        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
    print("> Model {} Imported. \nVisualize by running: \
    tensorboard --logdir={}".format(model_path, log_dir))

# Gather all Model Names in models/ 
开发者ID:gustavz,项目名称:realtime_object_detection,代码行数:23,代码来源:all_models_to_tensorboard.py

示例11: testStrippedOpListRecursiveFunctions

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def testStrippedOpListRecursiveFunctions(self):
    # The function module doesn't support recursive functions, so we build a
    # recursive function situation by ourselves: A calls B calls A and Const.
    graph = graph_pb2.GraphDef()
    a = graph.library.function.add()
    b = graph.library.function.add()
    a.signature.name = "A"
    b.signature.name = "B"
    a.node.add().op = "B"
    b.node.add().op = "Const"
    b.node.add().op = "A"

    # Use A in the graph
    graph.node.add().op = "A"

    # The stripped op list should contain just Const.
    op_list = tf.contrib.util.stripped_op_list_for_graph(graph)
    self.assertEqual(["Const"], [op.name for op in op_list.op]) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:20,代码来源:meta_graph_test.py

示例12: main

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def main(unused_args):
  if not gfile.Exists(FLAGS.graph):
    print("Input graph file '" + FLAGS.graph + "' does not exist!")
    return -1

  graph = graph_pb2.GraphDef()
  with open(FLAGS.graph, "r") as f:
    if FLAGS.input_binary:
      graph.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), graph)

  with open(FLAGS.dot_output, "wb") as f:
    print("digraph graphname {", file=f)
    for node in graph.node:
      output_name = node.name
      print("  \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
      for input_full_name in node.input:
        parts = input_full_name.split(":")
        input_name = re.sub(r"^\^", "", parts[0])
        print("  \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
    print("}", file=f)
  print("Created DOT file '" + FLAGS.dot_output + "'.") 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:25,代码来源:graph_to_dot.py

示例13: load

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def load(self, model_path, inputs=None, outputs=None):
        # there is no input/output meta data i the graph so it need to come from config.
        if not inputs:
            raise ValueError("BackendTensorflow needs inputs")
        if not outputs:
            raise ValueError("BackendTensorflow needs outputs")
        self.outputs = outputs
        self.inputs = inputs

        # TODO: support checkpoint and saved_model formats?
        graph_def = graph_pb2.GraphDef()
        with open(model_path, "rb") as f:
            graph_def.ParseFromString(f.read())
        g = tf.compat.v1.import_graph_def(graph_def, name='')
        self.sess = tf.compat.v1.Session(graph=g)
        return self 
开发者ID:mlperf,项目名称:inference,代码行数:18,代码来源:backend_tf.py

示例14: _parse_input_graph_proto

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def _parse_input_graph_proto(input_graph, input_binary):
  """Parser input tensorflow graph into GraphDef proto."""
  if not gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1
  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  return input_graph_def 
开发者ID:rockingdingo,项目名称:deepnlp,代码行数:15,代码来源:freeze_graph.py

示例15: do_quantize_training_on_graphdef

# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def do_quantize_training_on_graphdef(input_graph, num_bits):
  from tensorflow.core.framework.graph_pb2 import GraphDef
  from tensorflow.python.framework import errors
  with errors.raise_exception_on_not_ok_status() as status:
    graph = GraphDef()
    result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
        input_graph.SerializeToString(), num_bits, status)

  graph.ParseFromString(result_graph_string)
  return graph 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:12,代码来源:pywrap_tensorflow_internal.py


注:本文中的tensorflow.core.framework.graph_pb2.GraphDef方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。