当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.scatter_update方法代码示例

本文整理汇总了Python中tensorflow.scatter_update方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.scatter_update方法的具体用法?Python tensorflow.scatter_update怎么用?Python tensorflow.scatter_update使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.scatter_update方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_update_op

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:23,代码来源:memory.py

示例2: replace

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def replace(self, episodes, length, rows=None):
    """Replace full episodes.

    Args:
      episodes: Tuple of transition quantities with batch and time dimensions.
      length: Batch of sequence lengths.
      rows: Episodes to replace, defaults to all.

    Returns:
      Operation.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    assert_capacity = tf.assert_less(
        rows, self._capacity, message='capacity exceeded')
    with tf.control_dependencies([assert_capacity]):
      assert_max_length = tf.assert_less_equal(
          length, self._max_length, message='max length exceeded')
    replace_ops = []
    with tf.control_dependencies([assert_max_length]):
      for buffer_, elements in zip(self._buffers, episodes):
        replace_op = tf.scatter_update(buffer_, rows, elements)
        replace_ops.append(replace_op)
    with tf.control_dependencies(replace_ops):
      return tf.scatter_update(self._length, rows, length) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:27,代码来源:memory.py

示例3: reinit_nested_vars

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reinit_nested_vars(variables, indices=None):
  """Reset all variables in a nested tuple to zeros.

  Args:
    variables: Nested tuple or list of variaables.
    indices: Indices along the first dimension to reset, defaults to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        reinit_nested_vars(variable, indices) for variable in variables])
  if indices is None:
    return variables.assign(tf.zeros_like(variables))
  else:
    zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
    return tf.scatter_update(variables, indices, zeros) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:20,代码来源:utility.py

示例4: reset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.bool)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, done)]):
      return tf.identity(observ) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:24,代码来源:in_graph_batch_env.py

示例5: reinit_nested_vars

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reinit_nested_vars(variables, indices=None):
  """Reset all variables in a nested tuple to zeros.

  Args:
    variables: Nested tuple or list of variaables.
    indices: Batch indices to reset, defaults to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        reinit_nested_vars(variable, indices) for variable in variables])
  if indices is None:
    return variables.assign(tf.zeros_like(variables))
  else:
    zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
    return tf.scatter_update(variables, indices, zeros) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:20,代码来源:utility.py

示例6: assign_nested_vars

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def assign_nested_vars(variables, tensors, indices=None):
  """Assign tensors to matching nested tuple of variables.

  Args:
    variables: Nested tuple or list of variables to update.
    tensors: Nested tuple or list of tensors to assign.
    indices: Batch indices to assign to; default to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        assign_nested_vars(variable, tensor)
        for variable, tensor in zip(variables, tensors)])
  if indices is None:
    return variables.assign(tensors)
  else:
    return tf.scatter_update(variables, indices, tensors) 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:21,代码来源:utility.py

示例7: _reset_non_empty

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def _reset_non_empty(self, indices):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    observ_dtype = utils.parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ)]):
      return tf.identity(observ) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:18,代码来源:py_func_batch_env.py

示例8: sparse_moving_average

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def sparse_moving_average(self, variable, unique_indices, accumulant, name='Accumulator', decay=.9):
    """"""
    
    accumulant = tf.clip_by_value(accumulant, -self.clip, self.clip)
    first_dim = variable.get_shape().as_list()[0]
    accumulator = self.get_accumulator(name, variable)
    indexed_accumulator = tf.gather(accumulator, unique_indices)
    iteration = self.get_accumulator('{}/iteration'.format(name), variable, shape=[first_dim, 1])
    indexed_iteration = tf.gather(iteration, unique_indices)
    iteration = tf.scatter_add(iteration, unique_indices, tf.ones_like(indexed_iteration))
    indexed_iteration = tf.gather(iteration, unique_indices)
    
    if decay < 1:
      current_indexed_decay = decay * (1-decay**(indexed_iteration-1)) / (1-decay**indexed_iteration)
    else:
      current_indexed_decay = (indexed_iteration-1) / indexed_iteration
    
    accumulator = tf.scatter_update(accumulator, unique_indices, current_indexed_decay*indexed_accumulator)
    accumulator = tf.scatter_add(accumulator, unique_indices, (1-current_indexed_decay)*accumulant)
    return accumulator, iteration
  
  #============================================================= 
开发者ID:tdozat,项目名称:Parser-v3,代码行数:24,代码来源:optimizer.py

示例9: _reset_non_empty

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def _reset_non_empty(self, indices):
    # pylint: disable=protected-access
    new_values = self._batch_env._reset_non_empty(indices)
    # pylint: enable=protected-access
    initial_frames = getattr(self._batch_env, "history_observations", None)
    if initial_frames is not None:
      # Using history buffer frames for initialization, if they are available.
      with tf.control_dependencies([new_values]):
        # Transpose to [batch, height, width, history, channels] and merge
        # history and channels into one dimension.
        initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4])
        initial_frames = tf.reshape(initial_frames,
                                    (len(self),) + self.observ_shape)
    else:
      inx = tf.concat(
          [
              tf.ones(tf.size(tf.shape(new_values)),
                      dtype=tf.int64)[:-1],
              [self.history]
          ],
          axis=0)
      initial_frames = tf.tile(new_values, inx)
    assign_op = tf.scatter_update(self._observ, indices, initial_frames)
    with tf.control_dependencies([assign_op]):
      return tf.gather(self.observ, indices) 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:27,代码来源:tf_atari_wrappers.py

示例10: testWhileUpdateVariable_1

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def testWhileUpdateVariable_1(self):
    with self.test_session():
      select = tf.Variable([3.0, 4.0, 5.0])
      n = tf.constant(0)

      def loop_iterator(j):
        return tf.less(j, 3)

      def loop_body(j):
        ns = tf.scatter_update(select, j, 10.0)
        nj = tf.add(j, 1)
        op = control_flow_ops.group(ns)
        nj = control_flow_ops.with_dependencies([op], nj)
        return [nj]

      r = tf.while_loop(loop_iterator, loop_body, [n],
                        parallel_iterations=1)
      self.assertTrue(check_op_order(n.graph))
      tf.global_variables_initializer().run()
      self.assertEqual(3, r.eval())
      result = select.eval()
      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:24,代码来源:control_flow_ops_py_test.py

示例11: testWhileUpdateVariable_3

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def testWhileUpdateVariable_3(self):
    with self.test_session():
      select = tf.Variable([3.0, 4.0, 5.0])
      n = tf.constant(0)

      def loop_iterator(j, _):
        return tf.less(j, 3)

      def loop_body(j, _):
        ns = tf.scatter_update(select, j, 10.0)
        nj = tf.add(j, 1)
        return [nj, ns]

      r = tf.while_loop(loop_iterator, loop_body,
                        [n, tf.identity(select)],
                        parallel_iterations=1)
      tf.global_variables_initializer().run()
      result = r[1].eval()
    self.assertTrue(check_op_order(n.graph))
    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)

  # b/24814703 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:24,代码来源:control_flow_ops_py_test.py

示例12: insert

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def insert(self, ids, scores):
    """Insert the ids and scores into the TopN."""
    with tf.control_dependencies(self.last_ops):
      scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
      larger_scores = tf.greater(scores, self.sl_scores[0])

      def shortlist_insert():
        larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
        larger_score_values = tf.boolean_mask(scores, larger_scores)
        shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
            self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
        u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
        u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
        return tf.group(u1, u2)

      # We only need to insert into the shortlist if there are any
      # scores larger than the threshold.
      cond_op = tf.cond(
          tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
      with tf.control_dependencies([cond_op]):
        self.last_ops = [scatter_op, cond_op] 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:23,代码来源:topn.py

示例13: remove

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def remove(self, ids):
    """Remove the ids (and their associated scores) from the TopN."""
    with tf.control_dependencies(self.last_ops):
      scatter_op = tf.scatter_update(
          self.id_to_score,
          ids,
          tf.ones_like(
              ids, dtype=tf.float32) * tf.float32.min)
      # We assume that removed ids are almost always in the shortlist,
      # so it makes no sense to hide the Op behind a tf.cond
      shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
                                                                  ids)
      u1 = tf.scatter_update(
          self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
          tf.concat(0, [new_length,
                        tf.ones_like(shortlist_ids_to_remove) * -1]))
      u2 = tf.scatter_update(
          self.sl_scores,
          shortlist_ids_to_remove,
          tf.float32.min * tf.ones_like(
              shortlist_ids_to_remove, dtype=tf.float32))
      self.last_ops = [scatter_op, u1, u2] 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:24,代码来源:topn.py

示例14: reset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.int32)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, tf.to_int32(done))]):
      return tf.identity(observ) 
开发者ID:google-research,项目名称:planet,代码行数:23,代码来源:in_graph_batch_env.py

示例15: curvature_range

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import scatter_update [as 别名]
def curvature_range(self):
    # set up the curvature window
    self._curv_win = \
      tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False)
    self._curv_win = tf.scatter_update(self._curv_win, 
      self._global_step % self._curv_win_width, self._grad_norm_squared)
    # note here the iterations start from iteration 0
    valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ), 
      tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) )
    self._h_min_t = tf.reduce_min(valid_window)
    self._h_max_t = tf.reduce_max(valid_window)

    curv_range_ops = []
    with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
      avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] )
      with tf.control_dependencies([avg_op] ):
        self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) )
        self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) )
    curv_range_ops.append(avg_op)
    return curv_range_ops 
开发者ID:Zehaos,项目名称:MobileNet,代码行数:22,代码来源:yellowfin.py


注:本文中的tensorflow.scatter_update方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。