当前位置: 首页>>代码示例>>Python>>正文


Python v1.gather_nd方法代码示例

本文整理汇总了Python中tensorflow.compat.v1.gather_nd方法的典型用法代码示例。如果您正苦于以下问题:Python v1.gather_nd方法的具体用法?Python v1.gather_nd怎么用?Python v1.gather_nd使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.compat.v1的用法示例。


在下文中一共展示了v1.gather_nd方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: remove

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def remove(self, x):
    """Remove padding from the given tensor.

    Args:
      x (tf.Tensor): of shape [dim_origin,...]

    Returns:
      a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin
    """
    with tf.name_scope("pad_reduce/remove"):
      x_shape = x.get_shape().as_list()
      x = tf.gather_nd(
          x,
          indices=self.nonpad_ids,
      )
      if not tf.executing_eagerly():
        # This is a hack but for some reason, gather_nd return a tensor of
        # undefined shape, so the shape is set up manually
        x.set_shape([None] + x_shape[1:])
    return x 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:22,代码来源:expert_utils.py

示例2: argmax_with_score

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def argmax_with_score(logits, axis=None):
  """Argmax along with the value."""
  axis = axis or len(logits.get_shape()) - 1
  predictions = tf.argmax(logits, axis=axis)

  logits_shape = shape_list(logits)
  prefix_shape, vocab_size = logits_shape[:-1], logits_shape[-1]
  prefix_size = 1
  for d in prefix_shape:
    prefix_size *= d

  # Flatten to extract scores
  flat_logits = tf.reshape(logits, [prefix_size, vocab_size])
  flat_predictions = tf.reshape(predictions, [prefix_size])
  flat_indices = tf.stack(
      [tf.range(tf.to_int64(prefix_size)),
       tf.to_int64(flat_predictions)],
      axis=1)
  flat_scores = tf.gather_nd(flat_logits, flat_indices)

  # Unflatten
  scores = tf.reshape(flat_scores, prefix_shape)

  return predictions, scores 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:26,代码来源:common_layers.py

示例3: resize_and_crop_boxes

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def resize_and_crop_boxes(self):
    """Resize boxes and crop it to the self._output dimension."""
    boxlist = preprocessor.box_list.BoxList(self._boxes)
    boxes = preprocessor.box_list_scale(
        boxlist, self._scaled_height, self._scaled_width).get()
    # Adjust box coordinates based on the offset.
    box_offset = tf.stack([self._crop_offset_y, self._crop_offset_x,
                           self._crop_offset_y, self._crop_offset_x,])
    boxes -= tf.cast(tf.reshape(box_offset, [1, 4]), tf.float32)
    # Clip the boxes.
    boxes = self.clip_boxes(boxes)
    # Filter out ground truth boxes that are all zeros.
    indices = tf.where(tf.not_equal(tf.reduce_sum(boxes, axis=1), 0))
    boxes = tf.gather_nd(boxes, indices)
    classes = tf.gather_nd(self._classes, indices)
    return boxes, classes 
开发者ID:JunweiLiang,项目名称:Object_Detection_Tracking,代码行数:18,代码来源:dataloader.py

示例4: _build_train_op

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def _build_train_op(self):
    """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
    actions = self._replay.actions
    indices = tf.stack([tf.range(actions.shape[0]), actions], axis=-1)
    replay_chosen_q = tf.gather_nd(
        self._replay_net_outputs.q_heads, indices=indices)
    target = tf.stop_gradient(self._build_target_q_op())
    loss = tf.losses.huber_loss(
        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
    q_head_losses = tf.reduce_mean(loss, axis=0)
    final_loss = tf.reduce_mean(q_head_losses)
    if self.summary_writer is not None:
      with tf.variable_scope('Losses'):
        tf.summary.scalar('HuberLoss', final_loss)
    return self.optimizer.minimize(final_loss) 
开发者ID:google-research,项目名称:batch_rl,代码行数:21,代码来源:multi_head_dqn_agent.py

示例5: _build_target_distribution

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def _build_target_distribution(self):
    batch_size = tf.shape(self._replay.rewards)[0]
    # size of rewards: batch_size x 1
    rewards = self._replay.rewards[:, None]
    # size of tiled_support: batch_size x num_atoms

    is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
    # Incorporate terminal state to discount factor.
    # size of gamma_with_terminal: batch_size x 1
    gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
    gamma_with_terminal = gamma_with_terminal[:, None]

    # size of next_qt_argmax: 1 x batch_size
    next_qt_argmax = tf.argmax(
        self._replay_next_target_net_outputs.q_values, axis=1)[:, None]
    batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
    # size of next_qt_argmax: batch_size x 2
    batch_indexed_next_qt_argmax = tf.concat(
        [batch_indices, next_qt_argmax], axis=1)
    # size of next_logits (next quantiles): batch_size x num_atoms
    next_logits = tf.gather_nd(
        self._replay_next_target_net_outputs.logits,
        batch_indexed_next_qt_argmax)
    return rewards + gamma_with_terminal * next_logits 
开发者ID:google-research,项目名称:batch_rl,代码行数:26,代码来源:quantile_agent.py

示例6: sparse_dense_mul

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def sparse_dense_mul(sp_mat, dense_mat):
  """Element-wise multiplication between sparse and dense tensors.

  Returns a sparse tensor. Limited broadcasting of dense_mat is supported.
  If rank(dense_mat) < rank(sparse_mat), then dense_mat is broadcasted on the
  rightmost dimensions to match sparse_mat.

  Args:
    sp_mat: SparseTensor.
    dense_mat: DenseTensor with rank <= sp_mat.

  Returns:
    SparseTensor.
  """
  rank = dense_mat.get_shape().ndims
  indices = sp_mat.indices[:, :rank]
  dense_values = tf.gather_nd(dense_mat, indices)
  return tf.SparseTensor(sp_mat.indices, sp_mat.values * dense_values,
                         sp_mat.dense_shape) 
开发者ID:google-research,项目名称:language,代码行数:21,代码来源:model_fns.py

示例7: rpn_class_loss_graph

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def rpn_class_loss_graph(rpn_match, rpn_class_logits):
    """RPN anchor classifier loss.

    rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
               -1=negative, 0=neutral anchor.
    rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for BG/FG.
    """
    # Squeeze last dim to simplify
    rpn_match = tf.squeeze(rpn_match, -1)
    # Get anchor classes. Convert the -1/+1 match to 0/1 values.
    anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)
    # Positive and Negative anchors contribute to the loss,
    # but neutral anchors (match value = 0) don't.
    indices = tf.where(K.not_equal(rpn_match, 0))
    # Pick rows that contribute to the loss and filter out the rest.
    rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
    anchor_class = tf.gather_nd(anchor_class, indices)
    # Cross entropy loss
    loss = K.sparse_categorical_crossentropy(target=anchor_class,
                                             output=rpn_class_logits,
                                             from_logits=True)
    loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
    return loss 
开发者ID:OCR-D,项目名称:ocrd_anybaseocr,代码行数:25,代码来源:model.py

示例8: get_batch_predictions_from_indices

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def get_batch_predictions_from_indices(batch_predictions, indices):
  """Gets the values of predictions in a batch at the given indices.

  The indices are expected to come from the offset targets generation functions
  in this library. The returned value is intended to be used inside a loss
  function.

  Args:
    batch_predictions: A tensor of shape [batch_size, height, width, channels]
      or [batch_size, height, width, class, channels] for class-specific
      features (e.g. keypoint joint offsets).
    indices: A tensor of shape [num_instances, 3] for single class features or
      [num_instances, 4] for multiple classes features.

  Returns:
    values: A tensor of shape [num_instances, channels] holding the predicted
      values at the given indices.
  """
  return tf.gather_nd(batch_predictions, indices) 
开发者ID:tensorflow,项目名称:models,代码行数:21,代码来源:target_assigner.py

示例9: extract_random_video_patch

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def extract_random_video_patch(videos, num_frames=-1):
  """For every video, extract a random consecutive patch of num_frames.

  Args:
    videos: 5-D Tensor, (NTHWC)
    num_frames: Integer, if -1 then the entire video is returned.
  Returns:
    video_patch: 5-D Tensor, (NTHWC) with T = num_frames.
  Raises:
    ValueError: If num_frames is greater than the number of total frames in
                the video.
  """
  if num_frames == -1:
    return videos
  batch_size, num_total_frames, h, w, c = common_layers.shape_list(videos)
  if num_total_frames < num_frames:
    raise ValueError("Expected num_frames <= %d, got %d" %
                     (num_total_frames, num_frames))

  # Randomly choose start_inds for each video.
  frame_start = tf.random_uniform(
      shape=(batch_size,), minval=0, maxval=num_total_frames - num_frames + 1,
      dtype=tf.int32)

  # [start[0], start[0] + 1, ... start[0] + num_frames - 1] + ...
  # [start[batch_size-1], ... start[batch_size-1] + num_frames - 1]
  range_inds = tf.expand_dims(tf.range(num_frames), axis=0)
  frame_inds = range_inds + tf.expand_dims(frame_start, axis=1)
  frame_inds = tf.reshape(frame_inds, [-1])

  # [0]*num_frames + [1]*num_frames + ... [batch_size-1]*num_frames
  batch_inds = tf.expand_dims(tf.range(batch_size), axis=1)
  batch_inds = tf.tile(batch_inds, [1, num_frames])
  batch_inds = tf.reshape(batch_inds, [-1])

  gather_inds = tf.stack((batch_inds, frame_inds), axis=1)
  video_patches = tf.gather_nd(videos, gather_inds)
  return tf.reshape(video_patches, (batch_size, num_frames, h, w, c)) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:40,代码来源:common_video.py

示例10: _batch_slice

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def _batch_slice(self, ary, start_ijk, w, batch_size):
    """Batched slicing of original grid.

    Args:
      ary: tensor, rank = 3.
      start_ijk: [batch_size, 3] tensor, starting index.
      w: width of cube to extract.
      batch_size: int, batch size.

    Returns:
      batched_slices: [batch_size, w, w, w] tensor, batched slices of ary.
    """
    batch_size = start_ijk.shape[0]
    ijk = tf.range(w, dtype=tf.int32)
    slice_idx = tf.meshgrid(ijk, ijk, ijk, indexing='ij')
    slice_idx = tf.stack(
        slice_idx, axis=-1)  # [in_grid_res, in_grid_res, in_grid_res, 3]
    slice_idx = tf.broadcast_to(slice_idx[tf.newaxis], [batch_size, w, w, w, 3])
    offset = tf.broadcast_to(
        start_ijk[:, tf.newaxis, tf.newaxis, tf.newaxis, :],
        [batch_size, w, w, w, 3])
    slice_idx += offset
    # [batch_size, in_grid_res, in_grid_res, in_grid_res, 3]
    batched_slices = tf.gather_nd(ary, slice_idx)
    # [batch_size, in_grid_res, in_grid_res, in_grid_res]
    return batched_slices 
开发者ID:tensorflow,项目名称:graphics,代码行数:28,代码来源:evaluator.py

示例11: get_final

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def get_final(sequence, sequence_length, time_major=True):
  """Get the final item in a batch of sequences."""
  final_index = _get_final_index(sequence_length, time_major)
  return tf.gather_nd(sequence, final_index) 
开发者ID:magenta,项目名称:magenta,代码行数:6,代码来源:lstm_utils.py

示例12: flatten_maybe_padded_sequences

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def flatten_maybe_padded_sequences(maybe_padded_sequences, lengths=None):
  """Flattens the batch of sequences, removing padding (if applicable).

  Args:
    maybe_padded_sequences: A tensor of possibly padded sequences to flatten,
        sized `[N, M, ...]` where M = max(lengths).
    lengths: Optional length of each sequence, sized `[N]`. If None, assumes no
        padding.

  Returns:
     flatten_maybe_padded_sequences: The flattened sequence tensor, sized
         `[sum(lengths), ...]`.
  """
  def flatten_unpadded_sequences():
    # The sequences are equal length, so we should just flatten over the first
    # two dimensions.
    return tf.reshape(maybe_padded_sequences,
                      [-1] + maybe_padded_sequences.shape.as_list()[2:])

  if lengths is None:
    return flatten_unpadded_sequences()

  def flatten_padded_sequences():
    indices = tf.where(tf.sequence_mask(lengths))
    return tf.gather_nd(maybe_padded_sequences, indices)

  return tf.cond(
      tf.equal(tf.reduce_min(lengths), tf.shape(maybe_padded_sequences)[1]),
      flatten_unpadded_sequences,
      flatten_padded_sequences) 
开发者ID:magenta,项目名称:magenta,代码行数:32,代码来源:sequence_example_lib.py

示例13: _build_train_op

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def _build_train_op(self):
    """Builds a training op.

    Returns:
      train_op: An op performing one step of training from replay data.
    """
    actions = self._replay.actions
    indices = tf.stack([tf.range(actions.shape[0]), actions], axis=-1)
    replay_chosen_q = tf.gather_nd(
        self._replay_net_outputs.q_networks, indices=indices)
    target = tf.stop_gradient(self._build_target_q_op())
    loss = tf.losses.huber_loss(
        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
    q_head_losses = tf.reduce_mean(loss, axis=0)
    final_loss = tf.reduce_mean(q_head_losses)
    if self.summary_writer is not None:
      with tf.variable_scope('Losses'):
        tf.summary.scalar('HuberLoss', final_loss)
    self.optimizers = [copy.deepcopy(self.optimizer) for _ in
                       range(self.num_networks)]
    train_ops = []
    for i in range(self.num_networks):
      var_list = tf.trainable_variables(scope='Online/subnet_{}'.format(i))
      train_op = self.optimizers[i].minimize(final_loss, var_list=var_list)
      train_ops.append(train_op)
    return tf.group(*train_ops, name='merged_train_op') 
开发者ID:google-research,项目名称:batch_rl,代码行数:28,代码来源:multi_network_dqn_agent.py

示例14: _build

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def _build(self, modules):
    if not (self.collapse and
            isinstance(modules[-1], verifiable_wrapper.LinearFCWrapper)):
      logging.info('Elision of last layer disabled.')
      bounds = modules[-1].output_bounds
      bounds = bounds_lib.IntervalBounds.convert(bounds)
      correct_class_logit = tf.gather_nd(bounds.lower, self._correct_idx)
      wrong_class_logits = tf.gather_nd(bounds.upper, self._wrong_idx)
      return wrong_class_logits - tf.expand_dims(correct_class_logit, 1)

    logging.info('Elision of last layer active.')
    bounds = modules[-1].input_bounds
    bounds = bounds_lib.IntervalBounds.convert(bounds)
    batch_size = tf.shape(bounds.lower)[0]
    w = modules[-1].module.w
    b = modules[-1].module.b
    w_t = tf.tile(tf.expand_dims(tf.transpose(w), 0), [batch_size, 1, 1])
    b_t = tf.tile(tf.expand_dims(b, 0), [batch_size, 1])
    w_correct = tf.expand_dims(tf.gather_nd(w_t, self._correct_idx), -1)
    b_correct = tf.expand_dims(tf.gather_nd(b_t, self._correct_idx), 1)
    w_wrong = tf.transpose(tf.gather_nd(w_t, self._wrong_idx), [0, 2, 1])
    b_wrong = tf.gather_nd(b_t, self._wrong_idx)
    w = w_wrong - w_correct
    b = b_wrong - b_correct
    # Maximize z * w + b s.t. lower <= z <= upper.
    c = (bounds.lower + bounds.upper) / 2.
    r = (bounds.upper - bounds.lower) / 2.
    c = tf.einsum('ij,ijk->ik', c, w)
    if b is not None:
      c += b
    r = tf.einsum('ij,ijk->ik', r, tf.abs(w))
    return c + r 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:34,代码来源:specification.py

示例15: evaluate

# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import gather_nd [as 别名]
def evaluate(self, logits):
    if len(logits.shape) == 2:
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.expand_dims(correct_class_logit, -1)
      wrong_class_logits = tf.gather_nd(logits, self._wrong_idx)
    elif len(logits.shape) == 3:
      # [num_restarts, batch_size, num_classes] to
      # [num_restarts, batch_size, num_specs]
      logits = tf.transpose(logits, [1, 2, 0])  # Put restart dimension last.
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.transpose(correct_class_logit)
      correct_class_logit = tf.expand_dims(correct_class_logit, -1)
      wrong_class_logits = tf.gather_nd(logits, self._wrong_idx)
      wrong_class_logits = tf.transpose(wrong_class_logits, [2, 0, 1])
    else:
      assert len(logits.shape) == 4
      # [num_restarts, num_specs, batch_size, num_classes] to
      # [num_restarts, batch_size, num_specs].
      logits = tf.transpose(logits, [2, 3, 1, 0])
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.transpose(correct_class_logit, [2, 0, 1])
      batch_size = tf.shape(logits)[0]
      wrong_idx = tf.concat([
          self._wrong_idx,
          tf.tile(tf.reshape(tf.range(self.num_specifications, dtype=tf.int32),
                             [1, self.num_specifications, 1]),
                  [batch_size, 1, 1])], axis=-1)
      wrong_class_logits = tf.gather_nd(logits, wrong_idx)
      wrong_class_logits = tf.transpose(wrong_class_logits, [2, 0, 1])
    return wrong_class_logits - correct_class_logit 
开发者ID:deepmind,项目名称:interval-bound-propagation,代码行数:32,代码来源:specification.py


注:本文中的tensorflow.compat.v1.gather_nd方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。