本文整理汇总了Python中tensorflow.python.framework.ops.get_collection_proto_type函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection_proto_type函数的具体用法?Python get_collection_proto_type怎么用?Python get_collection_proto_type使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_collection_proto_type函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: assert_meta_graph_protos_equal
def assert_meta_graph_protos_equal(tester, a, b):
"""Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
# Carefully check the collection_defs
tester.assertEqual(set(a.collection_def), set(b.collection_def))
collection_keys = a.collection_def.keys()
for k in collection_keys:
a_value = a.collection_def[k]
b_value = b.collection_def[k]
proto_type = ops.get_collection_proto_type(k)
if proto_type:
a_proto = proto_type()
b_proto = proto_type()
# Number of entries in the collections is the same
tester.assertEqual(len(a_value.bytes_list.value),
len(b_value.bytes_list.value))
for (a_value_item, b_value_item) in zip(
a_value.bytes_list.value,
b_value.bytes_list.value):
a_proto.ParseFromString(a_value_item)
b_proto.ParseFromString(b_value_item)
tester.assertProtoEquals(a_proto, b_proto)
else:
tester.assertEquals(a_value, b_value)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("collection_def")
b.ClearField("collection_def")
tester.assertProtoEquals(a, b)
示例2: add_collection_def
def add_collection_def(meta_graph_def, key, graph=None,
export_scope=None):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
graph: The `Graph` from which to get collections.
export_scope: Optional `string`. Name scope to remove.
"""
if graph and not isinstance(graph, ops.Graph):
raise TypeError("graph must be of type Graph, not %s", type(graph))
if not isinstance(key, six.string_types) and not isinstance(key, bytes):
logging.warning("Only collections with string type keys will be "
"serialized. This key has %s", type(key))
return
# Sets graph to default graph if it's not passed in.
graph = graph or ops.get_default_graph()
collection_list = graph.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x, export_scope=export_scope)
if proto:
assert isinstance(proto, proto_type)
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
for x in collection_list:
if not export_scope or x.name.startswith(export_scope):
getattr(col_def, kind).value.append(
ops.strip_name_scope(x.name, export_scope))
elif kind == "bytes_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python3 distinguishes between bytes and strings.
getattr(col_def, kind).value.extend(
[compat.as_bytes(x) for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
logging.warning("Error encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s", key, str(e))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
示例3: _import_meta_graph_def
def _import_meta_graph_def(meta_graph_def):
"""Recreates a Graph saved in a a `MetaGraphDef` proto.
This function adds all the nodes from the meta graph def proto to the current
graph, recreates all the collections, and returns a saver from saver_def.
Args:
meta_graph_def: `MetaGraphDef` protocol buffer.
Returns:
A saver constructed rom `saver_def` in `meta_graph_def`.
"""
# Gathers the list of nodes we are interested in.
importer.import_graph_def(meta_graph_def.graph_def, name="")
# Restores all the other collections.
for key, col_def in meta_graph_def.collection_def.items():
kind = col_def.WhichOneof("kind")
if kind is None:
logging.error("Cannot identify data type for collection %s. Skipping."
% key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto:
assert kind == "bytes_list"
proto_type = ops.get_collection_proto_type(key)
for value in col_def.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
ops.add_to_collection(key, from_proto(proto))
else:
field = getattr(col_def, kind)
if kind == "node_list":
for value in field.value:
col_op = ops.get_default_graph().as_graph_element(value)
ops.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python2 distinguishes between int and long, while Python3 has
# only int.
for value in field.value:
ops.add_to_collection(key, int(value))
else:
for value in field.value:
ops.add_to_collection(key, value)
if meta_graph_def.HasField("saver_def"):
return Saver(saver_def=meta_graph_def.saver_def)
else:
return Saver()
示例4: _get_all_protos_from_collection
def _get_all_protos_from_collection(meta_graph_def, collection_key):
"""Obtain node names from a collection."""
if collection_key not in meta_graph_def.collection_def:
return []
collection = meta_graph_def.collection_def[collection_key]
if not collection.bytes_list.value:
raise ValueError(
'Collection {} is present but type is not bytes_list.'.format(
collection_key))
proto_type = _ops.get_collection_proto_type(collection_key)
result = []
for value in collection.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
result.append(proto)
return result
示例5: add_collection_def
def add_collection_def(meta_graph_def, key):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
"""
if not isinstance(key, six.string_types) and not isinstance(key, bytes):
logging.warning("Only collections with string type keys will be " "serialized. This key has %s", type(key))
return
collection_list = ops.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x)
if not isinstance(proto, proto_type):
raise TypeError("proto %s is not type %s" % (proto, proto_type))
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
elif kind == "bytes_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python3 distinguishes between bytes and strings.
getattr(col_def, kind).value.extend([compat.as_bytes(x) for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
logging.warning(
"Error encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s",
key,
str(e),
)
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
示例6: _restore_collections
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
"""Restores collections that we need to keep."""
scope = ""
for key in collection_keys:
collection_def = src_meta_graph_def.collection_def[key]
kind = collection_def.WhichOneof("kind")
if kind is None:
tf_logging.error(
"Cannot identify data type for collection %s. Skipping.", key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto and kind == "bytes_list":
proto_type = ops.get_collection_proto_type(key)
# It is assumed that there are no Variables Keys in collections
for value in collection_def.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
try:
new_value = from_proto(proto, import_scope=scope)
except:
continue
dest_graph.add_to_collection(key, new_value)
else:
field = getattr(collection_def, kind)
if kind == "node_list":
for value in field.value:
name = ops.prepend_name_scope(value, scope)
# Since the graph has been optimized, the node may no longer
# exists
try:
col_op = dest_graph.as_graph_element(name)
except (TypeError, ValueError, KeyError) as e:
continue
dest_graph.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the
# fact that Python2 distinguishes between int and long, while
# Python3 has only int.
for value in field.value:
dest_graph.add_to_collection(key, int(value))
else:
for value in field.value:
dest_graph.add_to_collection(key,
ops.prepend_name_scope(value, scope))
示例7: _add_collection_def
def _add_collection_def(meta_graph_def, key):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
"""
if not isinstance(key, (str, bytes, unicode)):
logging.warning("Only collections with string type keys will be "
"serialized. This key has %s" % type(key))
return
collection_list = ops.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x)
assert isinstance(proto, proto_type)
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception, e: # pylint: disable=broad-except
logging.warning("Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s" % str(e))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
示例8: import_scoped_meta_graph
#.........这里部分代码省略.........
kind = col_def.WhichOneof("kind")
field = getattr(col_def, kind)
if field.value and (
not input_map or
sorted([compat.as_str(v) for v in field.value]) !=
sorted(input_map)):
raise ValueError("Graph contains unbound inputs: %s. Must "
"provide these inputs through input_map." %
",".join([compat.as_str(v) for v in field.value
if not input_map or v not in input_map]))
break
# Sets graph to default graph if it's not passed in.
graph = graph or ops.get_default_graph()
# Gathers the list of nodes we are interested in.
with graph.as_default():
producer_op_list = None
if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
input_graph_def = meta_graph_def.graph_def
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
scope_to_prepend_to_names = graph.unique_name(
import_scope or "", mark_as_used=False)
importer.import_graph_def(
input_graph_def,
name=(import_scope or scope_to_prepend_to_names),
input_map=input_map,
producer_op_list=producer_op_list)
# Restores all the other collections.
variable_objects = {}
for key, col_def in sorted(meta_graph_def.collection_def.items()):
# Don't add unbound_inputs to the new graph.
if key == unbound_inputs_col_name:
continue
if not restore_collections_predicate(key):
continue
kind = col_def.WhichOneof("kind")
if kind is None:
logging.error("Cannot identify data type for collection %s. Skipping.",
key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto and kind == "bytes_list":
proto_type = ops.get_collection_proto_type(key)
if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
for value in col_def.bytes_list.value:
variable = variable_objects.get(value, None)
if variable is None:
proto = proto_type()
proto.ParseFromString(value)
variable = from_proto(
proto, import_scope=scope_to_prepend_to_names)
variable_objects[value] = variable
graph.add_to_collection(key, variable)
else:
for value in col_def.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
graph.add_to_collection(
key, from_proto(
proto, import_scope=scope_to_prepend_to_names))
else:
field = getattr(col_def, kind)
if key in _COMPAT_COLLECTION_LIST:
logging.warning(
"The saved meta_graph is possibly from an older release:\n"
"'%s' collection should be of type 'byte_list', but instead "
"is of type '%s'.", key, kind)
if kind == "node_list":
for value in field.value:
col_op = graph.as_graph_element(
ops.prepend_name_scope(value, scope_to_prepend_to_names))
graph.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python2 distinguishes between int and long, while Python3 has
# only int.
for value in field.value:
graph.add_to_collection(key, int(value))
else:
for value in field.value:
graph.add_to_collection(
key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
var_list = {}
variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=scope_to_prepend_to_names)
for v in variables:
var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
return var_list
示例9: import_scoped_meta_graph
def import_scoped_meta_graph(meta_graph_or_file,
clear_devices=False,
graph=None,
import_scope=None,
input_map=None,
unbound_inputs_col_name="unbound_inputs"):
"""Recreates a`Graph` saved in a `MetaGraphDef` proto.
This function takes a `MetaGraphDef` protocol buffer as input. If
the argument is a file containing a `MetaGraphDef` protocol buffer ,
it constructs a protocol buffer from the file content. The function
then adds all the nodes from the `graph_def` field to the
current graph, recreates all the collections, and returns a saver
constructed from the `saver_def` field.
In combination with `export_scoped_meta_graph()`, this function can be used to
* Serialize a graph along with other Python objects such as `QueueRunner`,
`Variable` into a `MetaGraphDef`.
* Restart training from a saved graph and checkpoints.
* Run inference from a saved graph and checkpoints.
Args:
meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
the path) containing a `MetaGraphDef`.
clear_devices: Boolean which controls whether to clear device information
from graph_def. Default false.
graph: The `Graph` to import into. If `None`, use the default graph.
import_scope: Optional `string`. Name scope into which to import the
subgraph. If `None`, the graph is imported to the root name scope.
input_map: A dictionary mapping input names (as strings) in `graph_def` to
`Tensor` objects. The values of the named input tensors in the imported
graph will be re-mapped to the respective `Tensor` values.
unbound_inputs_col_name: Collection name for looking up unbound inputs.
Returns:
A dictionary of all the `Variables` imported into the name scope.
Raises:
ValueError: If the graph_def contains unbound inputs.
"""
if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
meta_graph_def = meta_graph_or_file
else:
meta_graph_def = read_meta_graph_file(meta_graph_or_file)
if unbound_inputs_col_name:
for key, col_def in meta_graph_def.collection_def.items():
if key == unbound_inputs_col_name:
kind = col_def.WhichOneof("kind")
field = getattr(col_def, kind)
if field.value and (
not input_map or
sorted([compat.as_str(v) for v in field.value]) !=
sorted(input_map)):
raise ValueError("Graph contains unbound inputs: %s. Must "
"provide these inputs through input_map." %
",".join([compat.as_str(v) for v in field.value]))
break
# Sets graph to default graph if it's not passed in.
graph = graph or ops.get_default_graph()
# Gathers the list of nodes we are interested in.
with graph.as_default():
producer_op_list = None
if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
input_graph_def = meta_graph_def.graph_def
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
importer.import_graph_def(
input_graph_def, name=(import_scope or ""), input_map=input_map,
producer_op_list=producer_op_list)
# Restores all the other collections.
for key, col_def in meta_graph_def.collection_def.items():
# Don't add unbound_inputs to the new graph.
if key == unbound_inputs_col_name:
continue
kind = col_def.WhichOneof("kind")
if kind is None:
logging.error("Cannot identify data type for collection %s. Skipping.",
key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto:
assert kind == "bytes_list"
proto_type = ops.get_collection_proto_type(key)
for value in col_def.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
graph.add_to_collection(
key, from_proto(proto, import_scope=import_scope))
#.........这里部分代码省略.........