当前位置: 首页>>代码示例>>Python>>正文


Python v1.global_variables方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.global_variables方法的典型用法代码示例。如果您正苦于以下问题:Python v1.global_variables方法的具体用法?Python v1.global_variables怎么用?Python v1.global_variables使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.global_variables方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_post_init_ops

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:20,代码来源:variable_mgr.py

示例2: savable_variables

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/')
              or v.name in ('global_step:0', 'loss_scale:0',
                            'loss_scale_normal_steps:0')), (
                                'Invalid global variable: %s' % v)
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params 
开发者ID:tensorflow,项目名称:benchmarks,代码行数:24,代码来源:variable_mgr.py

示例3: evaluate

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def evaluate(self, env_fn, hparams, sampling_temp):
    with tf.Graph().as_default():
      with tf.name_scope("rl_eval"):
        eval_env = env_fn(in_graph=True)
        (collect_memory, _, collect_init) = _define_collect(
            eval_env,
            hparams,
            "ppo_eval",
            eval_phase=True,
            frame_stack_size=self.frame_stack_size,
            force_beginning_resets=False,
            sampling_temp=sampling_temp,
            distributional_size=self._distributional_size,
        )
        model_saver = tf.train.Saver(
            tf.global_variables(hparams.policy_network + "/.*")
            # tf.global_variables("clean_scope.*")  # Needed for sharing params.
        )

        with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          collect_init(sess)
          trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver,
                                         sess)
          sess.run(collect_memory) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:27,代码来源:ppo_learner.py

示例4: __init__

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def __init__(self, hparams, action_space, observation_space, policy_dir):
    assert hparams.base_algo == "ppo"
    ppo_hparams = trainer_lib.create_hparams(hparams.base_algo_params)

    frame_stack_shape = (1, hparams.frame_stack_size) + observation_space.shape
    self._frame_stack = np.zeros(frame_stack_shape, dtype=np.uint8)

    with tf.Graph().as_default():
      self.obs_t = tf.placeholder(shape=self.frame_stack_shape, dtype=np.uint8)
      self.logits_t, self.value_function_t = get_policy(
          self.obs_t, ppo_hparams, action_space
      )
      model_saver = tf.train.Saver(
          tf.global_variables(scope=ppo_hparams.policy_network + "/.*")  # pylint: disable=unexpected-keyword-arg
      )
      self.sess = tf.Session()
      self.sess.run(tf.global_variables_initializer())
      trainer_lib.restore_checkpoint(policy_dir, model_saver,
                                     self.sess) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:21,代码来源:player_utils.py

示例5: __init__

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def __init__(
      self, batch_size, observation_space, action_space, policy_hparams,
      policy_dir, sampling_temp
  ):
    super(PolicyAgent, self).__init__(
        batch_size, observation_space, action_space
    )
    self._sampling_temp = sampling_temp
    with tf.Graph().as_default():
      self._observations_t = tf.placeholder(
          shape=((batch_size,) + self.observation_space.shape),
          dtype=self.observation_space.dtype
      )
      (logits, self._values_t) = rl.get_policy(
          self._observations_t, policy_hparams, self.action_space
      )
      actions = common_layers.sample_with_temperature(logits, sampling_temp)
      self._probs_t = tf.nn.softmax(logits / sampling_temp)
      self._actions_t = tf.cast(actions, tf.int32)
      model_saver = tf.train.Saver(
          tf.global_variables(policy_hparams.policy_network + "/.*")  # pylint: disable=unexpected-keyword-arg
      )
      self._sess = tf.Session()
      self._sess.run(tf.global_variables_initializer())
      trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:27,代码来源:rl_utils.py

示例6: testVarNames

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def testVarNames(self):
    with tf.Graph().as_default():
      model, features = get_model(
          mode=tf.estimator.ModeKeys.PREDICT,
          model_cls=transformer.TransformerScorer)
      _ = model.infer(features)
      scorer_vars = [v.name for v in tf.global_variables()]

    with tf.Graph().as_default():
      model, features = get_model(
          mode=tf.estimator.ModeKeys.EVAL,
          model_cls=transformer.TransformerScorer)
      _ = model(features)
      scorer_eval_vars = [v.name for v in tf.global_variables()]

    with tf.Graph().as_default():
      model, features = get_model(
          mode=tf.estimator.ModeKeys.EVAL,
          model_cls=transformer.Transformer)
      _ = model(features)
      transformer_vars = [v.name for v in tf.global_variables()]

    self.assertEqual(sorted(scorer_vars), sorted(transformer_vars))
    self.assertEqual(sorted(scorer_eval_vars), sorted(transformer_vars)) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:26,代码来源:transformer_test.py

示例7: underlying_variable

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def underlying_variable(t):
  """Find the underlying tf.Variable object.

  Args:
    t: a Tensor

  Returns:
    tf.Variable.
  """
  t = underlying_variable_ref(t)
  assert t is not None
  # make sure that the graph has a variable index and that it is up-to-date
  if not hasattr(tf.get_default_graph(), "var_index"):
    tf.get_default_graph().var_index = {}
  var_index = tf.get_default_graph().var_index
  for v in tf.global_variables()[len(var_index):]:
    var_index[v.name] = v
  return var_index[t.name] 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:20,代码来源:common_layers.py

示例8: build_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def build_model(self):
    # Our test model is:
    #
    #         -> conv1 --+     -> conv3 -->
    #        /           |    /
    #  image          [concat]
    #        \           |    \
    #         -> conv2 --+     -> conv4 -->
    #
    # (the model has two "outputs", conv3 and conv4).
    #
    image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
    conv1 = slim.layers.conv2d(image, 13, [7, 5], padding='SAME', scope='conv1')
    conv2 = slim.layers.conv2d(image, 23, [1, 1], padding='SAME', scope='conv2')
    concat = tf.concat([conv1, conv2], 3)
    self.conv3 = slim.layers.conv2d(
        concat, 29, [3, 3], stride=2, padding='SAME', scope='conv3')
    self.conv4 = slim.layers.conv2d(
        concat, 31, [1, 1], stride=1, padding='SAME', scope='conv4')
    self.name_to_var = {v.op.name: v for v in tf.global_variables()}

    self.regularizer = latency_regularizer.GammaLatencyRegularizer(
        [self.conv3.op, self.conv4.op],
        gamma_threshold=0.45, hardware=HARDWARE) 
开发者ID:google-research,项目名称:morph-net,代码行数:26,代码来源:latency_regularizer_test.py

示例9: BuildModel

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def BuildModel(self):
    # Our test model is:
    #
    #         -> conv1 --+     -> conv3 -->
    #        /           |    /
    #  image          [concat]
    #        \           |    \
    #         -> conv2 --+     -> conv4 -->
    #
    # (the model has two "outputs", conv3 and conv4).
    #

    # op.name: 'Const'
    image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
    # op.name: 'conv1/Conv2D'
    self.conv1 = slim.layers.conv2d(
        image, 13, [7, 5], padding='SAME', scope='conv1')
    self.conv2 = slim.layers.conv2d(
        image, 23, [1, 1], padding='SAME', scope='conv2')
    self.concat = tf.concat([self.conv1, self.conv2], 3)
    self.conv3 = slim.layers.conv2d(
        self.concat, 29, [3, 3], stride=2, padding='SAME', scope='conv3')
    self.conv4 = slim.layers.conv2d(
        self.concat, 31, [1, 1], stride=1, padding='SAME', scope='conv4')
    self.name_to_var = {v.op.name: v for v in tf.global_variables()} 
开发者ID:google-research,项目名称:morph-net,代码行数:27,代码来源:flop_regularizer_test.py

示例10: testLossCostDecorated

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def testLossCostDecorated(self):
    params = {'trainable': True, 'normalizer_fn': slim.batch_norm,
              'normalizer_params': {'scale': True}}

    with slim.arg_scope([slim.layers.conv2d], **params):
      image = tf.constant(0.0, shape=[1, 1, 1, NUM_CHANNELS])
      conv1 = slim.layers.conv2d(
          image, 2, [1, 1], padding='SAME', scope='conv1')
    with self.cached_session():
      tf.global_variables_initializer().run()
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([1] * 2).eval()

    self.gamma_flop_reg = model_size_regularizer.GammaModelSizeRegularizer(
        [conv1.op],
        gamma_threshold=0.1,
        regularizer_decorator=dummy_decorator.DummyDecorator,
        decorator_parameters={'scale': 0.5})

    conv = self.get_conv('conv1')
    self.assertEqual(_coeff(conv) * 3 * 1, self.loss([conv]))
    self.assertEqual(_coeff(conv) * 2 * NUM_CHANNELS, self.cost([conv])) 
开发者ID:google-research,项目名称:morph-net,代码行数:25,代码来源:model_size_regularizer_test.py

示例11: run

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def run(config):
  """Entry point to run training."""
  init_data_normalizer(config)

  stage_ids = train_util.get_stage_ids(**config)
  if not config['train_progressive']:
    stage_ids = list(stage_ids)[-1:]

  # Train one stage at a time
  for stage_id in stage_ids:
    batch_size = train_util.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(config['ps_tasks'])):
      model = lib_model.Model(stage_id, batch_size, config)
      model.add_summaries()
      print('Variables:')
      for v in tf.global_variables():
        print('\t', v.name, v.get_shape().as_list())
      logging.info('Calling train.train')
      train_util.train(model, **config) 
开发者ID:magenta,项目名称:magenta,代码行数:22,代码来源:gansynth_train.py

示例12: _testScope

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def _testScope(self, factory, prefix="prefix", use_outer_scope=True):
    # REMARKS: factory(scope) is a function accepting a scope
    #          as an argument, such scope can be None, a string
    #          or a VariableScope instance.
    with self.session(use_gpu=True, graph=tf.Graph()):
      if use_outer_scope:
        with tf.variable_scope(prefix) as scope:
          factory(scope)
      else:
        factory(prefix)

      # check that all the variables names starts with the proper scope.
      tf.global_variables_initializer()
      all_vars = tf.global_variables()
      prefix = prefix or "stack_bidirectional_rnn"
      scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
      tf.logging.info("StackRNN with scope: %s (%s)" %
                      (prefix, "scope" if use_outer_scope else "str"))
      for v in scope_vars:
        tf.logging.info(v.name)
      self.assertEqual(len(scope_vars), len(all_vars)) 
开发者ID:magenta,项目名称:magenta,代码行数:23,代码来源:rnn_test.py

示例13: recoverer

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def recoverer(sess, model_path, meta_graph_path=None):
    """
    Recovery parameters from a pretrained model.
    Args:
        sess: The tensorflow session instance.
        model_path: Checkpoint file path.
    Returns:
        Nothing
    """
    if meta_graph_path is None:
        restore_var = tf.global_variables()
        restorer = tf.train.Saver(restore_var)
    else:
        restorer = tf.train.import_meta_graph(meta_graph_path)
    restorer.restore(sess, model_path)


# from https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
# 0 = all messages are logged (default behavior)
# 1 = INFO messages are not printed
# 2 = INFO and WARNING messages are not printed
# 3 = INFO, WARNING, and ERROR messages are not printed 
开发者ID:luigifreda,项目名称:pyslam,代码行数:24,代码来源:utils_tf.py

示例14: restore_model

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def restore_model(sess, checkpoint_path, enable_ema=True):
  """Restore variables from the checkpoint into the provided session.

  Args:
    sess: A tensorflow session where the checkpoint will be loaded.
    checkpoint_path: Path to the trained checkpoint.
    enable_ema: (optional) Whether to load the exponential moving average (ema)
      version of the tensorflow variables. Defaults to True.
  """
  if enable_ema:
    ema = tf.train.ExponentialMovingAverage(decay=0.0)
    ema_vars = tf.trainable_variables() + tf.get_collection("moving_vars")
    for v in tf.global_variables():
      if "moving_mean" in v.name or "moving_variance" in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))
    var_dict = ema.variables_to_restore(ema_vars)
  else:
    var_dict = None

  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(var_dict, max_to_keep=1)
  saver.restore(sess, checkpoint_path) 
开发者ID:tensorflow,项目名称:models,代码行数:25,代码来源:post_training_quantization.py

示例15: get_global_variables_safely

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import global_variables [as 别名]
def get_global_variables_safely():
  """If not executing eagerly, returns tf.global_variables().

  Raises a ValueError if eager execution is enabled,
  because the variables are not tracked when executing eagerly.

  If executing eagerly, use a Keras model's .variables property instead.

  Returns:
    The result of tf.global_variables()
  """
  with tf.init_scope():
    if tf.executing_eagerly():
      raise ValueError("Global variables collection is not tracked when "
                       "executing eagerly. Use a Keras model's `.variables` "
                       "attribute instead.")
  return tf.global_variables() 
开发者ID:tensorflow,项目名称:models,代码行数:19,代码来源:variables_helper.py


注:本文中的tensorflow.compat.v1.global_variables方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。