当前位置: 首页>>代码示例>>Python>>正文


Python v1.batch_gather方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.batch_gather方法的典型用法代码示例。如果您正苦于以下问题:Python v1.batch_gather方法的具体用法?Python v1.batch_gather怎么用?Python v1.batch_gather使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.batch_gather方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _top_k_sample

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
    """
    Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)
        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        k_expanded = k if isinstance(k, int) else k[:, None]
        exclude_mask = tf.range(vocab_size)[None] >= k_expanded

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

    return {
        'probs': probs,
        'sample': sample,
    } 
开发者ID:imcaspar,项目名称:gpt2-ml,代码行数:35,代码来源:modeling.py

示例2: fast_tpu_gather

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype.is_integer:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:57,代码来源:beam_search.py

示例3: _top_p_sample

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs': probs,
                'sample': tf.random.categorical(
                    logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples, dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    } 
开发者ID:imcaspar,项目名称:gpt2-ml,代码行数:54,代码来源:modeling.py

示例4: sample_step

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
    """
    Helper function that samples from grover for a single step
    :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
    :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
    :param news_config: config for the GroverModel
    :param batch_size: batch size to use
    :param p_for_topp: top-p or top-k threshold
    :param cache: [batch_size, news_config.num_hidden_layers, 2,
                   news_config.num_attention_heads, n_ctx_a,
                   news_config.hidden_size // news_config.num_attention_heads] OR, None
    :return: new_tokens, size [batch_size]
             new_probs, also size [batch_size]
             new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
                   news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
    """
    model = GroverModel(
        config=news_config,
        is_training=False,
        input_ids=tokens,
        reuse=tf.AUTO_REUSE,
        scope='newslm',
        chop_off_last_token=False,
        do_cache=True,
        cache=cache,
    )

    # Extract the FINAL SEQ LENGTH
    batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
    next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]

    if do_topk:
        sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
    else:
        sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)

    new_tokens = tf.squeeze(sample_info['sample'], 1)
    new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
    return {
        'new_tokens': new_tokens,
        'new_probs': new_probs,
        'new_cache': model.new_kvs,
    } 
开发者ID:imcaspar,项目名称:gpt2-ml,代码行数:45,代码来源:modeling.py

示例5: _build_verified_loss

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _build_verified_loss(self, labels):
    """Build verified loss using an upper bound on specification."""
    if not self._specification:
      self._verified_loss = tf.constant(0.)
      self._interval_bounds_accuracy = tf.constant(0.)
      return
    # Interval bounds.
    bounds = self._get_specification_bounds()
    # Select specifications.
    if self._interval_bounds_loss_mode == 'all':
      pass  # Keep bounds the way it is.
    elif self._interval_bounds_loss_mode == 'most':
      bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
    elif self._interval_bounds_loss_mode == 'random':
      idx = tf.random.uniform(
          [tf.shape(bounds)[0], self._interval_bounds_loss_n],
          0, tf.shape(bounds)[1], dtype=tf.int32)
      bounds = tf.batch_gather(bounds, idx)
    else:
      assert self._interval_bounds_loss_mode == 'least'
      # This picks the least violated contraint.
      mask = tf.cast(bounds < 0., tf.float32)
      smallest_violation = tf.reduce_min(
          bounds + mask * _BIG_NUMBER, axis=1, keepdims=True)
      has_violations = tf.less(
          tf.reduce_sum(mask, axis=1, keepdims=True) + .5,
          tf.cast(tf.shape(bounds)[1], tf.float32))
      largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
      bounds = tf.where(has_violations, smallest_violation, largest_bounds)

    if self._interval_bounds_loss_type == 'xent':
      v = tf.concat(
          [bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      l = tf.concat(
          [tf.zeros_like(bounds),
           tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      self._verified_loss = tf.reduce_mean(
          tf.nn.softmax_cross_entropy_with_logits_v2(
              labels=tf.stop_gradient(l), logits=v))
    elif self._interval_bounds_loss_type == 'softplus':
      self._verified_loss = tf.reduce_mean(
          tf.nn.softplus(bounds + self._interval_bounds_hinge_margin))
    else:
      assert self._interval_bounds_loss_type == 'hinge'
      self._verified_loss = tf.reduce_mean(
          tf.maximum(bounds, -self._interval_bounds_hinge_margin)) 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:50,代码来源:loss.py

示例6: compute_target_topk_q

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def compute_target_topk_q(reward, gamma, next_actions, next_q_values,
                          next_states, terminals):
  """Computes the optimal target Q value with the greedy algorithm.

  This algorithm corresponds to the method "TT" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
  slate_size = next_actions.get_shape().as_list()[1]
  scores, score_no_click = _get_unnormalized_scores(next_states)

  # Choose the documents with top affinity_scores * Q values to fill a slate and
  # treat it as if it is the optimal slate.
  unnormalized_next_q_target = next_q_values * scores
  _, topk_optimal_slate = tf.math.top_k(
      unnormalized_next_q_target, k=slate_size)

  # Get the expected Q-value of the slate containing top-K items.
  # [batch_size, slate_size]
  next_q_values_selected = tf.batch_gather(
      next_q_values, tf.cast(topk_optimal_slate, dtype=tf.int32))

  # Get normalized affinity scores on the slate.
  # [batch_size, slate_size]
  scores_selected = tf.batch_gather(scores,
                                    tf.cast(topk_optimal_slate, dtype=tf.int32))

  next_q_target_topk = tf.reduce_sum(
      input_tensor=next_q_values_selected * scores_selected, axis=1) / (
          tf.reduce_sum(input_tensor=scores_selected, axis=1) + score_no_click)

  return reward + gamma * next_q_target_topk * (
      1. - tf.cast(terminals, tf.float32)) 
开发者ID:google-research,项目名称:recsim,代码行数:47,代码来源:slate_decomp_q_agent.py

示例7: _build_train_op

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _build_train_op(self):
    """Builds a training op.

    Returns:
      An op performing one step of training from replay data.
    """
    # click_indicator: [B, S]
    # q_values: [B, A]
    # actions: [B, S]
    # slate_q_values: [B, S]
    # replay_click_q: [B]
    click_indicator = self._replay.rewards[:, :, self._click_response_index]
    slate_q_values = tf.batch_gather(
        self._replay_net_outputs.q_values,
        tf.cast(self._replay.actions, dtype=tf.int32))
    # Only get the Q from the clicked document.
    replay_click_q = tf.reduce_sum(
        input_tensor=slate_q_values * click_indicator,
        axis=1,
        name='replay_click_q')

    target = tf.stop_gradient(self._build_target_q_op())

    clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
    clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1)
    # clicked_indices is a vector and tf.gather selects the batch dimension.
    q_clicked = tf.gather(replay_click_q, clicked_indices)
    target_clicked = tf.gather(target, clicked_indices)

    def get_train_op():
      loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked))
      if self.summary_writer is not None:
        with tf.variable_scope('Losses'):
          tf.summary.scalar('Loss', loss)

      return loss

    loss = tf.cond(
        pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
        true_fn=get_train_op,
        false_fn=lambda: tf.constant(0.),
        name='')

    return self.optimizer.minimize(loss) 
开发者ID:google-research,项目名称:recsim,代码行数:46,代码来源:slate_decomp_q_agent.py


注:本文中的tensorflow.compat.v1.batch_gather方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。