当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.gather方法代码示例

本文整理汇总了Python中tensorflow.gather方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.gather方法的具体用法?Python tensorflow.gather怎么用?Python tensorflow.gather使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.gather方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_hint_pool_idxs

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_hint_pool_idxs(self, normalized_query):
    """Get small set of idxs to compute nearest neighbor queries on.

    This is an expensive look-up on the whole memory that is used to
    avoid more expensive operations later on.

    Args:
      normalized_query: A Tensor of shape [None, key_dim].

    Returns:
      A Tensor of shape [None, choose_k] of indices in memory
      that are closest to the queries.

    """
    # get hash of query vecs
    hash_slot_idxs = self.get_hash_slots(normalized_query)

    # grab mem idxs in the hash slots
    hint_pool_idxs = [
        tf.maximum(tf.minimum(
            tf.gather(self.hash_slots[i], idxs),
            self.memory_size - 1), 0)
        for i, idxs in enumerate(hash_slot_idxs)]

    return tf.concat(axis=1, values=hint_pool_idxs) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:27,代码来源:memory.py

示例2: scheduled_sample

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
  """Sample batch with specified mix of ground truth and generated data points.

  Args:
    ground_truth_x: tensor of ground-truth data points.
    generated_x: tensor of generated data points.
    batch_size: batch size
    num_ground_truth: number of ground-truth examples to include in batch.
  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
  idx = tf.random_shuffle(tf.range(int(batch_size)))
  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  generated_examps = tf.gather(generated_x, generated_idx)
  return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                           [ground_truth_examps, generated_examps]) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:22,代码来源:prediction_model.py

示例3: build_cross_entropy_loss

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def build_cross_entropy_loss(logits, gold):
  """Constructs a cross entropy from logits and one-hot encoded gold labels.

  Supports skipping rows where the gold label is the magic -1 value.

  Args:
    logits: float Tensor of scores.
    gold: int Tensor of one-hot labels.

  Returns:
    cost, correct, total: the total cost, the total number of correctly
        predicted labels, and the total number of valid labels.
  """
  valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  gold = tf.gather(gold, valid)
  logits = tf.gather(logits, valid)
  correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  total = tf.size(gold)
  cost = tf.reduce_sum(
      tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
          logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
  return cost, correct, total 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:24,代码来源:bulk_component.py

示例4: clip_tensor

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def clip_tensor(t, length):
  """Clips the input tensor along the first dimension up to the length.

  Args:
    t: the input tensor, assuming the rank is at least 1.
    length: a tensor of shape [1]  or an integer, indicating the first dimension
      of the input tensor t after clipping, assuming length <= t.shape[0].

  Returns:
    clipped_t: the clipped tensor, whose first dimension is length. If the
      length is an integer, the first dimension of clipped_t is set to length
      statically.
  """
  clipped_t = tf.gather(t, tf.range(length))
  if not _is_tensor(length):
    clipped_t = _set_dim_0(clipped_t, length)
  return clipped_t 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:19,代码来源:shape_utils.py

示例5: filter_field_value_equals

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def filter_field_value_equals(boxlist, field, value, scope=None):
  """Filter to keep only boxes with field entries equal to the given value.

  Args:
    boxlist: BoxList holding N boxes.
    field: field name for filtering.
    value: scalar value.
    scope: name scope.

  Returns:
    a BoxList holding M boxes where M <= N

  Raises:
    ValueError: if boxlist not a BoxList object or if it does not have
      the specified field.
  """
  with tf.name_scope(scope, 'FilterFieldValueEquals'):
    if not isinstance(boxlist, box_list.BoxList):
      raise ValueError('boxlist must be a BoxList')
    if not boxlist.has_field(field):
      raise ValueError('boxlist must contain the specified field')
    filter_field = boxlist.get_field(field)
    gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
    return gather(boxlist, gather_index) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:26,代码来源:box_list_ops.py

示例6: memory_run

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def memory_run(step, nmaps, mem_size, batch_size, vocab_size,
               global_step, do_training, update_mem, decay_factor, num_gpus,
               target_emb_weights, output_w, gpu_targets_tn, it):
  """Run memory."""
  q = step[:, 0, it, :]
  mlabels = gpu_targets_tn[:, it, 0]
  res, mask, mem_loss = memory_call(
      q, mlabels, nmaps, mem_size, vocab_size, num_gpus, update_mem)
  res = tf.gather(target_emb_weights, res) * tf.expand_dims(mask[:, 0], 1)

  # Mix gold and original in the first steps, 20% later.
  gold = tf.nn.dropout(tf.gather(target_emb_weights, mlabels), 0.7)
  use_gold = 1.0 - tf.cast(global_step, tf.float32) / (1000. * decay_factor)
  use_gold = tf.maximum(use_gold, 0.2) * do_training
  mem = tf.cond(tf.less(tf.random_uniform([]), use_gold),
                lambda: use_gold * gold + (1.0 - use_gold) * res,
                lambda: res)
  mem = tf.reshape(mem, [-1, 1, 1, nmaps])
  return mem, mem_loss, update_mem 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:21,代码来源:neural_gpu.py

示例7: batch_transformer

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
    """Batch Spatial Transformer Layer

    Parameters
    ----------

    U : float
        tensor of inputs [num_batch,height,width,num_channels]
    thetas : float
        a set of transformations for each input [num_batch,num_transforms,6]
    out_size : int
        the size of the output [out_height,out_width]

    Returns: float
        Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
    """
    with tf.variable_scope(name):
        num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
        indices = [[i]*num_transforms for i in xrange(num_batch)]
        input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
        return transformer(input_repeated, thetas, out_size) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:23,代码来源:spatial_transformer.py

示例8: data

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def data(self, rows=None):
    """Access a batch of episodes from the memory.

    Padding elements after the length of each episode are unspecified and might
    contain old data.

    Args:
      rows: Episodes to select, defaults to all.

    Returns:
      Tuple containing a tuple of transition quantiries with batch and time
      dimensions, and a batch of sequence lengths.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    episode = [tf.gather(buffer_, rows) for buffer_ in self._buffers]
    length = tf.gather(self._length, rows)
    return episode, length 
开发者ID:utra-robosoccer,项目名称:soccer-matlab,代码行数:20,代码来源:memory.py

示例9: scheduled_sample

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def scheduled_sample(self,
                       ground_truth_x,
                       generated_x,
                       batch_size,
                       num_ground_truth):
    """Sample batch with specified mix of groundtruth and generated data points.

    Args:
      ground_truth_x: tensor of ground-truth data points.
      generated_x: tensor of generated data points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    idx = tf.random_shuffle(tf.range(batch_size))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps]) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:26,代码来源:next_frame.py

示例10: convert_gradient_to_tensor

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def convert_gradient_to_tensor(x):
  """Identity operation whose gradient is converted to a `Tensor`.

  Currently, the gradient to `tf.concat` is particularly expensive to
  compute if dy is an `IndexedSlices` (a lack of GPU implementation
  forces the gradient operation onto CPU).  This situation occurs when
  the output of the `tf.concat` is eventually passed to `tf.gather`.
  It is sometimes faster to convert the gradient to a `Tensor`, so as
  to get the cheaper gradient for `tf.concat`.  To do this, replace
  `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`.

  Args:
    x: A `Tensor`.

  Returns:
    The input `Tensor`.
  """
  return x 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:20,代码来源:common_layers.py

示例11: add_positional_embedding

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def add_positional_embedding(x, max_length, name, positions=None):
  """Add positional embedding.

  Args:
    x: a Tensor with shape [batch, length, depth]
    max_length: an integer.  static maximum size of any dimension.
    name: a name for this layer.
    positions: an optional tensor with shape [batch, length]

  Returns:
    a Tensor the same shape as x.
  """
  _, length, depth = common_layers.shape_list(x)
  var = tf.cast(tf.get_variable(name, [max_length, depth]), x.dtype)
  if positions is None:
    sliced = tf.cond(
        tf.less(length, max_length),
        lambda: tf.slice(var, [0, 0], [length, -1]),
        lambda: tf.pad(var, [[0, length - max_length], [0, 0]]))
    return x + tf.expand_dims(sliced, 0)
  else:
    return x + tf.gather(var, tf.to_int32(positions)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:24,代码来源:common_attention.py

示例12: get_shifted_center_blocks

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_shifted_center_blocks(x, indices):
  """Get right shifted blocks for masked local attention 2d.

  Args:
    x: A tensor with shape [batch, heads, height, width, depth]
    indices: The indices to gather blocks

  Returns:
    x_shifted: a tensor of extracted blocks, each block right shifted along
      length.
  """
  center_x = gather_blocks_2d(x, indices)

  # Shift right along the length dimension
  def shift_right_2d_blocks(x):
    """Shift the second to last dimension of x right by one."""
    shifted_targets = (
        tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :, :-1, :])
    return shifted_targets

  x_shifted = shift_right_2d_blocks(center_x)
  return x_shifted 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:24,代码来源:common_attention.py

示例13: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def __init__(self, pc: _Network3D, config, centers, sess, freqs_resolution=1e9):
        """
        :param sess: Must be set at the latest before using get_pr or get_freqs
        """
        self.pc_class = pc.__class__
        self.config = config
        self.input_ctx_shape = self.pc_class.get_context_shape(config)
        self.input_ctx = tf.placeholder(tf.int64, self.input_ctx_shape)  # symbols!
        input_ctx_batched = tf.expand_dims(self.input_ctx, 0)  # add batch dimension, 1DHW
        input_ctx_batched = tf.expand_dims(input_ctx_batched, -1)  # add T dimension for 3d conv, now 1CHW1
        # Here, in contrast to pc.bitcost(...), q does not need to be padded, as it is part of some context.
        # Logits will be a 1111L vector, i.e., prediction of the next pixel
        q = tf.gather(centers, input_ctx_batched)
        logits = pc.logits(q, is_training=False)
        self.pr = tf.nn.softmax(logits)
        self.freqs = tf.squeeze(tf.cast(self.pr * freqs_resolution, tf.int64))
        self.sess = sess

        self._get_freqs = None 
开发者ID:fab-jul,项目名称:imgcomp-cvpr,代码行数:21,代码来源:probclass.py

示例14: prune_small_boxes

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def prune_small_boxes(boxlist, min_side, scope=None):
  """Prunes small boxes in the boxlist which have a side smaller than min_side.

  Args:
    boxlist: BoxList holding N boxes.
    min_side: Minimum width AND height of box to survive pruning.
    scope: name scope.

  Returns:
    A pruned boxlist.
  """
  with tf.name_scope(scope, 'PruneSmallBoxes'):
    height, width = height_width(boxlist)
    is_valid = tf.logical_and(tf.greater_equal(width, min_side),
                              tf.greater_equal(height, min_side))
    return gather(boxlist, tf.reshape(tf.where(is_valid), [-1])) 
开发者ID:datitran,项目名称:object_detector_app,代码行数:18,代码来源:box_list_ops.py

示例15: get_random_labels_tf

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import gather [as 别名]
def get_random_labels_tf(self, minibatch_size): # => labels
        if self.label_size > 0:
            return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))
        else:
            return tf.zeros([minibatch_size, 0], self.label_dtype)

    # Get random labels as NumPy array. 
开发者ID:zalandoresearch,项目名称:disentangling_conditional_gans,代码行数:9,代码来源:dataset.py


注:本文中的tensorflow.gather方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。