本文整理汇总了Python中tensorflow.python.framework.ops.get_collection方法的典型用法代码示例。如果您正苦于以下问题:Python ops.get_collection方法的具体用法?Python ops.get_collection怎么用?Python ops.get_collection使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.framework.ops
的用法示例。
在下文中一共展示了ops.get_collection方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_first_op_from_collection
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def _get_first_op_from_collection(self, key):
"""Returns the first `Operation` from a collection.
Args:
key: A string collection key.
Returns:
The first Op found in a collection, or `None` if the collection is empty.
"""
try:
op_list = ops.get_collection(key)
if len(op_list) > 1:
logging.info("Found %d %s operations. Returning the first one.",
len(op_list), key)
if op_list:
return op_list[0]
except LookupError:
pass
return None
示例2: start_queue_runners
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def start_queue_runners(self, sess, queue_runners=None):
"""Start threads for `QueueRunners`.
Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
are already started automatically when you create a session with the
supervisor, so unless you have non-collected queue runners to start
you do not need to call this explicitly.
Args:
sess: A `Session`.
queue_runners: A list of `QueueRunners`. If not specified, we'll use the
list of queue runners gathered in the graph under the key
`GraphKeys.QUEUE_RUNNERS`.
Returns:
The list of threads started for the `QueueRunners`.
"""
if queue_runners is None:
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
threads = []
for qr in queue_runners:
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
start=True))
return threads
示例3: convert_collection_to_dict
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def convert_collection_to_dict(collection, clear_collection=False):
"""Returns an OrderedDict of Tensors with their aliases as keys.
Args:
collection: A collection.
clear_collection: When True, it clears the collection after converting to
OrderedDict.
Returns:
An OrderedDict of {alias: tensor}
"""
output = OrderedDict((alias, tensor)
for tensor in ops.get_collection(collection)
for alias in get_tensor_aliases(tensor))
if clear_collection:
ops.get_default_graph().clear_collection(collection)
return output
示例4: _get_main_op_tensor
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def _get_main_op_tensor(meta_graph_def_to_load):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
The main op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the main op key has
other than exactly one tensor.
"""
collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None
if constants.MAIN_OP_KEY in collection_def:
main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
if len(main_ops) != 1:
raise RuntimeError("Expected exactly one SavedModel main op.")
main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
return main_op_tensor
示例5: _get_legacy_init_op_tensor
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def _get_legacy_init_op_tensor(meta_graph_def_to_load):
"""Gets the legacy init op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
The legacy init op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the legacy init op key
has other than exactly one tensor.
"""
collection_def = meta_graph_def_to_load.collection_def
legacy_init_op_tensor = None
if constants.LEGACY_INIT_OP_KEY in collection_def:
legacy_init_ops = collection_def[
constants.LEGACY_INIT_OP_KEY].node_list.value
if len(legacy_init_ops) != 1:
raise RuntimeError("Expected exactly one legacy serving init op.")
legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
return legacy_init_op_tensor
示例6: merge_all_summaries
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
"""Merges all summaries collected in the default graph.
This op is deprecated. Please switch to tf.summary.merge_all, which has
identical behavior.
Args:
key: `GraphKey` used to collect the summaries. Defaults to
`GraphKeys.SUMMARIES`.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_ops = ops.get_collection(key)
if not summary_ops:
return None
else:
return merge_summary(summary_ops)
示例7: global_variables
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def global_variables():
"""Returns global variables.
Global variables are variables that are shared across machines in a
distributed environment. The `Variable()` constructor or `get_variable()`
automatically adds new variables to the graph collection
`GraphKeys.GLOBAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to global variables are local variables. See
@{tf.local_variables}
Returns:
A list of `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
示例8: local_variables
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def local_variables():
"""Returns local variables.
Local variables - per process variables, usually not saved/restored to
checkpoint and used for temporary or intermediate values.
For example, they can be used as counters for metrics computation or
number of epochs this machine has read data.
The `tf.contrib.framework.local_variable()` function automatically adds the
new variable to `GraphKeys.LOCAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to local variables are global variables. See
@{tf.global_variables}
Returns:
A list of local `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
示例9: merge_all
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def merge_all(key=_ops.GraphKeys.SUMMARIES):
"""Merges all summaries collected in the default graph.
Args:
key: `GraphKey` used to collect the summaries. Defaults to
`GraphKeys.SUMMARIES`.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_ops = _ops.get_collection(key)
if not summary_ops:
return None
else:
return merge(summary_ops)
示例10: _get_concat_variable
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
示例11: _export_graph
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def _export_graph(graph, saver, checkpoint_path, export_dir,
default_graph_signature, named_graph_signatures,
exports_to_keep):
"""Exports graph via session_bundle, by creating a Session."""
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(
init_op=control_flow_ops.group(
variables.local_variables_initializer(),
lookup_ops.tables_initializer()),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
示例12: get_variables
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def get_variables(scope=None, suffix=None,
collection=ops.GraphKeys.GLOBAL_VARIABLES):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return. Can be a
variable scope or a string.
suffix: an optional suffix for filtering the variables to return.
collection: in which collection search for. Defaults to
`GraphKeys.GLOBAL_VARIABLES`.
Returns:
a list of variables in collection with scope and suffix.
"""
if isinstance(scope, variable_scope.VariableScope):
scope = scope.name
if suffix is not None:
if ':' not in suffix:
suffix += ':'
scope = (scope or '') + '.*' + suffix
return ops.get_collection(collection, scope)
示例13: testClearDevices
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def testClearDevices(self):
export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Specify a device and save a variable.
ops.reset_default_graph()
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph with a single predefined tag whose variables were saved
# without any device information.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
示例14: get_summary_op
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def get_summary_op():
"""Returns a single Summary op that would run all summaries.
Either existing one from `SUMMARY_OP` collection or merges all existing
summaries.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP)
if summary_op is not None:
if summary_op:
summary_op = summary_op[0]
else:
summary_op = None
if summary_op is None:
summary_op = merge_all_summaries()
if summary_op is not None:
ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
return summary_op
示例15: global_variables
# 需要导入模块: from tensorflow.python.framework import ops [as 别名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 别名]
def global_variables():
"""Returns global variables.
Global variables are variables that are shared across machines in a
distributed environment. The `Variable()` constructor or `get_variable()`
automatically adds new variables to the graph collection
`GraphKeys.GLOBAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to global variables are local variables. See
[`tf.local_variables()`](../../api_docs/python/state_ops.md#local_variables)
Returns:
A list of `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)