当前位置: 首页>>代码示例>>Python>>正文


Python v1.trainable_variables方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.trainable_variables方法的典型用法代码示例。如果您正苦于以下问题:Python v1.trainable_variables方法的具体用法?Python v1.trainable_variables怎么用?Python v1.trainable_variables使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.trainable_variables方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: trainable_variables_on_device

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def trainable_variables_on_device(self, rel_device_num, abs_device_num,
                                    writable):
    """Return the set of trainable variables on the specified device.

    Args:
      rel_device_num: local worker device index.
      abs_device_num: global graph device index.
      writable: whether the returned variables is writable or read-only.

    Returns:
      Return the set of trainable variables on the specified device.
    """
    del abs_device_num
    params_refs = tf.trainable_variables()
    if writable:
      return params_refs
    params = []
    for param in params_refs:
      var_name = param.name.split(':')[0]
      _, var_get_op = self.variable_mgr.staging_vars_on_devices[rel_device_num][
          var_name]
      params.append(var_get_op)
    return params 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:25,代码来源:variable_mgr_util.py

示例2: trainable_variables_on_device

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def trainable_variables_on_device(self,
                                    rel_device_num,
                                    abs_device_num,
                                    writable=False):
    """Return the set of trainable variables on device.

    Args:
      rel_device_num: local worker device index.
      abs_device_num: global graph device index.
      writable: whether to get a reference to the underlying variable.

    Returns:
      The set of trainable variables on the specified device.
    """
    del rel_device_num, writable
    if self.each_tower_has_variables():
      params = [
          v for v in tf.trainable_variables()
          if v.name.startswith('v%s/' % abs_device_num)
      ]
    else:
      params = tf.trainable_variables()
    return params 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:25,代码来源:variable_mgr.py

示例3: savable_variables

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/')
              or v.name in ('global_step:0', 'loss_scale:0',
                            'loss_scale_normal_steps:0')), (
                                'Invalid global variable: %s' % v)
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:24,代码来源:variable_mgr.py

示例4: find_var

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def find_var(name, vars_=None):
  """Find a variable by name or return None.

  Args:
    name: The name of the variable (full qualified with all
      enclosing scopes).
    vars_: The variables among which to search. Defaults to all
      trainable variables.

  Returns:
    The [first] variable with `name` among `vars_` or None if there
    is no match.
  """
  if vars_ is None:
    vars_ = tf.trainable_variables()
  return next((var for var in vars_ if var.name == name),
              None) 
开发者ID:deepmind,项目名称:lamb,代码行数:19,代码来源:utils.py

示例5: _load_checkpoint

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
  if tf.gfile.IsDirectory(checkpoint_filename):
    checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
  logging.info('Loading checkpoint %s', checkpoint_filename)
  saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
               tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
  if trainable_only:
    saveables = list(set(saveables) & set(tf.trainable_variables()))
  # Try to restore all saveables, if that fails try without extra_vars.
  try:
    saver = tf.train.Saver(var_list=saveables)
    saver.restore(tf.get_default_session(), checkpoint_filename)
  except (ValueError, tf.errors.NotFoundError):
    logging.info('Missing key in checkpoint. Trying old checkpoint format.')
    saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
    saver.restore(tf.get_default_session(), checkpoint_filename) 
开发者ID:deepmind,项目名称:lamb,代码行数:18,代码来源:training.py

示例6: weight_decay_and_noise

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""
  if var_list is None:
    var_list = tf.trainable_variables()

  decay_vars = [v for v in var_list]
  noise_vars = [v for v in var_list if "/body/" in v.name]

  weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars)
  if hparams.weight_decay and common_layers.should_generate_summaries():
    tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate,
                                  noise_vars)

  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:21,代码来源:optimize.py

示例7: summarize_variables

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def summarize_variables(var_list=None, tag=None):
  """Summarize the variables.

  Args:
    var_list: a list of variables; defaults to trainable_variables.
    tag: name scope of the summary; defaults to training_variables/.
  """
  if var_list is None:
    var_list = tf.trainable_variables()
  if tag is None:
    tag = "training_variables/"

  name_to_var = {v.name: v for v in var_list}
  for v_name in list(name_to_var):
    v = name_to_var[v_name]
    tf.summary.histogram(tag + v_name, v) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:18,代码来源:optimize.py

示例8: test_adam

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def test_adam(self):
    with self.test_session() as sess:
      w = tf.get_variable(
          "w",
          shape=[3],
          initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
      x = tf.constant([0.4, 0.2, -0.5])
      loss = tf.reduce_mean(tf.square(x - w))
      tvars = tf.trainable_variables()
      grads = tf.gradients(loss, tvars)
      global_step = tf.train.get_or_create_global_step()
      optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
      train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step)
      init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
      sess.run(init_op)
      for _ in range(100):
        sess.run(train_op)
      w_np = sess.run(w)
      self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 
开发者ID:google-research,项目名称:albert,代码行数:22,代码来源:optimization_test.py

示例9: testCompatibleNames

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def testCompatibleNames(self):
    with self.session(use_gpu=True, graph=tf.Graph()):
      cell = rnn_cell.LSTMCell(10)
      pcell = rnn_cell.LSTMCell(10, use_peepholes=True)
      inputs = [tf.zeros([4, 5])] * 6
      tf.nn.static_rnn(cell, inputs, dtype=tf.float32, scope="basic")
      tf.nn.static_rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
      basic_names = {
          v.name: v.get_shape()
          for v in tf.trainable_variables()
      }

    with self.session(use_gpu=True, graph=tf.Graph()):
      cell = contrib_rnn.LSTMBlockCell(10)
      pcell = contrib_rnn.LSTMBlockCell(10, use_peephole=True)
      inputs = [tf.zeros([4, 5])] * 6
      tf.nn.static_rnn(cell, inputs, dtype=tf.float32, scope="basic")
      tf.nn.static_rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
      block_names = {
          v.name: v.get_shape()
          for v in tf.trainable_variables()
      }

    self.assertEqual(basic_names, block_names) 
开发者ID:magenta,项目名称:magenta,代码行数:26,代码来源:rnn_test.py

示例10: _test_model_params

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def _test_model_params(self,
                         model_name,
                         input_size,
                         expected_params,
                         override_params=None,
                         features_only=False,
                         pooled_features_only=False):
    images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
    efficientnet_builder.build_model(
        images,
        model_name=model_name,
        override_params=override_params,
        training=True,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    self.assertEqual(num_params, expected_params) 
开发者ID:JunweiLiang,项目名称:Object_Detection_Tracking,代码行数:19,代码来源:efficientnet_builder_test.py

示例11: _test_model_params

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def _test_model_params(self,
                         model_name,
                         input_size,
                         expected_params,
                         override_params=None,
                         features_only=False,
                         pooled_features_only=False):
    images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
    efficientnet_lite_builder.build_model(
        images,
        model_name=model_name,
        override_params=override_params,
        training=True,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])

    self.assertEqual(num_params, expected_params) 
开发者ID:JunweiLiang,项目名称:Object_Detection_Tracking,代码行数:20,代码来源:efficientnet_lite_builder_test.py

示例12: test_inner_loop_reuse

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def test_inner_loop_reuse(self, learn_inner_lr):
    # Inner loop should create as many trainable vars in 'inner_loop' scope as a
    # direct call to inference_network_fn would. Learned learning rates and
    # learned loss variables should be created *outside* the 'inner_loop' scope
    # since they do not adapt.
    graph = tf.Graph()
    with tf.Session(graph=graph):
      inputs = create_inputs()
      features, _ = inputs
      # Record how many trainable vars a call to inference_network_fn creates.
      with tf.variable_scope('test_scope'):
        inference_network_fn(features)
      expected_num_train_vars = len(tf.trainable_variables(scope='test_scope'))
      maml_inner_loop_instance = maml_inner_loop.MAMLInnerLoopGradientDescent(
          learning_rate=LEARNING_RATE, learn_inner_lr=learn_inner_lr)
      maml_inner_loop_instance.inner_loop(
          [inputs, inputs, inputs],
          inference_network_fn,
          learned_model_train_fn)
      num_train_vars = len(tf.trainable_variables(scope='inner_loop'))
      self.assertEqual(expected_num_train_vars, num_train_vars) 
开发者ID:google-research,项目名称:tensor2robot,代码行数:23,代码来源:maml_inner_loop_test.py

示例13: initialize_networks

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def initialize_networks(self):

        model_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(model_vars)

        # Set up directory for saving models
        self.model_dir = os.getcwd() + '/models'
        self.model_loc = self.model_dir + '/HAC.ckpt'

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

         # Initialize actor/critic networks
        self.sess.run(tf.global_variables_initializer())

        # If not retraining, restore weights
        # if we are not retraining from scratch, just restore weights
        if self.FLAGS.retrain == False:
            self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_dir))


    # Save neural network parameters 
开发者ID:andrew-j-levy,项目名称:Hierarchical-Actor-Critc-HAC-,代码行数:24,代码来源:agent.py

示例14: restore_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def restore_model(sess, checkpoint_path, enable_ema=True):
  """Restore variables from the checkpoint into the provided session.

  Args:
    sess: A tensorflow session where the checkpoint will be loaded.
    checkpoint_path: Path to the trained checkpoint.
    enable_ema: (optional) Whether to load the exponential moving average (ema)
      version of the tensorflow variables. Defaults to True.
  """
  if enable_ema:
    ema = tf.train.ExponentialMovingAverage(decay=0.0)
    ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
    for v in tf.global_variables():
      if "moving_mean" in v.name or "moving_variance" in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))
    var_dict = ema.variables_to_restore(ema_vars)
  else:
    var_dict = None

  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(var_dict, max_to_keep=1)
  saver.restore(sess, checkpoint_path) 
开发者ID:tensorflow,项目名称:models,代码行数:25,代码来源:post_training_quantization.py

示例15: testCreateLogisticClassifier

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import trainable_variables [as 别名]
def testCreateLogisticClassifier(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = LogisticClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 2)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(update_ops, [])

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, 'GPU:0')
        self.assertDeviceEqual(v.device, 'CPU:0') 
开发者ID:tensorflow,项目名称:models,代码行数:27,代码来源:model_deploy_test.py


注:本文中的tensorflow.compat.v1.trainable_variables方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。