当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.get_collection_ref方法代码示例

本文整理汇总了Python中tensorflow.get_collection_ref方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.get_collection_ref方法的具体用法?Python tensorflow.get_collection_ref怎么用?Python tensorflow.get_collection_ref使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.get_collection_ref方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: MarkAsNonTrainable

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def MarkAsNonTrainable(self):
    """Mark all the variables of this block as non-trainable.

    All the variables owned directly or indirectly (through subblocks) are
    marked as non trainable.

    This function along with CheckpointInitOp can be used to load a pretrained
    model that consists in only one part of the whole graph.
    """
    assert self._called

    all_variables = self.VariableList()
    collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    for v in all_variables:
      if v in collection:
        collection.remove(v) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:18,代码来源:block_base.py

示例2: initialize_tbcnn_weights

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def initialize_tbcnn_weights(clz):
        clz.initialize_embedding_weights()
        # Don't train We
        tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).remove(clz.get('We'))

        clz.create_variable('Wcomb1', (hyper.word_dim, hyper.word_dim),
                            tf.constant_initializer(-.2, .2))
        clz.create_variable('Wcomb2', (hyper.word_dim, hyper.word_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvt', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvl', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Wconvr', (hyper.word_dim, hyper.conv_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('Bconv', (hyper.conv_dim,),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC1/weight', (hyper.conv_dim, hyper.fc_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC1/bias', (hyper.fc_dim,),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC2/weight', (hyper.fc_dim, hyper.output_dim),
                            tf.random_uniform_initializer(-.2, .2))
        clz.create_variable('FC2/bias', (hyper.output_dim, ),
                            tf.random_uniform_initializer(-.2, .2)) 
开发者ID:Aetf,项目名称:tensorflow-tbcnn,代码行数:27,代码来源:config.py

示例3: test_tied_weights_untied_bias_registered_bias

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_tied_weights_untied_bias_registered_bias(self):
    """Tests that ambiguity in graph raises value error.

    Graph search will find several possible registrations for tensors.
    In this registering b_1 as a linked variable will result in an error
    because there will remain an ambiguity on the other branch of the graph.
    """
    with tf.Graph().as_default():
      tensor_dict = _build_model()

      layer_collection = lc.LayerCollection()
      layer_collection.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection.register_squared_error_loss(tensor_dict['out_1'])

      layer_collection.define_linked_parameters((tensor_dict['b_1']))

      with self.assertRaises(gs.AmbiguousRegistrationError):
        gs.register_layers(layer_collection,
                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) 
开发者ID:tensorflow,项目名称:kfac,代码行数:21,代码来源:graph_search_test.py

示例4: mixed_usage_test

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def mixed_usage_test(self):
    """Tests that graph search raises error on mixed types usage for tensors.

    Tensors can be reused in various locations in the tensorflow graph. This
    occurs regularly in the case of recurrent models or models with parallel
    graphs. However the tensors must be used for the same operation in each
    location or graph search should raise an error.
    """
    with tf.Graph().as_default():
      w = tf.get_variable('W', [10, 10])
      x = tf.placeholder(tf.float32, shape=(32, 10))
      y = tf.placeholder(tf.float32, shape=(32, 10, 10))

      out_0 = tf.matmul(x, w)  # pylint: disable=unused-variable
      out_1 = y + w  # pylint: disable=unused-variable

      layer_collection = lc.LayerCollection()

      with self.assertRaises(ValueError) as cm:
        gs.register_layers(layer_collection,
                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

      self.assertIn('mixed record types', str(cm.exception)) 
开发者ID:tensorflow,项目名称:kfac,代码行数:25,代码来源:graph_search_test.py

示例5: import_ops

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def import_ops(self):
    """Imports ops from collections."""
    if self._is_training:
      self._train_op = tf.get_collection_ref("train_op")[0]
      self._lr = tf.get_collection_ref("lr")[0]
      self._new_lr = tf.get_collection_ref("new_lr")[0]
      self._lr_update = tf.get_collection_ref("lr_update")[0]
      rnn_params = tf.get_collection_ref("rnn_params")
      if self._cell and rnn_params:
        params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
            self._cell,
            self._cell.params_to_canonical,
            self._cell.canonical_to_params,
            rnn_params,
            base_variable_scope="Model/RNN")
        tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
    self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0]
    num_replicas = FLAGS.num_gpus if self._name == "Train" else 1
    self._initial_state = util.import_state_tuples(
        self._initial_state, self._initial_state_name, num_replicas)
    self._final_state = util.import_state_tuples(
        self._final_state, self._final_state_name, num_replicas) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:24,代码来源:ptb_word_lm.py

示例6: import_ops

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def import_ops(self):
		"""Imports ops from collections."""
		if self._is_training:
			self._train_op = tf.get_collection_ref("train_op")[0]
			self._lr = tf.get_collection_ref("lr")[0]
			self._new_lr = tf.get_collection_ref("new_lr")[0]
			self._lr_update = tf.get_collection_ref("lr_update")[0]
			rnn_params = tf.get_collection_ref("rnn_params")
			if self._cell and rnn_params:
				params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable(
					self._cell,
					self._cell.params_to_canonical,
					self._cell.canonical_to_params,
					rnn_params,
					base_variable_scope="Model/RNN")
				tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
		self._cost = tf.get_collection_ref(tf_util.with_prefix(self._name, "cost"))[0]
		self._kl_div = tf.get_collection_ref(tf_util.with_prefix(self._name, "kl_div"))[0]
		num_replicas = 1
		self._initial_state = tf_util.import_state_tuples(
			self._initial_state, self._initial_state_name, num_replicas)
		self._final_state = tf_util.import_state_tuples(
			self._final_state, self._final_state_name, num_replicas) 
开发者ID:mirceamironenco,项目名称:BayesianRecurrentNN,代码行数:25,代码来源:bayesian_rnn.py

示例7: load_model

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def load_model(sess, checkpoint_path):
    meta_graph_location = checkpoint_path + '.meta'

    saver = tf.train.import_meta_graph(
        meta_graph_location, clear_devices=True, import_scope='m2'
    )

    saver.restore(sess, checkpoint_path)

    sess.run(
        set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES))
    ) 
开发者ID:devicehive,项目名称:devicehive-audio-analysis,代码行数:14,代码来源:model.py

示例8: global_mode

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def global_mode():
    """Returns the Tensor of global mode.

    This is a placeholder with default value of
    :tf_main:`tf.estimator.ModeKeys.TRAIN <estimator/ModeKeys>`.

    Example:

        .. code-block:: python

            mode = session.run(global_mode())
            # mode == tf.estimator.ModeKeys.TRAIN

            mode = session.run(
                global_mode(),
                feed_dict={tf.global_mode(): tf.estimator.ModeKeys.PREDICT})
            # mode == tf.estimator.ModeKeys.PREDICT
    """
    mode = tf.get_collection_ref(_GLOBAL_MODE_KEY)
    if len(mode) < 1:
        #mode_tensor = tf.placeholder(tf.string, name="global_mode")
        mode_tensor = tf.placeholder_with_default(
            input=tf.estimator.ModeKeys.TRAIN,
            shape=(),
            name="global_mode")
        #mode_tensor = tf.constant(
        #    value=tf.estimator.ModeKeys.TRAIN,
        #    dtype=tf.string,
        #    name="global_mode")
        mode.append(mode_tensor)
    return mode[0] 
开发者ID:qkaren,项目名称:Counterfactual-StoryRW,代码行数:33,代码来源:context.py

示例9: build_graph

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def build_graph(self, trainer, test_mode=False):
        if test_mode:
            self.saver = tf.train.Saver(max_to_keep=3)
        else:
            self._build_numerical_summaries(trainer)
            self._build_img_summaries(trainer)

            if trainer.train_extrap:
                # Extrap training has an additional Adam optimizer with parameters not existed in the ckpt
                ckpt_dir = self.force_load_from_dir if self.force_load_from_dir else self.ckpt_dir
                ckpt_vars = set([v[0] for v in tf.train.list_variables(ckpt_dir)])
                restore_var = [v for v in tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) if v.op.name in ckpt_vars]
                self.saver = tf.train.Saver(max_to_keep=3, var_list=restore_var)
            else:
                self.saver = tf.train.Saver(max_to_keep=3) 
开发者ID:hubert0527,项目名称:COCO-GAN,代码行数:17,代码来源:logger.py

示例10: restore_collection

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def restore_collection(backup):
    for k, v in six.iteritems(backup):
        del tf.get_collection_ref(k)[:]
        tf.get_collection_ref(k).extend(v) 
开发者ID:anonymous-author1,项目名称:DDRL,代码行数:6,代码来源:common.py

示例11: clear_collection

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def clear_collection(keys):
    for k in keys:
        del tf.get_collection_ref(k)[:] 
开发者ID:anonymous-author1,项目名称:DDRL,代码行数:5,代码来源:common.py

示例12: test_multitower_multi_loss_function

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_multitower_multi_loss_function(self):
    """Test multitower setup with multiple loss functions.

    The automatic graph scanner should handle multiple loss functions per tower,
    as long as they're registered in a consistent order.
    """
    with tf.Graph().as_default():
      w_1 = tf.get_variable('w_1', shape=[10, 10])
      b_1 = tf.get_variable('b_1', shape=[10])
      w_2 = tf.get_variable('w_2', shape=[10, 10])
      b_2 = tf.get_variable('b_2', shape=[10])
      layer_collection = lc.LayerCollection()
      layer_collection_manual = lc.LayerCollection()
      for tower_num in range(5):
        x = tf.placeholder(tf.float32, shape=(32, 10))
        logits_1 = tf.matmul(x, w_1) + b_1
        logits_2 = tf.matmul(x, w_2) + b_2
        if tower_num == 0:
          reuse = False
        else:
          reuse = True
        with tf.variable_scope('tower%d' % tower_num, reuse=reuse):
          for l in [layer_collection, layer_collection_manual]:
            l.register_categorical_predictive_distribution(
                logits_1, name='loss_1')
            l.register_categorical_predictive_distribution(
                logits_2, name='loss_2')
          layer_collection_manual.register_fully_connected((w_1, b_1), x,
                                                           logits_1)
          layer_collection_manual.register_fully_connected((w_2, b_2), x,
                                                           logits_2)

      gs.register_layers(layer_collection,
                         tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

      assert_fisher_blocks_match(self, layer_collection,
                                 layer_collection_manual) 
开发者ID:tensorflow,项目名称:kfac,代码行数:39,代码来源:graph_search_test.py

示例13: test_graph_search_match_fail

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_graph_search_match_fail(self):
    """Tests graph search with linked bias tensors.

    In this code snippet two non adjacent bias tensors are linked together.
    There is no fisher block in kfac that matches this configuration, so the
    biases should not be registered.
    """
    with tf.Graph().as_default():
      tensor_dict = _build_model()

      layer_collection = lc.LayerCollection()
      layer_collection.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection.register_squared_error_loss(tensor_dict['out_1'])

      # TODO(b/69055612): remove this manual registration once layer_collection
      # implements register_fully_connected_multi.
      layer_collection.register_fully_connected(
          tensor_dict['w'], tensor_dict['x'], tensor_dict['pre_bias_0'])
      layer_collection.define_linked_parameters((tensor_dict['b_0'],
                                                 tensor_dict['b_1']))

      with self.assertRaises(ValueError) as cm:
        gs.register_layers(layer_collection,
                           tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))

      self.assertIn('in linked group', str(cm.exception))
      self.assertIn('was not matched', str(cm.exception))
      self.assertIn(
          str(frozenset([tensor_dict['b_0'], tensor_dict['b_1']])),
          str(cm.exception)) 
开发者ID:tensorflow,项目名称:kfac,代码行数:32,代码来源:graph_search_test.py

示例14: test_specify_approximation_shared_parameters

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_specify_approximation_shared_parameters(self):
    """Test specifying approximations with layers containing shared parameters.

    If linked parameters are identified along with an approximation, then
    that approximation should be used when registering those parameters.
    """
    with tf.Graph().as_default():
      tensor_dict = _build_model()

      layer_collection = lc.LayerCollection()
      layer_collection.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection.register_squared_error_loss(tensor_dict['out_1'])

      layer_collection.define_linked_parameters(
          tensor_dict['w'], approximation=lc.APPROX_KRONECKER_INDEP_NAME)
      layer_collection.define_linked_parameters(
          tensor_dict['b_0'], approximation=lc.APPROX_DIAGONAL_NAME)
      layer_collection.define_linked_parameters(
          tensor_dict['b_1'], approximation=lc.APPROX_FULL_NAME)

      gs.register_layers(
          layer_collection,
          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),
          batch_size=1)

      self.assertIsInstance(layer_collection.fisher_blocks[tensor_dict['w']],
                            fb.FullyConnectedMultiIndepFB)
      self.assertIsInstance(
          layer_collection.fisher_blocks[tensor_dict['b_0']],
          fb.NaiveDiagonalFB)
      self.assertIsInstance(
          layer_collection.fisher_blocks[tensor_dict['b_1']], fb.FullFB) 
开发者ID:tensorflow,项目名称:kfac,代码行数:34,代码来源:graph_search_test.py

示例15: test_tied_weights_untied_bias_registered_weights

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import get_collection_ref [as 别名]
def test_tied_weights_untied_bias_registered_weights(self):
    """Tests that graph search produces right solution on toy model."""
    with tf.Graph().as_default():
      tensor_dict = _build_model()

      layer_collection_manual = lc.LayerCollection()
      layer_collection_manual.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection_manual.register_squared_error_loss(tensor_dict['out_1'])

      layer_collection_manual.register_fully_connected_multi(
          tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']),
          (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1']))
      layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1)
      layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1)

      layer_collection = lc.LayerCollection()
      layer_collection.register_squared_error_loss(tensor_dict['out_0'])
      layer_collection.register_squared_error_loss(tensor_dict['out_1'])

      layer_collection.define_linked_parameters((tensor_dict['w']))
      gs.register_layers(
          layer_collection,
          tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES),
          batch_size=1)

      assert_fisher_blocks_match(self, layer_collection,
                                 layer_collection_manual) 
开发者ID:tensorflow,项目名称:kfac,代码行数:29,代码来源:graph_search_test.py


注:本文中的tensorflow.get_collection_ref方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。