当前位置: 首页>>代码示例>>Python>>正文


Python state_ops.scatter_update方法代码示例

本文整理汇总了Python中tensorflow.python.ops.state_ops.scatter_update方法的典型用法代码示例。如果您正苦于以下问题:Python state_ops.scatter_update方法的具体用法?Python state_ops.scatter_update怎么用?Python state_ops.scatter_update使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.ops.state_ops的用法示例。


在下文中一共展示了state_ops.scatter_update方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _apply_sparse

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def _apply_sparse(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        eps = 1e-7  # cap for moving average

        m = self.get_slot(var, "m")
        m_slice = tf.gather(m, grad.indices)
        m_t = state_ops.scatter_update(m, grad.indices,
                                       tf.maximum(beta_t * m_slice + eps, tf.abs(grad.values)))
        m_t_slice = tf.gather(m_t, grad.indices)

        var_update = state_ops.scatter_sub(var, grad.indices, lr_t * grad.values * tf.exp(
            tf.log(alpha_t) * tf.sign(grad.values) * tf.sign(m_t_slice)))  # Update 'ref' by subtracting 'value
        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        return control_flow_ops.group(*[var_update, m_t]) 
开发者ID:ChenglongChen,项目名称:tensorflow-XNN,代码行数:20,代码来源:optimizer.py

示例2: scatter_update

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def scatter_update(cls, factor, indices, values, sharding_func, name=None):
    """Helper function for doing sharded scatter update."""
    assert isinstance(factor, list)
    if len(factor) == 1:
      with ops.colocate_with(factor[0]):
        # TODO(agarwal): assign instead of scatter update for full batch update.
        return state_ops.scatter_update(factor[0], indices, values,
                                        name=name).op
    else:
      num_shards = len(factor)
      assignments, new_ids = sharding_func(indices)
      assert assignments is not None
      assignments = math_ops.cast(assignments, dtypes.int32)
      sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
                                                    num_shards)
      sharded_values = data_flow_ops.dynamic_partition(values, assignments,
                                                       num_shards)
      updates = []
      for i in xrange(num_shards):
        updates.append(state_ops.scatter_update(factor[i], sharded_ids[i],
                                                sharded_values[i]))
      return control_flow_ops.group(*updates, name=name) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:24,代码来源:factorization_ops.py

示例3: scatter_update

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def scatter_update(cls, factor, indices, values, sharding_func):
    """Helper function for doing sharded scatter update."""
    assert isinstance(factor, list)
    if len(factor) == 1:
      with ops.colocate_with(factor[0]):
        # TODO(agarwal): assign instead of scatter update for full batch update.
        return state_ops.scatter_update(factor[0], indices, values).op
    else:
      num_shards = len(factor)
      assignments, new_ids = sharding_func(indices)
      assert assignments is not None
      assignments = math_ops.cast(assignments, dtypes.int32)
      sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
                                                    num_shards)
      sharded_values = data_flow_ops.dynamic_partition(values, assignments,
                                                       num_shards)
      updates = []
      for i in xrange(num_shards):
        updates.append(
            state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[
                i]))
      return control_flow_ops.group(*updates) 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:24,代码来源:factorization_ops.py

示例4: insert

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def insert(self, ids, scores):
    """Insert the ids and scores into the TopN."""
    with ops.control_dependencies(self.last_ops):
      scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores)
      larger_scores = math_ops.greater(scores, self.sl_scores[0])

      def shortlist_insert():
        larger_ids = array_ops.boolean_mask(
            math_ops.to_int64(ids), larger_scores)
        larger_score_values = array_ops.boolean_mask(scores, larger_scores)
        shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
            self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
        u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
        u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
        return control_flow_ops.group(u1, u2)

      # We only need to insert into the shortlist if there are any
      # scores larger than the threshold.
      cond_op = control_flow_ops.cond(
          math_ops.reduce_any(larger_scores), shortlist_insert,
          control_flow_ops.no_op)
      with ops.control_dependencies([cond_op]):
        self.last_ops = [scatter_op, cond_op] 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:25,代码来源:topn.py

示例5: remove

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def remove(self, ids):
    """Remove the ids (and their associated scores) from the TopN."""
    with ops.control_dependencies(self.last_ops):
      scatter_op = state_ops.scatter_update(
          self.id_to_score,
          ids,
          array_ops.ones_like(
              ids, dtype=dtypes.float32) * dtypes.float32.min)
      # We assume that removed ids are almost always in the shortlist,
      # so it makes no sense to hide the Op behind a tf.cond
      shortlist_ids_to_remove, new_length = tensor_forest_ops.top_n_remove(
          self.sl_ids, ids)
      u1 = state_ops.scatter_update(
          self.sl_ids,
          array_ops.concat([[0], shortlist_ids_to_remove], 0),
          array_ops.concat(
              [new_length, array_ops.ones_like(shortlist_ids_to_remove) * -1],
              0))
      u2 = state_ops.scatter_update(
          self.sl_scores,
          shortlist_ids_to_remove,
          dtypes.float32.min * array_ops.ones_like(
              shortlist_ids_to_remove, dtype=dtypes.float32))
      self.last_ops = [scatter_op, u1, u2] 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:26,代码来源:topn.py

示例6: tree_initialization

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def tree_initialization(self):
    def _init_tree():
      return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op

    def _nothing():
      return control_flow_ops.no_op()

    return control_flow_ops.cond(
        math_ops.equal(
            array_ops.squeeze(
                array_ops.strided_slice(self.variables.tree, [0, 0], [1, 1])),
            -2), _init_tree, _nothing) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:14,代码来源:tensor_forest.py

示例7: _apply_sparse

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def _apply_sparse(self, grad, var):
    beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
    beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))

    # m := beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_t = state_ops.scatter_update(m, grad.indices,
                                   beta1_t * array_ops.gather(m, grad.indices) +
                                   (1 - beta1_t) * grad.values,
                                   use_locking=self._use_locking)

    # v := beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_t = state_ops.scatter_update(v, grad.indices,
                                   beta2_t * array_ops.gather(v, grad.indices) +
                                   (1 - beta2_t) * math_ops.square(grad.values),
                                   use_locking=self._use_locking)

    # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
    m_t_slice = array_ops.gather(m_t, grad.indices)
    v_t_slice = array_ops.gather(v_t, grad.indices)
    denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_t_slice / denominator_slice,
                                       use_locking=self._use_locking)
    return control_flow_ops.group(var_update, m_t, v_t) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:33,代码来源:lazy_adam_optimizer.py

示例8: _apply_sparse

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def _apply_sparse(self, grad, var):
        t = math_ops.cast(self._iterations, var.dtype.base_dtype) + 1.
        m_schedule = math_ops.cast(self._m_schedule, var.dtype.base_dtype)
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
        schedule_decay_t = math_ops.cast(self._schedule_decay_t, var.dtype.base_dtype)

        # Due to the recommendations in [2], i.e. warming momentum schedule
        momentum_cache_power = self._get_momentum_cache(schedule_decay_t, t)
        momentum_cache_t = beta1_t * (1. - 0.5 * momentum_cache_power)
        momentum_cache_t_1 = beta1_t * (1. - 0.5 * momentum_cache_power * self._momentum_cache_const)
        m_schedule_new = m_schedule * momentum_cache_t
        m_schedule_next = m_schedule_new * momentum_cache_t_1

        # the following equations given in [1]
        # m_t = beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = state_ops.scatter_update(m, grad.indices,
                                       beta1_t * array_ops.gather(m, grad.indices) +
                                       (1. - beta1_t) * grad.values,
                                       use_locking=self._use_locking)
        g_prime_slice = grad.values / (1. - m_schedule_new)
        m_t_prime_slice = array_ops.gather(m_t, grad.indices) / (1. - m_schedule_next)
        m_t_bar_slice = (1. - momentum_cache_t) * g_prime_slice + momentum_cache_t_1 * m_t_prime_slice

        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = state_ops.scatter_update(v, grad.indices,
                                       beta2_t * array_ops.gather(v, grad.indices) +
                                       (1. - beta2_t) * tf.square(grad.values),
                                       use_locking=self._use_locking)
        v_t_prime_slice = array_ops.gather(v_t, grad.indices) / (1. - tf.pow(beta2_t, t))

        var_update = state_ops.scatter_sub(var, grad.indices,
                                           lr_t * m_t_bar_slice / (math_ops.sqrt(v_t_prime_slice) + epsilon_t),
                                           use_locking=self._use_locking)

        return control_flow_ops.group(*[var_update, m_t, v_t]) 
开发者ID:yyht,项目名称:BERT,代码行数:42,代码来源:nadam.py

示例9: tree_initialization

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def tree_initialization(self):
    def _init_tree():
      return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op

    def _nothing():
      return control_flow_ops.no_op()

    return control_flow_ops.cond(
        math_ops.equal(array_ops.squeeze(array_ops.slice(
            self.variables.tree, [0, 0], [1, 1])), -2),
        _init_tree, _nothing) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:13,代码来源:tensor_forest.py

示例10: _apply_sparse

# 需要导入模块: from tensorflow.python.ops import state_ops [as 别名]
# 或者: from tensorflow.python.ops.state_ops import scatter_update [as 别名]
def _apply_sparse(self, grad, var):
    lr = (self._lr_t *
          math_ops.sqrt(1 - self._beta2_power)
          / (1 - self._beta1_power))
    # m_t = beta1 * m + (1 - beta1) * g_t
    m = self.get_slot(var, "m")
    m_scaled_g_values = grad.values * (1 - self._beta1_t)
    m_scaled = gen_array_ops.gather(m, grad.indices) * self._beta1_t
    m_t = state_ops.scatter_update(m, grad.indices,
                                   m_scaled + m_scaled_g_values,
                                   use_locking=self._use_locking)
    m_tp = gen_array_ops.gather(m_t, grad.indices)
    
    # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
    v = self.get_slot(var, "v")
    v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
    v_scaled = gen_array_ops.gather(v, grad.indices) * self._beta2_t
    v_t = state_ops.scatter_update(v, grad.indices,
                                   v_scaled + v_scaled_g_values,
                                   use_locking=self._use_locking)
    v_tp = gen_array_ops.gather(v_t, grad.indices)
    v_sqrtp = math_ops.sqrt(v_tp)
    
    var_update = state_ops.scatter_sub(var, grad.indices,
                                       lr * m_tp / (v_sqrtp + self._epsilon_t),
                                       use_locking=self._use_locking)    
    return control_flow_ops.group(*[var_update, m_t, v_t]) 
开发者ID:chentingpc,项目名称:NNCF,代码行数:29,代码来源:optimizer.py


注:本文中的tensorflow.python.ops.state_ops.scatter_update方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。