当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.gather_nd函数代码示例

本文整理汇总了Python中tensorflow.gather_nd函数的典型用法代码示例。如果您正苦于以下问题:Python gather_nd函数的具体用法?Python gather_nd怎么用?Python gather_nd使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了gather_nd函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testEmptyIndicesAndParamsOKButJustEmptyParamsFails

  def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
    with self.test_session(use_gpu=self.use_gpu):
      params = np.ones((3, 3), dtype=np.float32)

      indices_empty = np.empty((0, 2), dtype=np.int32)
      gather_nd_ok_t = tf.gather_nd(params, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0], gather_nd_ok_t.get_shape())
      self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)

      indices_empty = np.empty((0, 1), dtype=np.int32)
      gather_nd_ok_t = tf.gather_nd(params, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0, 3], gather_nd_ok_t.get_shape())
      self.assertAllEqual(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val)

      params_empty = np.empty((0, 3), dtype=np.float32)
      indices_empty = np.empty((0, 2), dtype=np.int32)
      gather_nd_ok_t = tf.gather_nd(params_empty, indices_empty)
      gather_nd_ok_val = gather_nd_ok_t.eval()
      self.assertEqual([0], gather_nd_ok_t.get_shape())
      self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)

      params_empty = np.empty((0, 3), dtype=np.float32)
      indices_nonempty = np.zeros((1, 2), dtype=np.int32)
      gather_nd_break_t = tf.gather_nd(params_empty, indices_nonempty)
      with self.assertRaisesOpError(
          r"Requested more than 0 entries, but params is empty."):
        gather_nd_break_t.eval()
      self.assertAllEqual(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:30,代码来源:gather_nd_op_test.py

示例2: conv3d_oneToMany

def conv3d_oneToMany(x, xShape, w, wShape, strideT, strideY, strideX, inName):
    [ntp, nyp, nxp, nifp, nofp] = wShape
    [nb, nt, ny, nx, nf] = xShape

    # stride must be divisible by both weights and input
    assert ntp % strideT == 0
    assert nyp % strideY == 0
    assert nxp % strideX == 0
    assert nt % strideT == 0
    assert ny % strideY == 0
    assert nx % strideX == 0

    assert nifp == nf

    print "Building weight indices for conv3d"
    # Build gather indices for weights
    # Must be in shape of target output weights
    weightIdxs = np.zeros(
        (int(ntp / strideT), int(nyp / strideY), int(nxp / strideX), nifp, nofp * strideT * strideX * strideY, 5)
    ).astype(np.int32)
    # Adding kernel number to end of features
    for itp in range(ntp):
        for iyp in range(nyp):
            for ixp in range(nxp):
                for iifp in range(nifp):
                    for iofp in range(nofp):
                        # Calculate output indices given input indices
                        # Must reverse, as we're using conv2d as transpose conv2d
                        otp = int((ntp - itp - 1) / strideT)
                        oyp = int((nyp - iyp - 1) / strideY)
                        oxp = int((nxp - ixp - 1) / strideX)
                        oifp = iifp  # Input features stay the same
                        # oofp uses iofp as offset, plus an nf stride based on which kernel it belongs to
                        kernelIdx = (itp % strideT) * strideY * strideX + (iyp % strideY) * strideX + (ixp % strideX)
                        oofp = iofp + nofp * kernelIdx
                        weightIdxs[otp, oyp, oxp, oifp, oofp, :] = [itp, iyp, ixp, iifp, iofp]

    print "Building output indices for conv3d"
    # Build gather indices for output
    # Must be in shape of target output data
    dataIdxs = np.zeros((nb, nt * strideT, ny * strideY, nx * strideX, nofp, 5)).astype(np.int32)
    for oob in range(nb):
        for oot in range(nt * strideT):
            for ooy in range(ny * strideY):
                for oox in range(nx * strideX):
                    for oof in range(nofp):
                        # Calculate input indices given output indices
                        iib = oob
                        iit = oot / strideT
                        iiy = ooy / strideY
                        iix = oox / strideX
                        kernelIdx = (oot % strideT) * strideY * strideX + (ooy % strideY) * strideX + (oox % strideX)
                        iif = oof + nofp * kernelIdx
                        dataIdxs[oob, oot, ooy, oox, oof, :] = [iib, iit, iiy, iix, iif]

    # Build convolution structure
    w_reshape = tf.gather_nd(w, weightIdxs)
    o_reshape = tf.nn.conv3d(x, w_reshape, strides=[1, 1, 1, 1, 1], padding="SAME", name=inName)
    o = tf.gather_nd(o_reshape, dataIdxs)
    return o
开发者ID:slundqui,项目名称:TFSparseCode,代码行数:60,代码来源:utils.py

示例3: parse_sequence_to_pairs_batch

def parse_sequence_to_pairs_batch(
    serialized_example, preprocess_fn, is_training, num_views, batch_size,
    window):
  """Parses a serialized sequence example into a batch of preprocessed data.

  Args:
    serialized_example: A serialized SequenceExample.
    preprocess_fn: A function with the signature (raw_images, is_training) ->
      preprocessed_images.
    is_training: Boolean, whether or not we're in training.
    num_views: Int, the number of simultaneous viewpoints at each timestep in
      the dataset.
    batch_size: Int, size of the batch to get.
    window: Int, only take pairs from a maximium window of this size.
  Returns:
    preprocessed: A 4-D float32 `Tensor` holding preprocessed images.
    anchor_images: A 4-D float32 `Tensor` holding raw anchor images.
    pos_images: A 4-D float32 `Tensor` holding raw positive images.
  """
  _, views, seq_len = parse_sequence_example(serialized_example, num_views)

  # Get random (anchor, positive) timestep and viewpoint indices.
  num_pairs = batch_size // 2
  ap_time_indices, a_view_indices, p_view_indices = get_tcn_anchor_pos_indices(
      seq_len, num_views, num_pairs, window)

  # Gather the image strings.
  combined_anchor_indices = tf.concat(
      [tf.expand_dims(a_view_indices, 1),
       tf.expand_dims(ap_time_indices, 1)], 1)
  combined_pos_indices = tf.concat(
      [tf.expand_dims(p_view_indices, 1),
       tf.expand_dims(ap_time_indices, 1)], 1)
  anchor_images = tf.gather_nd(views, combined_anchor_indices)
  pos_images = tf.gather_nd(views, combined_pos_indices)

  # Decode images.
  anchor_images = tf.map_fn(
      preprocessing.decode_image, anchor_images, dtype=tf.float32)
  pos_images = tf.map_fn(
      preprocessing.decode_image, pos_images, dtype=tf.float32)

  # Concatenate [anchor, postitive] images into a batch and preprocess it.
  concatenated = tf.concat([anchor_images, pos_images], 0)
  preprocessed = preprocess_fn(concatenated, is_training)
  anchor_prepro, positive_prepro = tf.split(preprocessed, num_or_size_splits=2,
                                            axis=0)

  # Set static batch dimensions for all image tensors
  ims = [anchor_prepro, positive_prepro, anchor_images, pos_images]
  ims = [set_image_tensor_batch_dim(i, num_pairs) for i in ims]
  [anchor_prepro, positive_prepro, anchor_images, pos_images] = ims

  # Assign each anchor and positive the same label.
  anchor_labels = tf.range(1, num_pairs+1)
  positive_labels = tf.range(1, num_pairs+1)

  return (anchor_prepro, positive_prepro, anchor_images, pos_images,
          anchor_labels, positive_labels, seq_len)
开发者ID:danabo,项目名称:models,代码行数:59,代码来源:data_providers.py

示例4: _get_coordinatewise_learning_rate

  def _get_coordinatewise_learning_rate(self, grad, var):
    # Compute the learning rate using a moving average for the diagonal of BB^T
    avg_first = self.get_slot(var, 'first_moment')
    avg_second = self.get_slot(var, 'second_moment')
    decay_tensor = tf.cast(self._decay_tensor, var.dtype)
    batch_size = tf.cast(self._batch_size_tensor, var.dtype)

    # Create an estimator for the moving average of gradient mean and variance
    # via Welford's algorithm
    if isinstance(grad, tf.Tensor):
      delta = grad - avg_first
      first_moment_update = avg_first.assign_add(
          delta * tf.where(self._counter < 1,
                           tf.cast(1, var.dtype),
                           1. - decay_tensor))

      with tf.control_dependencies([first_moment_update]):
        second_moment_update = avg_second.assign_add(
            tf.cast(self._counter < 1, var.dtype) *
            -(1. - decay_tensor) * (
                avg_second - decay_tensor  * tf.square(delta)))
      diag_preconditioner = control_flow_ops.with_dependencies(
          [second_moment_update],
          tf.clip_by_value(avg_second, 1e-12, 1e12))
    elif isinstance(grad, tf.IndexedSlices):
      delta = grad.values - tf.gather_nd(avg_first, grad.indices)
      first_moment_update = tf.scatter_add(
          avg_first,
          grad.indices,
          delta * tf.where(self._counter < 1,
                           tf.cast(1., var.dtype),
                           1. - decay_tensor))

      with tf.control_dependencies([first_moment_update]):
        avg_second = tf.scatter_add(
            avg_second,
            grad.indices,
            tf.cast(self._counter < 1, var.dtype) *
            -(1. - decay_tensor) * (
                tf.gather_nd(avg_second, grad.indices) - decay_tensor *
                tf.square(delta)))
        avg_second = tf.gather_nd(avg_second, grad.indices)
        # TODO(b/70783772): Needs dtype specific clipping.
        diag_preconditioner = tf.clip_by_value(avg_second, 1e-12, 1e12)
    else:
      raise tf.errors.InvalidArgumentError(
          None, None, 'grad must of type Tensor or IndexedSlice')

    diag_preconditioner *= batch_size

    if self._use_single_learning_rate:
      diag_preconditioner = tf.reduce_mean(diag_preconditioner)

    # From Theorem 2 Corollary 1 of Mandt et al. 2017
    return 2. * batch_size / (
        tf.cast(self._total_num_examples, var.dtype.base_dtype) *
        diag_preconditioner)
开发者ID:asudomoeva,项目名称:probability,代码行数:57,代码来源:variational_sgd.py

示例5: get_valid_logits_and_labels

def get_valid_logits_and_labels(annotation_batch_tensor,
                                logits_batch_tensor,
                                class_labels):
    labels_batch_tensor = get_labels_from_annotation_batch(annotation_batch_tensor=annotation_batch_tensor,
                                                           class_labels=class_labels)

    valid_batch_indices = get_valid_entries_indices_from_annotation_batch(
        annotation_batch_tensor=annotation_batch_tensor,
        class_labels=class_labels)

    valid_labels_batch_tensor = tf.gather_nd(params=labels_batch_tensor, indices=valid_batch_indices)

    valid_logits_batch_tensor = tf.gather_nd(params=logits_batch_tensor, indices=valid_batch_indices)

    return valid_labels_batch_tensor, valid_logits_batch_tensor
开发者ID:ruyi345,项目名称:Fully-convolutional-networks-TF,代码行数:15,代码来源:utils.py

示例6: batch_gather

def batch_gather(reference, indices):
    """
    C+P From Keras pull request https://github.com/keras-team/keras/pull/6377/files
    
    Batchwise gathering of row indices.

    The numpy equivalent is `reference[np.arange(batch_size), indices]`, where
    `batch_size` is the first dimension of the reference tensor.

    # Arguments
        reference: A tensor with ndim >= 2 of shape.
          (batch_size, dim1, dim2, ..., dimN)
        indices: A 1d integer tensor of shape (batch_size) satisfying
          0 <= i < dim2 for each element i.

    # Returns
        The selected tensor with shape (batch_size, dim2, ..., dimN).

    # Examples
        1. If reference is `[[3, 5, 7], [11, 13, 17]]` and indices is `[2, 1]`
        then the result is `[7, 13]`.

        2. If reference is
        ```
          [[[2, 3], [4, 5], [6, 7]],
           [[10, 11], [12, 13], [16, 17]]]
        ```
        and indices is `[2, 1]` then the result is `[[6, 7], [12, 13]]`.
    """
    batch_size = K.shape(reference)[0]
    indices = tf.stack([tf.range(batch_size), indices], axis=1)
    return tf.gather_nd(reference, indices)
开发者ID:ymcidence,项目名称:neuron,代码行数:32,代码来源:utils.py

示例7: fastrcnn_inference

    def fastrcnn_inference(self, image_shape2d,
                           rcnn_boxes, rcnn_label_logits, rcnn_box_logits):
        """
        Args:
            image_shape2d: h, w
            rcnn_boxes (nx4): the proposal boxes
            rcnn_label_logits (n):
            rcnn_box_logits (nx4):

        Returns:
            boxes (mx4):
            labels (m): each >= 1
        """
        label_probs = tf.nn.softmax(rcnn_label_logits, name='fastrcnn_all_probs')  # #proposal x #Class
        anchors = tf.tile(tf.expand_dims(rcnn_boxes, 1), [1, config.NUM_CLASS - 1, 1])   # #proposal x #Cat x 4
        decoded_boxes = decode_bbox_target(
            rcnn_box_logits /
            tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors)
        decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes')

        # indices: Nx2. Each index into (#proposal, #category)
        pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
        final_probs = tf.identity(final_probs, 'final_probs')
        final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
        final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
        return final_boxes, final_labels
开发者ID:wu-yy,项目名称:tensorpack,代码行数:26,代码来源:train.py

示例8: transpose5dWeight

def transpose5dWeight(w, wShape, strideT, strideY, strideX):
    print "Building weight indices for conv3d"
    # These shapes are in terms of the already strided values
    [ntp, nyp, nxp, nifp, nofp] = wShape
    # Translate to target output shape
    ntp *= strideT
    nyp *= strideY
    nxp *= strideX
    nofp = nofp / (strideT * strideX * strideY)

    # Build gather indices for weights
    # Must be in shape of target output weights
    weightIdxs = np.zeros((ntp, nyp, nxp, nifp, nofp, 5)).astype(np.int32)
    # Adding kernel number to end of features
    for otp in range(ntp):
        for oyp in range(nyp):
            for oxp in range(nxp):
                for oifp in range(nifp):
                    for oofp in range(nofp):
                        # Calculate output indices given input indices
                        # Must reverse, as we're using conv2d as transpose conv2d
                        # otp = int((ntp-itp-1)/strideT)
                        # oyp = int((nyp-iyp-1)/strideY)
                        # oxp = int((nxp-ixp-1)/strideX)
                        # oifp = iifp #Input features stay the same
                        itp = int((ntp - otp - 1) / strideT)
                        iyp = int((nyp - oyp - 1) / strideY)
                        ixp = int((nxp - oxp - 1) / strideX)
                        iifp = oifp
                        # oofp uses iofp as offset, plus an nf stride based on which kernel it belongs to
                        kernelIdx = (otp % strideT) * strideY * strideX + (oyp % strideY) * strideX + (oxp % strideX)
                        iofp = oofp + nofp * kernelIdx
                        weightIdxs[otp, oyp, oxp, oifp, oofp, :] = [itp, iyp, ixp, iifp, iofp]
    return tf.gather_nd(w, weightIdxs)
开发者ID:slundqui,项目名称:TFSparseCode,代码行数:34,代码来源:utils.py

示例9: calculate_outputs

    def calculate_outputs(self, x):
        h = lstm_layer(x, self.history_length, self.lstm_size, scope='lstm-1')
        h_final = time_distributed_dense_layer(h, 50, activation=tf.nn.relu, scope='dense-1')
        y_hat = tf.squeeze(time_distributed_dense_layer(h_final, 1, scope='dense2'), 2)

        final_temporal_idx = tf.stack([tf.range(tf.shape(self.history_length)[0]), self.history_length - 1], axis=1)
        self.final_states = tf.gather_nd(h_final, final_temporal_idx)
        self.final_predictions = tf.gather_nd(y_hat, final_temporal_idx)

        self.prediction_tensors = {
            'user_ids': self.user_id,
            'final_states': self.final_states,
            'predictions': self.final_predictions
        }

        return y_hat
开发者ID:dengminna,项目名称:instacart-basket-prediction,代码行数:16,代码来源:rnn_order_size.py

示例10: train_speech_to_text_network

def train_speech_to_text_network():
    logit = speech_to_text_network()

    # CTC loss
    indices = tf.where(tf.not_equal(tf.cast(Y, tf.float32), 0.))
    target = tf.SparseTensor(indices=indices, values=tf.gather_nd(Y, indices) - 1, shape=tf.cast(tf.shape(Y), tf.int64))
    loss = tf.nn.ctc_loss(logit, target, sequence_len, time_major=False)
    # optimizer
    lr = tf.Variable(0.001, dtype=tf.float32, trainable=False)
    optimizer = MaxPropOptimizer(learning_rate=lr, beta2=0.99)
    var_list = [t for t in tf.trainable_variables()]
    gradient = optimizer.compute_gradients(loss, var_list=var_list)
    optimizer_op = optimizer.apply_gradients(gradient)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables())

        for epoch in range(16):
            sess.run(tf.assign(lr, 0.001 * (0.97 ** epoch)))

            global pointer
            pointer = 0
            for batch in range(n_batch):
                batches_wavs, batches_labels = get_next_batches(batch_size)
                train_loss, _ = sess.run([loss, optimizer_op], feed_dict={X: batches_wavs, Y: batches_labels})
                print(epoch, batch, train_loss)
            if epoch % 5 == 0:
                saver.save(sess, 'speech.module', global_step=epoch)
开发者ID:luohuayong,项目名称:tensorflow,代码行数:30,代码来源:t15.py

示例11: testUnknownIndices

 def testUnknownIndices(self):
   params = tf.constant([[0, 1, 2]])
   indices = tf.placeholder(tf.int32)
   gather_nd_t = tf.gather_nd(params, indices)
   shape = gather_nd_t.get_shape()
   self.assertEqual(shape.ndims, None)
   self.assertEqual(shape[0].value, None)
开发者ID:0-T-0,项目名称:tensorflow,代码行数:7,代码来源:gather_nd_op_test.py

示例12: create_model

def create_model(input_shape, num_actions, model_name, create_network_fn, learning_rate):  # noqa: D103
    """Create the Q-network model."""
    with tf.name_scope(model_name):
        input_frames = tf.placeholder(tf.float32, [None, input_shape],
                                      name ='input_frames')
        q_network, network_parameters = create_network_fn(
            input_frames, input_shape, num_actions)

        mean_max_Q =tf.reduce_mean( tf.reduce_max(q_network, axis=[1]), name='mean_max_Q')

        Q_vector_indexes = tf.placeholder(tf.int32, [None, 2], name ='Q_vector_indexes')
        gathered_outputs = tf.gather_nd(q_network, Q_vector_indexes, name='gathered_outputs')

        y_ph = tf.placeholder(tf.float32, name='y_ph')
        loss = mean_huber_loss(y_ph, gathered_outputs)
        train_step = tf.train.RMSPropOptimizer(learning_rate,
            decay=RMSP_DECAY, momentum=RMSP_MOMENTUM, epsilon=RMSP_EPSILON).minimize(loss)

    model = {
        'q_network' : q_network,
        'input_frames' : input_frames,
        'Q_vector_indexes' : Q_vector_indexes,
        'y_ph' : y_ph,
        'train_step': train_step,
        'mean_max_Q' : mean_max_Q,
    }
    return model, network_parameters
开发者ID:codealphago,项目名称:melee-ai,代码行数:27,代码来源:dqn_atari.py

示例13: build_net

    def build_net(self):
        self.s = tf.placeholder(tf.float32, [None, self.n_features])
        self.s_ = tf.placeholder(tf.float32, [None, self.n_features])
        self.r = tf.placeholder(tf.float32, [None, ])
        self.a = tf.placeholder(tf.int32, [None, ])
 
        w_initializer = tf.random_normal_initializer(0., 0.3)
        b_initializer = tf.constant_initializer(0.1)
        # q_eval网络架构,输入状态属性,输出4种动作
        with tf.variable_scope('eval_net'):
            eval_layer = tf.layers.dense(self.s, 20, tf.nn.relu, kernel_initializer=w_initializer,
                                         bias_initializer=b_initializer, name='eval_layer')
            self.q_eval = tf.layers.dense(eval_layer, self.n_actions, kernel_initializer=w_initializer,
                                          bias_initializer=b_initializer, name='output_layer1')
        with tf.variable_scope('target_net'):
            target_layer = tf.layers.dense(self.s_, 20, tf.nn.relu, kernel_initializer=w_initializer,
                                           bias_initializer=b_initializer, name='target_layer')
            self.q_next = tf.layers.dense(target_layer, self.n_actions, kernel_initializer=w_initializer,
                                          bias_initializer=b_initializer, name='output_layer2')
        with tf.variable_scope('q_target'):
            # 计算期望价值,并使用stop_gradient函数将其不计算梯度,也就是当做常数对待
            self.q_target = tf.stop_gradient(self.r + self.gamma * tf.reduce_max(self.q_next, axis=1))
        with tf.variable_scope('q_eval'):
            # 将a的值对应起来,
            a_indices = tf.stack([tf.range(tf.shape(self.a)[0]), self.a], axis=1)
            self.q_eval_a = tf.gather_nd(params=self.q_eval, indices=a_indices)
        with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_a))
        with tf.variable_scope('train'):
            self.train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
开发者ID:wqw547243068,项目名称:wangqiwen,代码行数:30,代码来源:gym-dqn.py

示例14: killRegions

def killRegions(anchors, image_attr, axis=-1):
    """ Prune the anchors so that only those entirely within the image remain

    This function is the RPN-training analog of clipRegions, just more murderous

    Output:
        The anchors that survive the slaughter, along with their indices
    """

    with tf.device("/cpu:0"):
        # Assumes input of shape (numBaseAnchors, feature_h, feature_w, 4)
        # Or, was previously as above but then got flattened to (-1,4)

        anchors = tf.reshape(anchors, [-1, 4], name="flattened_anchors")
        x1, y1, x2, y2 = tf.unstack(anchors, num=4, axis=axis)

        zero = tf.constant([0.])

        max_x = [tf.subtract(image_attr[1] * image_attr[2], tf.constant([1.]),
            name="murder_img_w")]
        max_y = [tf.subtract(image_attr[0] * image_attr[2], tf.constant([1.]),
            name="murder_img_h")]

        x1_valid = x1 >= zero
        x2_valid = x2 <= max_x
        y1_valid = y1 >= zero
        y2_valid = y2 <= max_y

        anchor_valid = x1_valid and x2_valid and y1_valid and y2_valid
        valid_indices = tf.where(anchor_valid, name="surviving_indices")
    return tf.gather_nd(anchors, valid_indices, name="surviving_anchors"), valid_indices
开发者ID:PentaHiggs,项目名称:fantastic-pancakes,代码行数:31,代码来源:rpn.py

示例15: gather_flat

def gather_flat(x: tf.Tensor,
                indices: tf.Tensor,
                batch_size: Union[int, tf.Tensor] = 1,
                beam_size: Union[int, tf.Tensor] = 1) -> tf.Tensor:
    """Gather values from the flattened (shape=[batch * beam, ...]) input.

    This function expects a flattened tensor with first dimension of size
    *batch x beam* elements. Using the given batch and beam size, it reshapes
    the input tensor to a tensor of shape ``(batch, beam, ...)`` and gather
    the values from it using the index tensor.

    Arguments:
        x: A flattened ``Tensor`` from which to gather values.
        indices: Index tensor.
        batch_size: The size of the batch.
        beam_size: The size of the beam.

    Returns:
        The ``Tensor`` of gathered values.
    """
    if x.shape.ndims == 0:
        return x

    shape = [batch_size, beam_size] + get_shape_list(x)[1:]
    gathered = tf.gather_nd(tf.reshape(x, shape), indices)
    return tf.reshape(gathered, [-1] + shape[2:])
开发者ID:ufal,项目名称:neuralmonkey,代码行数:26,代码来源:tf_utils.py


注:本文中的tensorflow.gather_nd函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。