本文整理汇总了Python中tensorflow.core.framework.graph_pb2.GraphDef方法的典型用法代码示例。如果您正苦于以下问题:Python graph_pb2.GraphDef方法的具体用法?Python graph_pb2.GraphDef怎么用?Python graph_pb2.GraphDef使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.core.framework.graph_pb2
的用法示例。
在下文中一共展示了graph_pb2.GraphDef方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_graph
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def load_graph(graph_path,tensorboard=False,**kwargs):
'''
:param graph_filename: the path of the pb file
:return: tensorflow graph
'''
with gfile.FastGFile(graph_path,'rb') as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def,name="")
if tensorboard:
writer = tf.summary.FileWriter("log/")
writer.add_graph(graph)
return graph
示例2: get_header
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def get_header(graphs,
proto_fileformat='rawproto',
default_ops='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'):
"""Computes a header for use with tensorflow SELECTIVE_REGISTRATION.
Args:
graphs: a list of paths to GraphDef files to include.
proto_fileformat: optional format of proto file, either 'textproto' or
'rawproto' (default).
default_ops: optional comma-separated string of operator:kernel pairs to
always include implementation for. Pass 'all' to have all operators and
kernels included. Default: 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'.
Returns:
the string of the header that should be written as ops_to_register.h.
"""
ops_and_kernels = get_ops_and_kernels(proto_fileformat, graphs, default_ops)
if not ops_and_kernels:
print('Error reading graph!')
return 1
return get_header_from_ops_and_kernels(ops_and_kernels, default_ops == 'all')
示例3: get_stats_for_node_def
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def get_stats_for_node_def(graph, node, statistic_type):
"""Looks up the node's statistics function in the registry and calls it.
This function takes a Graph object and a NodeDef from a GraphDef, and if
there's an associated statistics method, calls it and returns a result. If no
function has been registered for the particular node type, it returns an empty
statistics object.
Args:
graph: A Graph object that's been set up with the node's graph.
node: A NodeDef describing the operator.
statistic_type: A string identifying the statistic we're interested in.
Returns:
An OpStats object containing information about resource usage.
"""
try:
stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
result = stats_func(graph, node)
except LookupError:
result = OpStats(statistic_type)
return result
示例4: as_graph_def
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def as_graph_def(self, from_version=None, add_shapes=False):
"""Returns a serialized `GraphDef` representation of this graph.
The serialized `GraphDef` can be imported into another `Graph`
(using @{tf.import_graph_def}) or used with the
[C++ Session API](../../api_docs/cc/index.md).
This method is thread-safe.
Args:
from_version: Optional. If this is set, returns a `GraphDef`
containing only the nodes that were added to this graph since
its `version` property had the given value.
add_shapes: If true, adds an "_output_shapes" list attr to each
node with the inferred shapes of each of its outputs.
Returns:
A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer.
Raises:
ValueError: If the `graph_def` would be too large.
"""
result, _ = self._as_graph_def(from_version, add_shapes)
return result
示例5: Graph
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def Graph(self):
"""Return the graph definition, if there is one.
If the graph is stored directly, return that. If no graph is stored
directly but a metagraph is stored containing a graph, return that.
Raises:
ValueError: If there is no graph for this run.
Returns:
The `graph_def` proto.
"""
graph = graph_pb2.GraphDef()
if self._graph is not None:
graph.ParseFromString(self._graph)
return graph
raise ValueError('There is no graph in this EventAccumulator')
示例6: testAll
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def testAll(self):
default_ops = 'all'
graphs = [
text_format.Parse(d, graph_pb2.GraphDef())
for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
]
ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
'rawproto', self.WriteGraphFiles(graphs), default_ops)
header = print_selective_registration_header.get_header(ops_and_kernels,
default_ops)
self.assertListEqual(
[
'#ifndef OPS_TO_REGISTER', #
'#define OPS_TO_REGISTER', #
'#define SHOULD_REGISTER_OP(op) true', #
'#define SHOULD_REGISTER_OP_KERNEL(clz) true', #
'#define SHOULD_REGISTER_OP_GRADIENT true', #
'#endif'
],
header.split('\n'))
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:23,代码来源:print_selective_registration_header_test.py
示例7: as_graph_def
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def as_graph_def(self, from_version=None, add_shapes=False):
"""Returns a serialized `GraphDef` representation of this graph.
The serialized `GraphDef` can be imported into another `Graph`
(using [`import_graph_def()`](#import_graph_def)) or used with the
[C++ Session API](../../api_docs/cc/index.md).
This method is thread-safe.
Args:
from_version: Optional. If this is set, returns a `GraphDef`
containing only the nodes that were added to this graph since
its `version` property had the given value.
add_shapes: If true, adds an "_output_shapes" list attr to each
node with the inferred shapes of each of its outputs.
Returns:
A [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
protocol buffer.
Raises:
ValueError: If the `graph_def` would be too large.
"""
result, _ = self._as_graph_def(from_version, add_shapes)
return result
示例8: ProcessGraphDefParam
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def ProcessGraphDefParam(graph_def):
"""Type-checks and possibly canonicalizes `graph_def`.
Parameters
----------
graph_def : Obj
tensorflow graph definition.
Returns
-------
graph_def : Obj
tensorflow graph devinition
"""
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
try:
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
graph_def.MergeFrom(old_graph_def)
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
return graph_def
示例9: read
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def read(self, pb_path: str) -> Graph:
"""Read TF file and load model.
Args:
pb_path (str): Path to TF file
Returns:
Model: Loaded model
"""
# load tensorflow model
graph_def = graph_pb2.GraphDef()
try:
f = open(path.abspath(pb_path), "rb")
graph_def.ParseFromString(f.read())
f.close()
except IOError:
print("Could not open file. Creating a new one.")
# import graph
graph = Importer.make_graph(graph_def)
return graph
示例10: create_tfevent_from_pb
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def create_tfevent_from_pb(model,optimized=False):
print("> creating tfevent of model: {}".format(model))
if optimized:
model_path=ROOT_DIR+'/models/{}/optimized_inference_graph.pb'.format(model)
log_dir=ROOT_DIR+'/models/{}/log_opt/'.format(model)
else:
model_path=ROOT_DIR+'/models/{}/frozen_inference_graph.pb'.format(model)
log_dir=ROOT_DIR+'/models/{}/log/'.format(model)
with session.Session(graph=ops.Graph()) as sess:
with gfile.FastGFile(model_path, "rb") as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
importer.import_graph_def(graph_def)
pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
print("> Model {} Imported. \nVisualize by running: \
tensorboard --logdir={}".format(model_path, log_dir))
# Gather all Model Names in models/
示例11: testStrippedOpListRecursiveFunctions
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def testStrippedOpListRecursiveFunctions(self):
# The function module doesn't support recursive functions, so we build a
# recursive function situation by ourselves: A calls B calls A and Const.
graph = graph_pb2.GraphDef()
a = graph.library.function.add()
b = graph.library.function.add()
a.signature.name = "A"
b.signature.name = "B"
a.node.add().op = "B"
b.node.add().op = "Const"
b.node.add().op = "A"
# Use A in the graph
graph.node.add().op = "A"
# The stripped op list should contain just Const.
op_list = tf.contrib.util.stripped_op_list_for_graph(graph)
self.assertEqual(["Const"], [op.name for op in op_list.op])
示例12: main
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def main(unused_args):
if not gfile.Exists(FLAGS.graph):
print("Input graph file '" + FLAGS.graph + "' does not exist!")
return -1
graph = graph_pb2.GraphDef()
with open(FLAGS.graph, "r") as f:
if FLAGS.input_binary:
graph.ParseFromString(f.read())
else:
text_format.Merge(f.read(), graph)
with open(FLAGS.dot_output, "wb") as f:
print("digraph graphname {", file=f)
for node in graph.node:
output_name = node.name
print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
for input_full_name in node.input:
parts = input_full_name.split(":")
input_name = re.sub(r"^\^", "", parts[0])
print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
print("}", file=f)
print("Created DOT file '" + FLAGS.dot_output + "'.")
示例13: load
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def load(self, model_path, inputs=None, outputs=None):
# there is no input/output meta data i the graph so it need to come from config.
if not inputs:
raise ValueError("BackendTensorflow needs inputs")
if not outputs:
raise ValueError("BackendTensorflow needs outputs")
self.outputs = outputs
self.inputs = inputs
# TODO: support checkpoint and saved_model formats?
graph_def = graph_pb2.GraphDef()
with open(model_path, "rb") as f:
graph_def.ParseFromString(f.read())
g = tf.compat.v1.import_graph_def(graph_def, name='')
self.sess = tf.compat.v1.Session(graph=g)
return self
示例14: _parse_input_graph_proto
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def _parse_input_graph_proto(input_graph, input_binary):
"""Parser input tensorflow graph into GraphDef proto."""
if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
return input_graph_def
示例15: do_quantize_training_on_graphdef
# 需要导入模块: from tensorflow.core.framework import graph_pb2 [as 别名]
# 或者: from tensorflow.core.framework.graph_pb2 import GraphDef [as 别名]
def do_quantize_training_on_graphdef(input_graph, num_bits):
from tensorflow.core.framework.graph_pb2 import GraphDef
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
graph = GraphDef()
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
input_graph.SerializeToString(), num_bits, status)
graph.ParseFromString(result_graph_string)
return graph