本文整理汇总了Python中tensorflow.python.framework.ops.add_to_collection函数的典型用法代码示例。如果您正苦于以下问题:Python add_to_collection函数的具体用法?Python add_to_collection怎么用?Python add_to_collection使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了add_to_collection函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _compute_weighted_loss
def _compute_weighted_loss(losses, weight):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
Returns:
A scalar `Tensor` that returns the weighted loss.
Raises:
ValueError: If the weight shape is not compatible with the losses shape or
if the number of dimensions (rank) of either losses or weight is missing.
"""
losses = math_ops.to_float(losses)
weight = math_ops.to_float(ops.convert_to_tensor(weight))
if losses.get_shape().ndims is None:
raise ValueError("losses.get_shape().ndims cannot be None")
if weight.get_shape().ndims is None:
raise ValueError("weight.get_shape().ndims cannot be None")
total_loss = _scale_losses(losses, weight)
num_present = _num_present(losses, weight)
mean_loss = _safe_mean(total_loss, num_present)
ops.add_to_collection(ops.GraphKeys.LOSSES, mean_loss)
return mean_loss
示例2: fertile_stats_variable
def fertile_stats_variable(params, stats_config, name, container=None):
r"""Creates a stats object and returns a handle to it.
Args:
params: A TensorForestParams object.
stats_config: A `Tensor` of type `string`. Serialized proto of the stats.
name: A name for the variable.
container: An optional `string`. Defaults to `""`.
Returns:
A `Tensor` of type mutable `string`. The handle to the stats.
"""
with ops.name_scope(name, "FertileStatsVariable") as name:
fertile_stats_var = FertileStatsVariable(params, stats_config, name,
container)
resource_handle = fertile_stats_var.resource_handle
create_op = fertile_stats_var.initializer
is_initialized_op = fertile_stats_var.is_initialized()
# Adds the variable to the savable list.
saveable = (
fertile_stats_var._gather_saveables_for_checkpoint()[ # pylint: disable=protected-access
"fertile_stats_variable"](name=resource_handle.name))
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
resources.register_resource(resource_handle, create_op, is_initialized_op)
return resource_handle
示例3: _get_default_variable_store
def _get_default_variable_store():
store = ops.get_collection(_VARSTORE_KEY)
if store:
return store[0]
store = _VariableStore()
ops.add_to_collection(_VARSTORE_KEY, store)
return store
示例4: _get_or_create_global_step_read
def _get_or_create_global_step_read(graph=None):
"""Gets or creates global step read tensor in graph.
Args:
graph: The graph in which to create the global step read tensor. If missing,
use default graph.
Returns:
Global step read tensor if there is global_step_tensor else return None.
"""
graph = graph or ops.get_default_graph()
global_step_read_tensor = _get_global_step_read(graph)
if global_step_read_tensor is not None:
return global_step_read_tensor
global_step_tensor = get_global_step(graph)
if global_step_tensor is None:
return None
# add 'zero' so that it will create a copy of variable as Tensor.
with graph.as_default() as g, g.name_scope(None):
# using initialized_value to ensure that global_step is initialized before
# this run. This is needed for example Estimator makes all model_fn build
# under global_step_read_tensor dependency.
global_step_value = global_step_tensor.initialized_value() if isinstance(
global_step_tensor, variables.Variable) else global_step_tensor
global_step_read_tensor = global_step_value + 0
ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
return _get_global_step_read(graph)
示例5: _maybe_add_main_op
def _maybe_add_main_op(self, main_op):
"""Adds main op to the SavedModel.
Args:
main_op: Main op to run as part of graph initialization. If None, no
main op will be added to the graph.
Raises:
TypeError: if main op is provided but is not of type `Operation`.
ValueError: if the Graph already contains an init op.
"""
if main_op is None:
return
if not isinstance(main_op, ops.Operation):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
# Validate that no other init ops have been added to this graph already.
# We check main_op and legacy_init_op for thoroughness and explicitness.
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
if ops.get_collection(init_op_key):
raise ValueError(
"Graph already contains one or more main ops under the "
"collection {}.".format(init_op_key))
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
示例6: initialize_from
def initialize_from(self, keys, values, name=None):
"""Initialize the table with the provided keys and values tensors.
Construct an initializer object from keys and value tensors.
Args:
keys: The tensor for the keys.
values: The tensor for the values.
name: Optional name for the op.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
if name is None:
name = "%s_initialize_table" % self.name
with ops.op_scope([keys, values], None, name):
keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
values = ops.convert_to_tensor(values, dtype=self.value_dtype, name="values")
init_op = gen_data_flow_ops._initialize_table(self.table_ref, keys, values, name=name)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
示例7: __init__
def __init__(self, iterator_resource, initializer, output_types,
output_shapes, output_classes):
"""Creates a new iterator from the given iterator resource.
Note: Most users will not call this initializer directly, and will
instead use `Dataset.make_initializable_iterator()` or
`Dataset.make_one_shot_iterator()`.
Args:
iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
iterator.
initializer: A `tf.Operation` that should be run to initialize this
iterator.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element of this iterator.
output_shapes: A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this iterator.
output_classes: A nested structure of Python `type` objects corresponding
to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
self._initializer = initializer
if (output_types is None or output_shapes is None
or output_classes is None):
raise ValueError("If `structure` is not specified, all of "
"`output_types`, `output_shapes`, and `output_classes`"
" must be specified.")
self._structure = structure_lib.convert_legacy_structure(
output_types, output_shapes, output_classes)
self._string_handle = gen_dataset_ops.iterator_to_string_handle(
self._iterator_resource)
self._get_next_call_count = 0
ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
示例8: apply_regularization
def apply_regularization(regularizer, weights_list=None):
"""Returns the summed penalty by applying `regularizer` to the `weights_list`.
Adding a regularization penalty over the layer weights and embedding weights
can help prevent overfitting the training data. Regularization over layer
biases is less common/useful, but assuming proper data preprocessing/mean
subtraction, it usually shouldn't hurt much either.
Args:
regularizer: A function that takes a single `Tensor` argument and returns
a scalar `Tensor` output.
weights_list: List of weights `Tensors` or `Variables` to apply
`regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
`None`.
Returns:
A scalar representing the overall regularization penalty.
Raises:
ValueError: If `regularizer` does not return a scalar output.
"""
if not weights_list:
weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
with ops.op_scope(weights_list, 'get_regularization_penalty') as scope:
penalties = [regularizer(w) for w in weights_list]
for p in penalties:
if p.get_shape().ndims != 0:
raise ValueError('regularizer must return a scalar Tensor instead of a '
'Tensor with rank %d.' % p.get_shape().ndims)
summed_penalty = math_ops.add_n(penalties, name=scope)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
return summed_penalty
示例9: testKeepNodes
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
with g._attr_scope(
{'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
a3 = constant_op.constant(0, name='keep2')
b = constant_op.constant(1, shape=[100, 10])
c = constant_op.constant(0, shape=[10, 30])
d = math_ops.matmul(b, c)
ops.add_to_collection('train_op', d) # d is the fetch node.
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
config = config_pb2.ConfigProto()
rewriter_config = config.graph_options.rewrite_options
rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
# Check that the nodes referenced in various collections have been preserved
optimized_graph_nodes = [node.name for node in optimized_graph.node]
expected_nodes = [
d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
'Variable/Assign'
]
self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
self.assertAllInSet(optimized_graph_nodes, expected_nodes)
示例10: _CreateParamsSavable
def _CreateParamsSavable(params,
model,
base_variable_scope=None,
name="params_canonical"):
"""Create a RNNParamsSaveable for the weight and bias parameters.
Args:
params: a Variable for weight and bias parameters.
model: a CudnnRNN model.
base_variable_scope: a string, prefix of names of saved variables.
name: a string, name of the RNNParamsSaveable object.
Returns:
a RNNParamsSaveable object.
"""
if model._rnn_mode == CUDNN_LSTM:
fn = cudnn_rnn_ops.CudnnLSTMSaveable
elif model._rnn_mode == CUDNN_GRU:
fn = cudnn_rnn_ops.CudnnGRUSaveable
elif model._rnn_mode == CUDNN_RNN_TANH:
fn = cudnn_rnn_ops.CudnnRNNTanhSaveable
elif model._rnn_mode == CUDNN_RNN_RELU:
fn = cudnn_rnn_ops.CudnnRNNReluSaveable
params_saveable = fn(
params,
model.num_layers,
model.num_units,
model.input_size,
model.input_mode,
model.direction,
scope=base_variable_scope,
name=name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
return params_saveable
示例11: _train_model
def _train_model(self, input_fn, hooks):
all_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = training.create_global_step(g)
with ops.device('/cpu:0'):
features, labels = input_fn()
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
all_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
{
'loss': estimator_spec.loss,
'step': global_step_tensor
},
every_n_iter=100)
])
all_hooks.extend(hooks)
all_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or
ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
defer_build=True))
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
saver_hook_exists = any([
isinstance(h, training.CheckpointSaverHook)
for h in (all_hooks + chief_hooks +
estimator_spec.training_chief_hooks)
])
if not saver_hook_exists:
chief_hooks = [
training.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=estimator_spec.scaffold)
]
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=all_hooks,
chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
return loss
示例12: initialize
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
_check_table_dtypes(table, self._keys.dtype, self._values.dtype)
with ops.name_scope(
self._name, values=(table.table_ref, self._keys,
self._values)) as scope:
if context.executing_eagerly():
# Ensure a unique name when eager execution is enabled to avoid spurious
# sharing issues.
scope += str(ops.uid())
init_op = gen_lookup_ops.initialize_table_v2(
table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
示例13: _add_iterator_ops_to_collection
def _add_iterator_ops_to_collection(self, init_op, get_next):
ops.add_to_collection("iterator_ops", init_op)
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
for el in nest.flatten(get_next):
ops.add_to_collection("iterator_ops", el)
示例14: tree_variable
def tree_variable(params, tree_config, stats_handle, name, container=None):
r"""Creates a tree model and returns a handle to it.
Args:
params: A TensorForestParams object.
tree_config: A `Tensor` of type `string`. Serialized proto of the tree.
stats_handle: Resource handle to the stats object.
name: A name for the variable.
container: An optional `string`. Defaults to `""`.
Returns:
A `Tensor` of type mutable `string`. The handle to the tree.
"""
with ops.name_scope(name, "TreeVariable") as name:
resource_handle = gen_model_ops.decision_tree_resource_handle_op(
container, shared_name=name, name=name)
create_op = gen_model_ops.create_tree_variable(
resource_handle,
tree_config,
params=params.serialized_params_proto)
is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle)
# Adds the variable to the savable list.
saveable = TreeVariableSavable(params, resource_handle, stats_handle,
create_op,
resource_handle.name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
resources.register_resource(resource_handle, create_op, is_initialized_op)
return resource_handle
示例15: testCustomSaveable
def testCustomSaveable(self):
export_dir = self._get_export_dir("custom_saveable")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# CheckpointedOp is a key-value table that can be saved across sessions.
# The table register itself in SAVEABLE_OBJECTS collection.
v1 = saver_test_utils.CheckpointedOp(name="v1")
variables.global_variables_initializer().run()
v1.insert("k1", 3.0).run()
# Once the table is restored, we can access it through this reference.
ops.add_to_collection("table_ref", v1.table_ref)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
loader.load(sess, ["foo"], export_dir)
# Instantiate a wrapper object from the checkpointed reference.
v1 = saver_test_utils.CheckpointedOp(
name="v1", table_ref=ops.get_collection("table_ref")[0])
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())