本文整理匯總了Python中tensorflow.python.framework.ops.get_collection方法的典型用法代碼示例。如果您正苦於以下問題:Python ops.get_collection方法的具體用法?Python ops.get_collection怎麽用?Python ops.get_collection使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.framework.ops
的用法示例。
在下文中一共展示了ops.get_collection方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: _get_first_op_from_collection
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def _get_first_op_from_collection(self, key):
"""Returns the first `Operation` from a collection.
Args:
key: A string collection key.
Returns:
The first Op found in a collection, or `None` if the collection is empty.
"""
try:
op_list = ops.get_collection(key)
if len(op_list) > 1:
logging.info("Found %d %s operations. Returning the first one.",
len(op_list), key)
if op_list:
return op_list[0]
except LookupError:
pass
return None
示例2: start_queue_runners
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def start_queue_runners(self, sess, queue_runners=None):
"""Start threads for `QueueRunners`.
Note that the queue runners collected in the graph key `QUEUE_RUNNERS`
are already started automatically when you create a session with the
supervisor, so unless you have non-collected queue runners to start
you do not need to call this explicitly.
Args:
sess: A `Session`.
queue_runners: A list of `QueueRunners`. If not specified, we'll use the
list of queue runners gathered in the graph under the key
`GraphKeys.QUEUE_RUNNERS`.
Returns:
The list of threads started for the `QueueRunners`.
"""
if queue_runners is None:
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
threads = []
for qr in queue_runners:
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
start=True))
return threads
示例3: convert_collection_to_dict
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def convert_collection_to_dict(collection, clear_collection=False):
"""Returns an OrderedDict of Tensors with their aliases as keys.
Args:
collection: A collection.
clear_collection: When True, it clears the collection after converting to
OrderedDict.
Returns:
An OrderedDict of {alias: tensor}
"""
output = OrderedDict((alias, tensor)
for tensor in ops.get_collection(collection)
for alias in get_tensor_aliases(tensor))
if clear_collection:
ops.get_default_graph().clear_collection(collection)
return output
示例4: _get_main_op_tensor
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def _get_main_op_tensor(meta_graph_def_to_load):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
The main op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the main op key has
other than exactly one tensor.
"""
collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None
if constants.MAIN_OP_KEY in collection_def:
main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
if len(main_ops) != 1:
raise RuntimeError("Expected exactly one SavedModel main op.")
main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
return main_op_tensor
示例5: _get_legacy_init_op_tensor
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def _get_legacy_init_op_tensor(meta_graph_def_to_load):
"""Gets the legacy init op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
The legacy init op tensor, if it exists and `None` otherwise.
Raises:
RuntimeError: If the collection def corresponding to the legacy init op key
has other than exactly one tensor.
"""
collection_def = meta_graph_def_to_load.collection_def
legacy_init_op_tensor = None
if constants.LEGACY_INIT_OP_KEY in collection_def:
legacy_init_ops = collection_def[
constants.LEGACY_INIT_OP_KEY].node_list.value
if len(legacy_init_ops) != 1:
raise RuntimeError("Expected exactly one legacy serving init op.")
legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
return legacy_init_op_tensor
示例6: merge_all_summaries
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES):
"""Merges all summaries collected in the default graph.
This op is deprecated. Please switch to tf.summary.merge_all, which has
identical behavior.
Args:
key: `GraphKey` used to collect the summaries. Defaults to
`GraphKeys.SUMMARIES`.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_ops = ops.get_collection(key)
if not summary_ops:
return None
else:
return merge_summary(summary_ops)
示例7: global_variables
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def global_variables():
"""Returns global variables.
Global variables are variables that are shared across machines in a
distributed environment. The `Variable()` constructor or `get_variable()`
automatically adds new variables to the graph collection
`GraphKeys.GLOBAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to global variables are local variables. See
@{tf.local_variables}
Returns:
A list of `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
示例8: local_variables
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def local_variables():
"""Returns local variables.
Local variables - per process variables, usually not saved/restored to
checkpoint and used for temporary or intermediate values.
For example, they can be used as counters for metrics computation or
number of epochs this machine has read data.
The `tf.contrib.framework.local_variable()` function automatically adds the
new variable to `GraphKeys.LOCAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to local variables are global variables. See
@{tf.global_variables}
Returns:
A list of local `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
示例9: merge_all
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def merge_all(key=_ops.GraphKeys.SUMMARIES):
"""Merges all summaries collected in the default graph.
Args:
key: `GraphKey` used to collect the summaries. Defaults to
`GraphKeys.SUMMARIES`.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_ops = _ops.get_collection(key)
if not summary_ops:
return None
else:
return merge(summary_ops)
示例10: _get_concat_variable
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
示例11: _export_graph
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def _export_graph(graph, saver, checkpoint_path, export_dir,
default_graph_signature, named_graph_signatures,
exports_to_keep):
"""Exports graph via session_bundle, by creating a Session."""
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(
init_op=control_flow_ops.group(
variables.local_variables_initializer(),
lookup_ops.tables_initializer()),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
示例12: get_variables
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def get_variables(scope=None, suffix=None,
collection=ops.GraphKeys.GLOBAL_VARIABLES):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return. Can be a
variable scope or a string.
suffix: an optional suffix for filtering the variables to return.
collection: in which collection search for. Defaults to
`GraphKeys.GLOBAL_VARIABLES`.
Returns:
a list of variables in collection with scope and suffix.
"""
if isinstance(scope, variable_scope.VariableScope):
scope = scope.name
if suffix is not None:
if ':' not in suffix:
suffix += ':'
scope = (scope or '') + '.*' + suffix
return ops.get_collection(collection, scope)
示例13: testClearDevices
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def testClearDevices(self):
export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Specify a device and save a variable.
ops.reset_default_graph()
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph with a single predefined tag whose variables were saved
# without any device information.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
示例14: get_summary_op
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def get_summary_op():
"""Returns a single Summary op that would run all summaries.
Either existing one from `SUMMARY_OP` collection or merges all existing
summaries.
Returns:
If no summaries were collected, returns None. Otherwise returns a scalar
`Tensor` of type `string` containing the serialized `Summary` protocol
buffer resulting from the merging.
"""
summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP)
if summary_op is not None:
if summary_op:
summary_op = summary_op[0]
else:
summary_op = None
if summary_op is None:
summary_op = merge_all_summaries()
if summary_op is not None:
ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op)
return summary_op
示例15: global_variables
# 需要導入模塊: from tensorflow.python.framework import ops [as 別名]
# 或者: from tensorflow.python.framework.ops import get_collection [as 別名]
def global_variables():
"""Returns global variables.
Global variables are variables that are shared across machines in a
distributed environment. The `Variable()` constructor or `get_variable()`
automatically adds new variables to the graph collection
`GraphKeys.GLOBAL_VARIABLES`.
This convenience function returns the contents of that collection.
An alternative to global variables are local variables. See
[`tf.local_variables()`](../../api_docs/python/state_ops.md#local_variables)
Returns:
A list of `Variable` objects.
"""
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)