本文整理汇总了Python中tensorflow.gather方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.gather方法的具体用法?Python tensorflow.gather怎么用?Python tensorflow.gather使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.gather方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_hint_pool_idxs
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_hint_pool_idxs(self, normalized_query):
"""Get small set of idxs to compute nearest neighbor queries on.
This is an expensive look-up on the whole memory that is used to
avoid more expensive operations later on.
Args:
normalized_query: A Tensor of shape [None, key_dim].
Returns:
A Tensor of shape [None, choose_k] of indices in memory
that are closest to the queries.
"""
# get hash of query vecs
hash_slot_idxs = self.get_hash_slots(normalized_query)
# grab mem idxs in the hash slots
hint_pool_idxs = [
tf.maximum(tf.minimum(
tf.gather(self.hash_slots[i], idxs),
self.memory_size - 1), 0)
for i, idxs in enumerate(hash_slot_idxs)]
return tf.concat(axis=1, values=hint_pool_idxs)
示例2: scheduled_sample
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
"""Sample batch with specified mix of ground truth and generated data points.
Args:
ground_truth_x: tensor of ground-truth data points.
generated_x: tensor of generated data points.
batch_size: batch size
num_ground_truth: number of ground-truth examples to include in batch.
Returns:
New batch with num_ground_truth sampled from ground_truth_x and the rest
from generated_x.
"""
idx = tf.random_shuffle(tf.range(int(batch_size)))
ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
generated_examps = tf.gather(generated_x, generated_idx)
return tf.dynamic_stitch([ground_truth_idx, generated_idx],
[ground_truth_examps, generated_examps])
示例3: build_cross_entropy_loss
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def build_cross_entropy_loss(logits, gold):
"""Constructs a cross entropy from logits and one-hot encoded gold labels.
Supports skipping rows where the gold label is the magic -1 value.
Args:
logits: float Tensor of scores.
gold: int Tensor of one-hot labels.
Returns:
cost, correct, total: the total cost, the total number of correctly
predicted labels, and the total number of valid labels.
"""
valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
gold = tf.gather(gold, valid)
logits = tf.gather(logits, valid)
correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
total = tf.size(gold)
cost = tf.reduce_sum(
tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
return cost, correct, total
示例4: clip_tensor
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def clip_tensor(t, length):
"""Clips the input tensor along the first dimension up to the length.
Args:
t: the input tensor, assuming the rank is at least 1.
length: a tensor of shape [1] or an integer, indicating the first dimension
of the input tensor t after clipping, assuming length <= t.shape[0].
Returns:
clipped_t: the clipped tensor, whose first dimension is length. If the
length is an integer, the first dimension of clipped_t is set to length
statically.
"""
clipped_t = tf.gather(t, tf.range(length))
if not _is_tensor(length):
clipped_t = _set_dim_0(clipped_t, length)
return clipped_t
示例5: filter_field_value_equals
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def filter_field_value_equals(boxlist, field, value, scope=None):
"""Filter to keep only boxes with field entries equal to the given value.
Args:
boxlist: BoxList holding N boxes.
field: field name for filtering.
value: scalar value.
scope: name scope.
Returns:
a BoxList holding M boxes where M <= N
Raises:
ValueError: if boxlist not a BoxList object or if it does not have
the specified field.
"""
with tf.name_scope(scope, 'FilterFieldValueEquals'):
if not isinstance(boxlist, box_list.BoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field(field):
raise ValueError('boxlist must contain the specified field')
filter_field = boxlist.get_field(field)
gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
return gather(boxlist, gather_index)
示例6: memory_run
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def memory_run(step, nmaps, mem_size, batch_size, vocab_size,
global_step, do_training, update_mem, decay_factor, num_gpus,
target_emb_weights, output_w, gpu_targets_tn, it):
"""Run memory."""
q = step[:, 0, it, :]
mlabels = gpu_targets_tn[:, it, 0]
res, mask, mem_loss = memory_call(
q, mlabels, nmaps, mem_size, vocab_size, num_gpus, update_mem)
res = tf.gather(target_emb_weights, res) * tf.expand_dims(mask[:, 0], 1)
# Mix gold and original in the first steps, 20% later.
gold = tf.nn.dropout(tf.gather(target_emb_weights, mlabels), 0.7)
use_gold = 1.0 - tf.cast(global_step, tf.float32) / (1000. * decay_factor)
use_gold = tf.maximum(use_gold, 0.2) * do_training
mem = tf.cond(tf.less(tf.random_uniform([]), use_gold),
lambda: use_gold * gold + (1.0 - use_gold) * res,
lambda: res)
mem = tf.reshape(mem, [-1, 1, 1, nmaps])
return mem, mem_loss, update_mem
示例7: batch_transformer
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
"""Batch Spatial Transformer Layer
Parameters
----------
U : float
tensor of inputs [num_batch,height,width,num_channels]
thetas : float
a set of transformations for each input [num_batch,num_transforms,6]
out_size : int
the size of the output [out_height,out_width]
Returns: float
Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
"""
with tf.variable_scope(name):
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
indices = [[i]*num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)
示例8: data
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def data(self, rows=None):
"""Access a batch of episodes from the memory.
Padding elements after the length of each episode are unspecified and might
contain old data.
Args:
rows: Episodes to select, defaults to all.
Returns:
Tuple containing a tuple of transition quantiries with batch and time
dimensions, and a batch of sequence lengths.
"""
rows = tf.range(self._capacity) if rows is None else rows
assert rows.shape.ndims == 1
episode = [tf.gather(buffer_, rows) for buffer_ in self._buffers]
length = tf.gather(self._length, rows)
return episode, length
示例9: scheduled_sample
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def scheduled_sample(self,
ground_truth_x,
generated_x,
batch_size,
num_ground_truth):
"""Sample batch with specified mix of groundtruth and generated data points.
Args:
ground_truth_x: tensor of ground-truth data points.
generated_x: tensor of generated data points.
batch_size: batch size
num_ground_truth: number of ground-truth examples to include in batch.
Returns:
New batch with num_ground_truth sampled from ground_truth_x and the rest
from generated_x.
"""
idx = tf.random_shuffle(tf.range(batch_size))
ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))
ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
generated_examps = tf.gather(generated_x, generated_idx)
return tf.dynamic_stitch([ground_truth_idx, generated_idx],
[ground_truth_examps, generated_examps])
示例10: convert_gradient_to_tensor
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def convert_gradient_to_tensor(x):
"""Identity operation whose gradient is converted to a `Tensor`.
Currently, the gradient to `tf.concat` is particularly expensive to
compute if dy is an `IndexedSlices` (a lack of GPU implementation
forces the gradient operation onto CPU). This situation occurs when
the output of the `tf.concat` is eventually passed to `tf.gather`.
It is sometimes faster to convert the gradient to a `Tensor`, so as
to get the cheaper gradient for `tf.concat`. To do this, replace
`tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`.
Args:
x: A `Tensor`.
Returns:
The input `Tensor`.
"""
return x
示例11: add_positional_embedding
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def add_positional_embedding(x, max_length, name, positions=None):
"""Add positional embedding.
Args:
x: a Tensor with shape [batch, length, depth]
max_length: an integer. static maximum size of any dimension.
name: a name for this layer.
positions: an optional tensor with shape [batch, length]
Returns:
a Tensor the same shape as x.
"""
_, length, depth = common_layers.shape_list(x)
var = tf.cast(tf.get_variable(name, [max_length, depth]), x.dtype)
if positions is None:
sliced = tf.cond(
tf.less(length, max_length),
lambda: tf.slice(var, [0, 0], [length, -1]),
lambda: tf.pad(var, [[0, length - max_length], [0, 0]]))
return x + tf.expand_dims(sliced, 0)
else:
return x + tf.gather(var, tf.to_int32(positions))
示例12: get_shifted_center_blocks
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_shifted_center_blocks(x, indices):
"""Get right shifted blocks for masked local attention 2d.
Args:
x: A tensor with shape [batch, heads, height, width, depth]
indices: The indices to gather blocks
Returns:
x_shifted: a tensor of extracted blocks, each block right shifted along
length.
"""
center_x = gather_blocks_2d(x, indices)
# Shift right along the length dimension
def shift_right_2d_blocks(x):
"""Shift the second to last dimension of x right by one."""
shifted_targets = (
tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :, :-1, :])
return shifted_targets
x_shifted = shift_right_2d_blocks(center_x)
return x_shifted
示例13: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def __init__(self, pc: _Network3D, config, centers, sess, freqs_resolution=1e9):
"""
:param sess: Must be set at the latest before using get_pr or get_freqs
"""
self.pc_class = pc.__class__
self.config = config
self.input_ctx_shape = self.pc_class.get_context_shape(config)
self.input_ctx = tf.placeholder(tf.int64, self.input_ctx_shape) # symbols!
input_ctx_batched = tf.expand_dims(self.input_ctx, 0) # add batch dimension, 1DHW
input_ctx_batched = tf.expand_dims(input_ctx_batched, -1) # add T dimension for 3d conv, now 1CHW1
# Here, in contrast to pc.bitcost(...), q does not need to be padded, as it is part of some context.
# Logits will be a 1111L vector, i.e., prediction of the next pixel
q = tf.gather(centers, input_ctx_batched)
logits = pc.logits(q, is_training=False)
self.pr = tf.nn.softmax(logits)
self.freqs = tf.squeeze(tf.cast(self.pr * freqs_resolution, tf.int64))
self.sess = sess
self._get_freqs = None
示例14: prune_small_boxes
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def prune_small_boxes(boxlist, min_side, scope=None):
"""Prunes small boxes in the boxlist which have a side smaller than min_side.
Args:
boxlist: BoxList holding N boxes.
min_side: Minimum width AND height of box to survive pruning.
scope: name scope.
Returns:
A pruned boxlist.
"""
with tf.name_scope(scope, 'PruneSmallBoxes'):
height, width = height_width(boxlist)
is_valid = tf.logical_and(tf.greater_equal(width, min_side),
tf.greater_equal(height, min_side))
return gather(boxlist, tf.reshape(tf.where(is_valid), [-1]))
示例15: get_random_labels_tf
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_random_labels_tf(self, minibatch_size): # => labels
if self.label_size > 0:
return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
else:
return tf.zeros([minibatch_size, 0], self.label_dtype)
# Get random labels as NumPy array.