本文整理汇总了Python中tensorflow.compat.v1.batch_gather方法的典型用法代码示例。如果您正苦于以下问题:Python v1.batch_gather方法的具体用法?Python v1.batch_gather怎么用?Python v1.batch_gather使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.batch_gather方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _top_k_sample
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
"""
Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
:param logits: [batch_size, vocab_size] tensor
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
like padding maybe
:param p: topp threshold to use, either a float or a [batch_size] vector
:return: [batch_size, num_samples] samples
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
"""
with tf.variable_scope('top_p_sample'):
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
axis=-1)
# [batch_size, vocab_perm]
indices = tf.argsort(probs, direction='DESCENDING')
# find the top pth index to cut off. careful we don't want to cutoff everything!
# result will be [batch_size, vocab_perm]
k_expanded = k if isinstance(k, int) else k[:, None]
exclude_mask = tf.range(vocab_size)[None] >= k_expanded
# OPTION A - sample in the sorted space, then unsort.
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
sample = tf.batch_gather(indices, sample_perm)
return {
'probs': probs,
'sample': sample,
}
示例2: fast_tpu_gather
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def fast_tpu_gather(params, indices, name=None):
"""Fast gather implementation for models running on TPU.
This function use one_hot and batch matmul to do gather, which is faster
than gather_nd on TPU. For params that have dtype of int32 (sequences to
gather from), batch_gather is used to keep accuracy.
Args:
params: A tensor from which to gather values.
[batch_size, original_size, ...]
indices: A tensor used as the index to gather values.
[batch_size, selected_size].
name: A string, name of the operation (optional).
Returns:
gather_result: A tensor that has the same rank as params.
[batch_size, selected_size, ...]
"""
with tf.name_scope(name):
dtype = params.dtype
def _gather(params, indices):
"""Fast gather using one_hot and batch matmul."""
if dtype != tf.float32:
params = tf.to_float(params)
shape = common_layers.shape_list(params)
indices_shape = common_layers.shape_list(indices)
ndims = params.shape.ndims
# Adjust the shape of params to match one-hot indices, which is the
# requirement of Batch MatMul.
if ndims == 2:
params = tf.expand_dims(params, axis=-1)
if ndims > 3:
params = tf.reshape(params, [shape[0], shape[1], -1])
gather_result = tf.matmul(
tf.one_hot(indices, shape[1], dtype=params.dtype), params)
if ndims == 2:
gather_result = tf.squeeze(gather_result, axis=-1)
if ndims > 3:
shape[1] = indices_shape[1]
gather_result = tf.reshape(gather_result, shape)
if dtype != tf.float32:
gather_result = tf.cast(gather_result, dtype)
return gather_result
# If the dtype is int, use the gather instead of one_hot matmul to avoid
# precision loss. The max int value can be represented by bfloat16 in MXU is
# 256, which is smaller than the possible id values. Encoding/decoding can
# potentially used to make it work, but the benenfit is small right now.
if dtype.is_integer:
gather_result = tf.batch_gather(params, indices)
else:
gather_result = _gather(params, indices)
return gather_result
示例3: _top_p_sample
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
"""
Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
:param logits: [batch_size, vocab_size] tensor
:param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
like padding maybe
:param p: topp threshold to use, either a float or a [batch_size] vector
:return: [batch_size, num_samples] samples
# TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
"""
with tf.variable_scope('top_p_sample'):
batch_size, vocab_size = get_shape_list(logits, expected_rank=2)
probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
axis=-1)
if isinstance(p, float) and p > 0.999999:
# Don't do top-p sampling in this case
print("Top-p sampling DISABLED", flush=True)
return {
'probs': probs,
'sample': tf.random.categorical(
logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
num_samples=num_samples, dtype=tf.int32),
}
# [batch_size, vocab_perm]
indices = tf.argsort(probs, direction='DESCENDING')
cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)
# find the top pth index to cut off. careful we don't want to cutoff everything!
# result will be [batch_size, vocab_perm]
p_expanded = p if isinstance(p, float) else p[:, None]
exclude_mask = tf.logical_not(
tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))
# OPTION A - sample in the sorted space, then unsort.
logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
sample = tf.batch_gather(indices, sample_perm)
# OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
# unperm_indices = tf.argsort(indices, direction='ASCENDING')
# include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
# logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
# sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)
return {
'probs': probs,
'sample': sample,
}
示例4: sample_step
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
"""
Helper function that samples from grover for a single step
:param tokens: [batch_size, n_ctx_b] tokens that we will predict from
:param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
:param news_config: config for the GroverModel
:param batch_size: batch size to use
:param p_for_topp: top-p or top-k threshold
:param cache: [batch_size, news_config.num_hidden_layers, 2,
news_config.num_attention_heads, n_ctx_a,
news_config.hidden_size // news_config.num_attention_heads] OR, None
:return: new_tokens, size [batch_size]
new_probs, also size [batch_size]
new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
"""
model = GroverModel(
config=news_config,
is_training=False,
input_ids=tokens,
reuse=tf.AUTO_REUSE,
scope='newslm',
chop_off_last_token=False,
do_cache=True,
cache=cache,
)
# Extract the FINAL SEQ LENGTH
batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
if do_topk:
sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
else:
sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
new_tokens = tf.squeeze(sample_info['sample'], 1)
new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
return {
'new_tokens': new_tokens,
'new_probs': new_probs,
'new_cache': model.new_kvs,
}
示例5: _build_verified_loss
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _build_verified_loss(self, labels):
"""Build verified loss using an upper bound on specification."""
if not self._specification:
self._verified_loss = tf.constant(0.)
self._interval_bounds_accuracy = tf.constant(0.)
return
# Interval bounds.
bounds = self._get_specification_bounds()
# Select specifications.
if self._interval_bounds_loss_mode == 'all':
pass # Keep bounds the way it is.
elif self._interval_bounds_loss_mode == 'most':
bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
elif self._interval_bounds_loss_mode == 'random':
idx = tf.random.uniform(
[tf.shape(bounds)[0], self._interval_bounds_loss_n],
0, tf.shape(bounds)[1], dtype=tf.int32)
bounds = tf.batch_gather(bounds, idx)
else:
assert self._interval_bounds_loss_mode == 'least'
# This picks the least violated contraint.
mask = tf.cast(bounds < 0., tf.float32)
smallest_violation = tf.reduce_min(
bounds + mask * _BIG_NUMBER, axis=1, keepdims=True)
has_violations = tf.less(
tf.reduce_sum(mask, axis=1, keepdims=True) + .5,
tf.cast(tf.shape(bounds)[1], tf.float32))
largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
bounds = tf.where(has_violations, smallest_violation, largest_bounds)
if self._interval_bounds_loss_type == 'xent':
v = tf.concat(
[bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
axis=1)
l = tf.concat(
[tf.zeros_like(bounds),
tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
axis=1)
self._verified_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.stop_gradient(l), logits=v))
elif self._interval_bounds_loss_type == 'softplus':
self._verified_loss = tf.reduce_mean(
tf.nn.softplus(bounds + self._interval_bounds_hinge_margin))
else:
assert self._interval_bounds_loss_type == 'hinge'
self._verified_loss = tf.reduce_mean(
tf.maximum(bounds, -self._interval_bounds_hinge_margin))
示例6: compute_target_topk_q
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def compute_target_topk_q(reward, gamma, next_actions, next_q_values,
next_states, terminals):
"""Computes the optimal target Q value with the greedy algorithm.
This algorithm corresponds to the method "TT" in
Ie et al. https://arxiv.org/abs/1905.12767.
Args:
reward: [batch_size] tensor, the immediate reward.
gamma: float, discount factor with the usual RL meaning.
next_actions: [batch_size, slate_size] tensor, the next slate.
next_q_values: [batch_size, num_of_documents] tensor, the q values of the
documents in the next step.
next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
user and the docuemnts in the next step.
terminals: [batch_size] tensor, indicating if this is a terminal step.
Returns:
[batch_size] tensor, the target q values.
"""
slate_size = next_actions.get_shape().as_list()[1]
scores, score_no_click = _get_unnormalized_scores(next_states)
# Choose the documents with top affinity_scores * Q values to fill a slate and
# treat it as if it is the optimal slate.
unnormalized_next_q_target = next_q_values * scores
_, topk_optimal_slate = tf.math.top_k(
unnormalized_next_q_target, k=slate_size)
# Get the expected Q-value of the slate containing top-K items.
# [batch_size, slate_size]
next_q_values_selected = tf.batch_gather(
next_q_values, tf.cast(topk_optimal_slate, dtype=tf.int32))
# Get normalized affinity scores on the slate.
# [batch_size, slate_size]
scores_selected = tf.batch_gather(scores,
tf.cast(topk_optimal_slate, dtype=tf.int32))
next_q_target_topk = tf.reduce_sum(
input_tensor=next_q_values_selected * scores_selected, axis=1) / (
tf.reduce_sum(input_tensor=scores_selected, axis=1) + score_no_click)
return reward + gamma * next_q_target_topk * (
1. - tf.cast(terminals, tf.float32))
示例7: _build_train_op
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import batch_gather [as 别名]
def _build_train_op(self):
"""Builds a training op.
Returns:
An op performing one step of training from replay data.
"""
# click_indicator: [B, S]
# q_values: [B, A]
# actions: [B, S]
# slate_q_values: [B, S]
# replay_click_q: [B]
click_indicator = self._replay.rewards[:, :, self._click_response_index]
slate_q_values = tf.batch_gather(
self._replay_net_outputs.q_values,
tf.cast(self._replay.actions, dtype=tf.int32))
# Only get the Q from the clicked document.
replay_click_q = tf.reduce_sum(
input_tensor=slate_q_values * click_indicator,
axis=1,
name='replay_click_q')
target = tf.stop_gradient(self._build_target_q_op())
clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1)
# clicked_indices is a vector and tf.gather selects the batch dimension.
q_clicked = tf.gather(replay_click_q, clicked_indices)
target_clicked = tf.gather(target, clicked_indices)
def get_train_op():
loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked))
if self.summary_writer is not None:
with tf.variable_scope('Losses'):
tf.summary.scalar('Loss', loss)
return loss
loss = tf.cond(
pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
true_fn=get_train_op,
false_fn=lambda: tf.constant(0.),
name='')
return self.optimizer.minimize(loss)