本文整理汇总了Python中tensorflow.get_collection_ref方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_collection_ref方法的具体用法?Python tensorflow.get_collection_ref怎么用?Python tensorflow.get_collection_ref使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.get_collection_ref方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: MarkAsNonTrainable
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def MarkAsNonTrainable(self):
"""Mark all the variables of this block as non-trainable.
All the variables owned directly or indirectly (through subblocks) are
marked as non trainable.
This function along with CheckpointInitOp can be used to load a pretrained
model that consists in only one part of the whole graph.
"""
assert self._called
all_variables = self.VariableList()
collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
for v in all_variables:
if v in collection:
collection.remove(v)
示例2: initialize_tbcnn_weights
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def initialize_tbcnn_weights(clz):
clz.initialize_embedding_weights()
# Don't train We
tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).remove(clz.get('We'))
clz.create_variable('Wcomb1', (hyper.word_dim, hyper.word_dim),
tf.constant_initializer(-.2, .2))
clz.create_variable('Wcomb2', (hyper.word_dim, hyper.word_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('Wconvt', (hyper.word_dim, hyper.conv_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('Wconvl', (hyper.word_dim, hyper.conv_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('Wconvr', (hyper.word_dim, hyper.conv_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('Bconv', (hyper.conv_dim,),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('FC1/weight', (hyper.conv_dim, hyper.fc_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('FC1/bias', (hyper.fc_dim,),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('FC2/weight', (hyper.fc_dim, hyper.output_dim),
tf.random_uniform_initializer(-.2, .2))
clz.create_variable('FC2/bias', (hyper.output_dim, ),
tf.random_uniform_initializer(-.2, .2))
示例3: test_tied_weights_untied_bias_registered_bias
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_tied_weights_untied_bias_registered_bias(self):
"""Tests that ambiguity in graph raises value error.
Graph search will find several possible registrations for tensors.
In this registering b_1 as a linked variable will result in an error
because there will remain an ambiguity on the other branch of the graph.
"""
with tf.Graph().as_default():
tensor_dict = _build_model()
layer_collection = lc.LayerCollection()
layer_collection.register_squared_error_loss(tensor_dict['out_0'])
layer_collection.register_squared_error_loss(tensor_dict['out_1'])
layer_collection.define_linked_parameters((tensor_dict['b_1']))
with self.assertRaises(gs.AmbiguousRegistrationError):
gs.register_layers(layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
示例4: mixed_usage_test
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def mixed_usage_test(self):
"""Tests that graph search raises error on mixed types usage for tensors.
Tensors can be reused in various locations in the tensorflow graph. This
occurs regularly in the case of recurrent models or models with parallel
graphs. However the tensors must be used for the same operation in each
location or graph search should raise an error.
"""
with tf.Graph().as_default():
w = tf.get_variable('W', [10, 10])
x = tf.placeholder(tf.float32, shape=(32, 10))
y = tf.placeholder(tf.float32, shape=(32, 10, 10))
out_0 = tf.matmul(x, w) # pylint: disable=unused-variable
out_1 = y + w # pylint: disable=unused-variable
layer_collection = lc.LayerCollection()
with self.assertRaises(ValueError) as cm:
gs.register_layers(layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
self.assertIn('mixed record types', str(cm.exception))
示例5: import_ops
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def import_ops(self):
"""Imports ops from collections."""
if self._is_training:
self._train_op = tf.get_collection_ref("train_op")[0]
self._lr = tf.get_collection_ref("lr")[0]
self._new_lr = tf.get_collection_ref("new_lr")[0]
self._lr_update = tf.get_collection_ref("lr_update")[0]
rnn_params = tf.get_collection_ref("rnn_params")
if self._cell and rnn_params:
params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
self._cell,
self._cell.params_to_canonical,
self._cell.canonical_to_params,
rnn_params,
base_variable_scope="Model/RNN")
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
self._initial_state = util.import_state_tuples(
self._initial_state, self._initial_state_name, num_replicas)
self._final_state = util.import_state_tuples(
self._final_state, self._final_state_name, num_replicas)
示例6: import_ops
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def import_ops(self):
"""Imports ops from collections."""
if self._is_training:
self._train_op = tf.get_collection_ref("train_op")[0]
self._lr = tf.get_collection_ref("lr")[0]
self._new_lr = tf.get_collection_ref("new_lr")[0]
self._lr_update = tf.get_collection_ref("lr_update")[0]
rnn_params = tf.get_collection_ref("rnn_params")
if self._cell and rnn_params:
params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
self._cell,
self._cell.params_to_canonical,
self._cell.canonical_to_params,
rnn_params,
base_variable_scope="Model/RNN")
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
self._cost = tf.get_collection_ref(tf_util.with_prefix(self._name, "cost"))[0]
self._kl_div = tf.get_collection_ref(tf_util.with_prefix(self._name, "kl_div"))[0]
num_replicas = 1
self._initial_state = tf_util.import_state_tuples(
self._initial_state, self._initial_state_name, num_replicas)
self._final_state = tf_util.import_state_tuples(
self._final_state, self._final_state_name, num_replicas)
示例7: load_model
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def load_model(sess, checkpoint_path):
meta_graph_location = checkpoint_path + '.meta'
saver = tf.train.import_meta_graph(
meta_graph_location, clear_devices=True, import_scope='m2'
)
saver.restore(sess, checkpoint_path)
sess.run(
set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))
)
示例8: global_mode
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def global_mode():
"""Returns the Tensor of global mode.
This is a placeholder with default value of
:tf_main:`tf.estimator.ModeKeys.TRAIN <estimator/ModeKeys>`.
Example:
.. code-block:: python
mode = session.run(global_mode())
# mode == tf.estimator.ModeKeys.TRAIN
mode = session.run(
global_mode(),
feed_dict={tf.global_mode(): tf.estimator.ModeKeys.PREDICT})
# mode == tf.estimator.ModeKeys.PREDICT
"""
mode = tf.get_collection_ref(_GLOBAL_MODE_KEY)
if len(mode) < 1:
#mode_tensor = tf.placeholder(tf.string, name="global_mode")
mode_tensor = tf.placeholder_with_default(
input=tf.estimator.ModeKeys.TRAIN,
shape=(),
name="global_mode")
#mode_tensor = tf.constant(
# value=tf.estimator.ModeKeys.TRAIN,
# dtype=tf.string,
# name="global_mode")
mode.append(mode_tensor)
return mode[0]
示例9: build_graph
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def build_graph(self, trainer, test_mode=False):
if test_mode:
self.saver = tf.train.Saver(max_to_keep=3)
else:
self._build_numerical_summaries(trainer)
self._build_img_summaries(trainer)
if trainer.train_extrap:
# Extrap training has an additional Adam optimizer with parameters not existed in the ckpt
ckpt_dir = self.force_load_from_dir if self.force_load_from_dir else self.ckpt_dir
ckpt_vars = set([v[0] for v in tf.train.list_variables(ckpt_dir)])
restore_var = [v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) if v.op.name in ckpt_vars]
self.saver = tf.train.Saver(max_to_keep=3, var_list=restore_var)
else:
self.saver = tf.train.Saver(max_to_keep=3)
示例10: restore_collection
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
示例11: clear_collection
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
示例12: test_multitower_multi_loss_function
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_multitower_multi_loss_function(self):
"""Test multitower setup with multiple loss functions.
The automatic graph scanner should handle multiple loss functions per tower,
as long as they're registered in a consistent order.
"""
with tf.Graph().as_default():
w_1 = tf.get_variable('w_1', shape=[10, 10])
b_1 = tf.get_variable('b_1', shape=[10])
w_2 = tf.get_variable('w_2', shape=[10, 10])
b_2 = tf.get_variable('b_2', shape=[10])
layer_collection = lc.LayerCollection()
layer_collection_manual = lc.LayerCollection()
for tower_num in range(5):
x = tf.placeholder(tf.float32, shape=(32, 10))
logits_1 = tf.matmul(x, w_1) + b_1
logits_2 = tf.matmul(x, w_2) + b_2
if tower_num == 0:
reuse = False
else:
reuse = True
with tf.variable_scope('tower%d' % tower_num, reuse=reuse):
for l in [layer_collection, layer_collection_manual]:
l.register_categorical_predictive_distribution(
logits_1, name='loss_1')
l.register_categorical_predictive_distribution(
logits_2, name='loss_2')
layer_collection_manual.register_fully_connected((w_1, b_1), x,
logits_1)
layer_collection_manual.register_fully_connected((w_2, b_2), x,
logits_2)
gs.register_layers(layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
assert_fisher_blocks_match(self, layer_collection,
layer_collection_manual)
示例13: test_graph_search_match_fail
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_graph_search_match_fail(self):
"""Tests graph search with linked bias tensors.
In this code snippet two non adjacent bias tensors are linked together.
There is no fisher block in kfac that matches this configuration, so the
biases should not be registered.
"""
with tf.Graph().as_default():
tensor_dict = _build_model()
layer_collection = lc.LayerCollection()
layer_collection.register_squared_error_loss(tensor_dict['out_0'])
layer_collection.register_squared_error_loss(tensor_dict['out_1'])
# TODO(b/69055612): remove this manual registration once layer_collection
# implements register_fully_connected_multi.
layer_collection.register_fully_connected(
tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0'])
layer_collection.define_linked_parameters((tensor_dict['b_0'],
tensor_dict['b_1']))
with self.assertRaises(ValueError) as cm:
gs.register_layers(layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
self.assertIn('in linked group', str(cm.exception))
self.assertIn('was not matched', str(cm.exception))
self.assertIn(
str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])),
str(cm.exception))
示例14: test_specify_approximation_shared_parameters
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_specify_approximation_shared_parameters(self):
"""Test specifying approximations with layers containing shared parameters.
If linked parameters are identified along with an approximation, then
that approximation should be used when registering those parameters.
"""
with tf.Graph().as_default():
tensor_dict = _build_model()
layer_collection = lc.LayerCollection()
layer_collection.register_squared_error_loss(tensor_dict['out_0'])
layer_collection.register_squared_error_loss(tensor_dict['out_1'])
layer_collection.define_linked_parameters(
tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME)
layer_collection.define_linked_parameters(
tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME)
layer_collection.define_linked_parameters(
tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME)
gs.register_layers(
layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),
batch_size=1)
self.assertIsInstance(layer_collection.fisher_blocks[tensor_dict['w']],
fb.FullyConnectedMultiIndepFB)
self.assertIsInstance(
layer_collection.fisher_blocks[tensor_dict['b_0']],
fb.NaiveDiagonalFB)
self.assertIsInstance(
layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB)
示例15: test_tied_weights_untied_bias_registered_weights
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_tied_weights_untied_bias_registered_weights(self):
"""Tests that graph search produces right solution on toy model."""
with tf.Graph().as_default():
tensor_dict = _build_model()
layer_collection_manual = lc.LayerCollection()
layer_collection_manual.register_squared_error_loss(tensor_dict['out_0'])
layer_collection_manual.register_squared_error_loss(tensor_dict['out_1'])
layer_collection_manual.register_fully_connected_multi(
tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']),
(tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1']))
layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1)
layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1)
layer_collection = lc.LayerCollection()
layer_collection.register_squared_error_loss(tensor_dict['out_0'])
layer_collection.register_squared_error_loss(tensor_dict['out_1'])
layer_collection.define_linked_parameters((tensor_dict['w']))
gs.register_layers(
layer_collection,
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),
batch_size=1)
assert_fisher_blocks_match(self, layer_collection,
layer_collection_manual)