当前位置: 首页>>代码示例>>Python>>正文


Python v1.get_collection方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.get_collection方法的典型用法代码示例。如果您正苦于以下问题:Python v1.get_collection方法的具体用法?Python v1.get_collection怎么用?Python v1.get_collection使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.get_collection方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testVariablesSetDeviceMobileModel

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testVariablesSetDeviceMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
    inputs = tf.random_uniform((batch_size, height, width, 3))
    tf.train.create_global_step()
    # Force all Variables to reside on the device.
    with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
    with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
      self.assertDeviceEqual(v.device, '/cpu:0')
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
      self.assertDeviceEqual(v.device, '/gpu:0') 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:19,代码来源:nasnet_test.py

示例2: _load_checkpoint

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def _load_checkpoint(checkpoint_filename, extra_vars, trainable_only=False):
  if tf.gfile.IsDirectory(checkpoint_filename):
    checkpoint_filename = tf.train.latest_checkpoint(checkpoint_filename)
  logging.info('Loading checkpoint %s', checkpoint_filename)
  saveables = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
               tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
  if trainable_only:
    saveables = list(set(saveables) & set(tf.trainable_variables()))
  # Try to restore all saveables, if that fails try without extra_vars.
  try:
    saver = tf.train.Saver(var_list=saveables)
    saver.restore(tf.get_default_session(), checkpoint_filename)
  except (ValueError, tf.errors.NotFoundError):
    logging.info('Missing key in checkpoint. Trying old checkpoint format.')
    saver = tf.train.Saver(var_list=list(set(saveables) - set(extra_vars)))
    saver.restore(tf.get_default_session(), checkpoint_filename) 
开发者ID:deepmind,项目名称:lamb,代码行数:18,代码来源:training.py

示例3: estimator_spec_eval

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
    """Constructs `tf.estimator.EstimatorSpec` for EVAL (evaluation) mode."""
    estimator_spec = super(TransformerAE, self).estimator_spec_eval(
        features, logits, labels, loss, losses_dict)
    if common_layers.is_xla_compiled():
      # For TPUs (and XLA more broadly?), do not add summary hooks that depend
      # on losses; they are not supported.
      return estimator_spec

    summary_op = tf.get_collection(tf.GraphKeys.SUMMARIES, scope="losses")
    summary_op.extend(tf.get_collection(tf.GraphKeys.SUMMARIES, scope="loss"))
    summary_op.append(tf.summary.scalar("loss", loss))
    summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=100,
        summary_op=summary_op,
        output_dir=os.path.join(self.hparams.model_dir, "eval"))

    hooks = list(estimator_spec.evaluation_hooks)
    hooks.append(summary_saver_hook)
    return estimator_spec._replace(evaluation_hooks=hooks) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:22,代码来源:transformer_vae.py

示例4: define_train_ops

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def define_train_ops(gan_model, gan_loss, **kwargs):
  """Defines progressive GAN train ops.

  Args:
    gan_model: A `GANModel` namedtuple.
    gan_loss: A `GANLoss` namedtuple.
    **kwargs: A dictionary of
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
    of optimizers.
  """
  with tf.variable_scope('progressive_gan_train_ops') as var_scope:
    beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
    gen_opt = tf.train.AdamOptimizer(kwargs['generator_learning_rate'], beta1,
                                     beta2)
    dis_opt = tf.train.AdamOptimizer(kwargs['discriminator_learning_rate'],
                                     beta1, beta2)
    gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt)
  return gan_train_ops, tf.get_collection(
      tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name) 
开发者ID:magenta,项目名称:magenta,代码行数:27,代码来源:train_util.py

示例5: testBatchNormUpdateImproveStatistics

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testBatchNormUpdateImproveStatistics(self):
    """Test that updating the moving_mean improves statistics."""
    _, _, inputs = _get_inputs()
    # Use small decay_rate to update faster.
    bn = ibp.BatchNorm(offset=False, scale=False, decay_rate=0.1,
                       update_ops_collection=tf.GraphKeys.UPDATE_OPS)
    out1 = bn(inputs, is_training=False)
    # Build the update ops.
    bn(inputs, is_training=True)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      out_v = sess.run(out1)
      # Before updating the moving_mean the results are off.
      self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 2, 5)
      sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
      # After updating the moving_mean the results are better.
      out_v = sess.run(out1)
      self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 1, 2) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:21,代码来源:layers_test.py

示例6: initialize

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def initialize(self):
    """Initialize the teacher model from the checkpoint.

    This function will be called after the graph has been constructed.
    """
    if self.fraction_soft == 0.0:
      # Do nothing if we do not need the teacher.
      return
    vars_to_restore = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope="teacher")
    tf.train.init_from_checkpoint(
        self.teacher_checkpoint,
        {v.name[len("teacher/"):].split(":")[0]: v for v in vars_to_restore})


# gin-configurable constructors 
开发者ID:tensorflow,项目名称:mesh,代码行数:18,代码来源:transformer.py

示例7: restore_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def restore_model(sess, checkpoint_path, enable_ema=True):
  """Restore variables from the checkpoint into the provided session.

  Args:
    sess: A tensorflow session where the checkpoint will be loaded.
    checkpoint_path: Path to the trained checkpoint.
    enable_ema: (optional) Whether to load the exponential moving average (ema)
      version of the tensorflow variables. Defaults to True.
  """
  if enable_ema:
    ema = tf.train.ExponentialMovingAverage(decay=0.0)
    ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
    for v in tf.global_variables():
      if "moving_mean" in v.name or "moving_variance" in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))
    var_dict = ema.variables_to_restore(ema_vars)
  else:
    var_dict = None

  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(var_dict, max_to_keep=1)
  saver.restore(sess, checkpoint_path) 
开发者ID:tensorflow,项目名称:models,代码行数:25,代码来源:post_training_quantization.py

示例8: testVariablesSetDevice

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testVariablesSetDevice(self):
    batch_size = 5
    height, width = 299, 299
    num_classes = 1000
    with self.test_session():
      inputs = tf.random.uniform((batch_size, height, width, 3))
      # Force all Variables to reside on the device.
      with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
        inception.inception_resnet_v2(inputs, num_classes)
      with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
        inception.inception_resnet_v2(inputs, num_classes)
      for v in tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
        self.assertDeviceEqual(v.device, '/cpu:0')
      for v in tf.get_collection(
          tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
        self.assertDeviceEqual(v.device, '/gpu:0') 
开发者ID:tensorflow,项目名称:models,代码行数:19,代码来源:inception_resnet_v2_test.py

示例9: testVariablesSetDeviceMobileModel

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testVariablesSetDeviceMobileModel(self):
    batch_size = 5
    height, width = 224, 224
    num_classes = 1000
    inputs = tf.random.uniform((batch_size, height, width, 3))
    tf.train.create_global_step()
    # Force all Variables to reside on the device.
    with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
    with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
      with slim.arg_scope(nasnet.nasnet_mobile_arg_scope()):
        nasnet.build_nasnet_mobile(inputs, num_classes)
    for v in tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
      self.assertDeviceEqual(v.device, '/cpu:0')
    for v in tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
      self.assertDeviceEqual(v.device, '/gpu:0') 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:nasnet_test.py

示例10: testVariablesSetDevice

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testVariablesSetDevice(self):
    batch_size = 5
    height, width = 299, 299
    num_classes = 1000
    inputs = tf.random.uniform((batch_size, height, width, 3))
    # Force all Variables to reside on the device.
    with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
      inception.inception_v4(inputs, num_classes)
    with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
      inception.inception_v4(inputs, num_classes)
    for v in tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
      self.assertDeviceEqual(v.device, '/cpu:0')
    for v in tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
      self.assertDeviceEqual(v.device, '/gpu:0') 
开发者ID:tensorflow,项目名称:models,代码行数:18,代码来源:inception_v4_test.py

示例11: testCreateLogisticClassifier

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testCreateLogisticClassifier(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = LogisticClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      clone = clones[0]
      self.assertEqual(len(slim.get_variables()), 2)
      for v in slim.get_variables():
        self.assertDeviceEqual(v.device, 'CPU:0')
        self.assertDeviceEqual(v.value().device, 'CPU:0')
      self.assertEqual(clone.outputs.op.name,
                       'LogisticClassifier/fully_connected/Sigmoid')
      self.assertEqual(clone.scope, '')
      self.assertDeviceEqual(clone.device, 'GPU:0')
      self.assertEqual(len(slim.losses.get_losses()), 1)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(update_ops, []) 
开发者ID:tensorflow,项目名称:models,代码行数:27,代码来源:model_deploy_test.py

示例12: testCreateSingleclone

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def testCreateSingleclone(self):
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      model_fn = BatchNormClassifier
      clone_args = (tf_inputs, tf_labels)
      deploy_config = model_deploy.DeploymentConfig(num_clones=1)

      self.assertEqual(slim.get_variables(), [])
      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
      self.assertEqual(len(slim.get_variables()), 5)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      self.assertEqual(len(update_ops), 2)

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
                                                                optimizer)
      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
      self.assertEqual(total_loss.op.name, 'total_loss')
      for g, v in grads_and_vars:
        self.assertDeviceEqual(g.device, 'GPU:0')
        self.assertDeviceEqual(v.device, 'CPU:0') 
开发者ID:tensorflow,项目名称:models,代码行数:27,代码来源:model_deploy_test.py

示例13: _get_variables_to_train

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def _get_variables_to_train():
  """Returns a list of variables to train.

  Returns:
    A list of variables to train by the optimizer.
  """
  if FLAGS.trainable_scopes is None:
    return tf.trainable_variables()
  else:
    scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]

  variables_to_train = []
  for scope in scopes:
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
    variables_to_train.extend(variables)
  return variables_to_train 
开发者ID:tensorflow,项目名称:models,代码行数:18,代码来源:train_image_classifier.py

示例14: test_expected_calibration_error_all_bins_filled

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def test_expected_calibration_error_all_bins_filled(self):
    """Test expected calibration error when all bins contain predictions."""
    y_true, y_pred = self._get_calibration_placeholders()
    expected_ece_op, update_op = calibration_metrics.expected_calibration_error(
        y_true, y_pred, nbins=2)
    with self.test_session() as sess:
      metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
      sess.run(tf.variables_initializer(var_list=metrics_vars))
      # Bin calibration errors (|confidence - accuracy| * bin_weight):
      # - [0,0.5): |0.2 - 0.333| * (3/5) = 0.08
      # - [0.5, 1]: |0.75 - 0.5| * (2/5) = 0.1
      sess.run(
          update_op,
          feed_dict={
              y_pred: np.array([0., 0.2, 0.4, 0.5, 1.0]),
              y_true: np.array([0, 0, 1, 0, 1])
          })
    actual_ece = 0.08 + 0.1
    expected_ece = sess.run(expected_ece_op)
    self.assertAlmostEqual(actual_ece, expected_ece) 
开发者ID:tensorflow,项目名称:models,代码行数:22,代码来源:calibration_metrics_tf1_test.py

示例15: test_expected_calibration_error_all_bins_not_filled

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import get_collection [as 别名]
def test_expected_calibration_error_all_bins_not_filled(self):
    """Test expected calibration error when no predictions for one bin."""
    y_true, y_pred = self._get_calibration_placeholders()
    expected_ece_op, update_op = calibration_metrics.expected_calibration_error(
        y_true, y_pred, nbins=2)
    with self.test_session() as sess:
      metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
      sess.run(tf.variables_initializer(var_list=metrics_vars))
      # Bin calibration errors (|confidence - accuracy| * bin_weight):
      # - [0,0.5): |0.2 - 0.333| * (3/5) = 0.08
      # - [0.5, 1]: |0.75 - 0.5| * (2/5) = 0.1
      sess.run(
          update_op,
          feed_dict={
              y_pred: np.array([0., 0.2, 0.4]),
              y_true: np.array([0, 0, 1])
          })
    actual_ece = np.abs(0.2 - (1 / 3.))
    expected_ece = sess.run(expected_ece_op)
    self.assertAlmostEqual(actual_ece, expected_ece) 
开发者ID:tensorflow,项目名称:models,代码行数:22,代码来源:calibration_metrics_tf1_test.py


注:本文中的tensorflow.compat.v1.get_collection方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。