当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.local_variables函数代码示例

本文整理汇总了Python中tensorflow.local_variables函数的典型用法代码示例。如果您正苦于以下问题:Python local_variables函数的具体用法?Python local_variables怎么用?Python local_variables使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了local_variables函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_local_variable

 def test_local_variable(self):
   with self.test_session() as sess:
     self.assertEquals([], tf.local_variables())
     value0 = 42
     tf.contrib.framework.local_variable(value0)
     value1 = 43
     tf.contrib.framework.local_variable(value1)
     variables = tf.local_variables()
     self.assertEquals(2, len(variables))
     self.assertRaises(tf.OpError, sess.run, variables)
     tf.variables_initializer(variables).run()
     self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
开发者ID:jeffzheng1,项目名称:tensorflow,代码行数:12,代码来源:variables_test.py

示例2: testNotInLocalVariables

 def testNotInLocalVariables(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.model_variable('a', [5])
       self.assertTrue(a in tf.global_variables())
       self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in tf.local_variables())
开发者ID:jeffzheng1,项目名称:tensorflow,代码行数:7,代码来源:variables_test.py

示例3: get_post_init_ops

 def get_post_init_ops():
     """
     Copy values of variables on GPU 0 to other GPUs.
     """
     # literally all variables, because it's better to sync optimizer-internal variables as well
     all_vars = tf.global_variables() + tf.local_variables()
     var_by_name = dict([(v.name, v) for v in all_vars])
     post_init_ops = []
     for v in all_vars:
         if not v.name.startswith('tower'):
             continue
         if v.name.startswith('tower0'):
             logger.warn("[SyncMultiGPUReplicatedBuilder] variable "
                         "{} has prefix 'tower0', this is unexpected.".format(v.name))
             continue        # TODO some vars (EMA) may still startswith tower0
         # in this trainer, the master name doesn't have the towerx/ prefix
         split_name = v.name.split('/')
         prefix = split_name[0]
         realname = '/'.join(split_name[1:])
         if prefix in realname:
             logger.error("[SyncMultiGPUReplicatedBuilder] variable "
                          "{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
         copy_from = var_by_name.get(realname)
         assert copy_from is not None, var_by_name.keys()
         post_init_ops.append(v.assign(copy_from.read_value()))
     logger.info(
         "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
     return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
开发者ID:caserzer,项目名称:tensorpack,代码行数:28,代码来源:training.py

示例4: count_variables_by_type

def count_variables_by_type(variables=None):
  """Returns a dict mapping dtypes to number of variables and scalars.

  Args:
    variables: iterable of `tf.Variable`s, or None. If None is passed, then all
      global and local variables in the current graph are used.

  Returns:
    A dict mapping tf.dtype keys to a dict containing the keys 'num_scalars' and
      'num_variables'.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  unique_types = set(v.dtype.base_dtype for v in variables)
  results_dict = {}
  for dtype in unique_types:
    if dtype == tf.string:
      tf.logging.warning(
          "NB: string Variables present. The memory usage for these  Variables "
          "will not be accurately computed as it depends on the exact strings "
          "stored in a particular session.")
    vars_of_type = [v for v in variables if v.dtype.base_dtype == dtype]
    num_scalars = sum(v.shape.num_elements() for v in vars_of_type)
    results_dict[dtype] = {
        "num_variables": len(vars_of_type),
        "num_scalars": num_scalars
    }
  return results_dict
开发者ID:ccchang0111,项目名称:sonnet,代码行数:28,代码来源:util.py

示例5: cnn_train

def cnn_train(config, data_len, embed, pf_r1, pf_r2):
    config.data_len = data_len
    tf.reset_default_graph()
    with tf.Session() as session:
        # build model
        with tf.variable_scope("cnn_ch", reuse=None):
            m_train = ch_model(config)
        with tf.variable_scope("cnn_ch", reuse=True):
            m_valid = ch_model(config)

        doc_datas, pf_r1s, pf_r2s, labels = read_batch(config.csv_file, config, True)
        doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v = read_batch(config.csv_file, config, False)


        for item in tf.all_variables():
            print "var: ", item
        for item in tf.local_variables():
            print "local:", item

        loss, _ = m_train.inference(doc_datas, pf_r1s, pf_r2s, labels)
        loss_v, acc_v = m_valid.inference(doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v)
        train_op = m_train.train(loss)

        tf.initialize_all_variables().run()
        tf.initialize_local_variables().run()
        m_train.assign_word_embed(session, embed)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=session)
        
        epoch = 0
        step = 0
        min_cost = sys.maxint
        try:
            while not coord.should_stop():
                _, f_l = session.run([train_op, loss])
                step += 1
                if step == config.data_len // config.batch_size:
                    cost = 0.0
                    acc = 0.0
                    for i in range(step):
                        v_l, acc_l = session.run([loss_v, acc_v])
                        cost += v_l
                        acc += acc_l
                    cost /= step
                    acc /= step
                    if cost < min_cost:
                        min_cost = cost
                        print "save model as cost:", cost
                        m_train.saver.save(session, config.model_path)
                    print "epoch: ", epoch, "loss: ", cost, "acc: ", acc, "step:", step
                    step = 0
                    epoch += 1
        except tf.errors.OutOfRangeError:
            print("Done training")
        finally:
            coord.request_stop()
        coord.join(threads)
开发者ID:sww9370,项目名称:Relation_Extraction,代码行数:58,代码来源:extraction_tensorflow.py

示例6: testCreateVariable

 def testCreateVariable(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.variable('a', [5])
       self.assertEquals(a.op.name, 'A/a')
       self.assertListEqual(a.get_shape().as_list(), [5])
       self.assertTrue(a in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
       self.assertFalse(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in tf.local_variables())
开发者ID:jeffzheng1,项目名称:tensorflow,代码行数:9,代码来源:variables_test.py

示例7: setup_graph

 def setup_graph(self):
     """ Will setup the assign operator for that variable. """
     all_vars = tf.global_variables() + tf.local_variables()
     for v in all_vars:
         if v.name == self.var_name:
             self.var = v
             break
     else:
         raise ValueError("{} is not a variable in the graph!".format(self.var_name))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:9,代码来源:param.py

示例8: testUsage

  def testUsage(self, custom_getter_fn):
    # Create a module with no custom getters.
    linear = snt.Linear(10)

    # Create a module within the scope of an 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10)

    # Connect both modules to the graph, creating their variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    linear(inputs)
    local_linear(inputs)

    self.assertIn(linear.w, tf.global_variables())
    self.assertNotIn(linear.w, tf.local_variables())
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
开发者ID:ccchang0111,项目名称:sonnet,代码行数:19,代码来源:override_args_test.py

示例9: testExplicitArgOverridden

  def testExplicitArgOverridden(self):
    # Create a variable within the scope of an 'override args' custom getter.
    local_custom_getter = snt.custom_getters.override_args(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      # Explicitly specify an arg that disagrees with the custom getter.
      v = tf.get_variable("v", (), collections=[tf.GraphKeys.GLOBAL_VARIABLES])

    # The custom getter should win.
    self.assertIn(v, tf.local_variables())
    self.assertNotIn(v, tf.global_variables())
开发者ID:ccchang0111,项目名称:sonnet,代码行数:11,代码来源:override_args_test.py

示例10: _initialize_variables

def _initialize_variables():
    """Utility to initialize uninitialized variables on the fly.
    """
    variables = tf.local_variables()
    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = K.get_session()
        sess.run(tf.variables_initializer(uninitialized_variables))
开发者ID:sinianyutian,项目名称:keras-fcn-1,代码行数:12,代码来源:metrics.py

示例11: log_variables

def log_variables(variables=None):
  """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  for row in format_variables(variables, join_lines=False):
    tf.logging.info(row)
开发者ID:geniusjiqing,项目名称:sonnet,代码行数:14,代码来源:util.py

示例12: testWithNested

  def testWithNested(self, custom_getter_fn):
    # Create a module with a custom getter, within the scope of an
    # 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10, custom_getter=_suffix_custom_getter)

    # Connect the module to the graph, creating its variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    local_linear(inputs)

    # Both custom getters should be effective.
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
    self.assertEqual("linear/w_test", local_linear.w.op.name)
开发者ID:ccchang0111,项目名称:sonnet,代码行数:16,代码来源:override_args_test.py

示例13: guarantee_initialized_variables

def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
开发者ID:siddk,项目名称:lang2program,代码行数:16,代码来源:utils.py

示例14: log_variables

def log_variables(variables=None):
  """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables. In the "Device" columns,
  the nature of the variable (legacy or resource (for ResourceVariables)) is
  also specified in parenthesis.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  for row in format_variables(variables, join_lines=False):
    tf.logging.info(row)
开发者ID:ccchang0111,项目名称:sonnet,代码行数:16,代码来源:util.py

示例15: get_post_init_ops

  def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops
开发者ID:Ericyuanhui,项目名称:Build_learning,代码行数:18,代码来源:variable_mgr.py


注:本文中的tensorflow.local_variables函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。