Python torch.topk方法代码示例

示例1: greedy_decode

def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, flag):
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        decoded = []
        decoder = self.m_decoder if not flag else self.p_decoder
        for t in range(self.max_ts):
            proba, last_hidden = decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 

示例2: pz_selective_sampling

def pz_selective_sampling(self, pz_proba):
        Selective sampling of pz(do max-sampling but prevent repeated words)
        pz_proba = pz_proba.data
        z_proba, z_token = torch.topk(pz_proba, pz_proba.size(0), dim=2)
        z_token = z_token.transpose(0, 1)  # [B,Tz,top_Tz]
        all_sampled_z = []
        for b in range(z_token.size(0)):
            sampled_z = []
            for t in range(z_token.size(1)):
                for i in range(z_token.size(2)):
                    if z_token[b][t][i] not in sampled_z:
        return all_sampled_z 

示例3: greedy_decode

def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input):
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        decoded = []
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 

示例4: greedy_decode

def greedy_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
        decoded = []
        bspan_index_np = pad_sequences(bspan_index).transpose((1, 0))
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1,
                                                   degree_input, last_hidden, bspan_index_np)
            proba = torch.cat((proba[:, :2], proba[:, 3:]), 1)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            for i in range(mt_index.size(0)):
                if mt_index[i] >= cfg.vocab_size:
                    mt_index[i] = 2  # unk
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 

示例5: greedy_decode

def greedy_decode(self, decoder_hidden, encoder_outputs, target_tensor):
        decoded_sentences = []
        batch_size, seq_len = target_tensor.size()
        decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device)

        decoded_words = torch.zeros((batch_size, self.max_len))
        for t in range(self.max_len):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)

            topv, topi = decoder_output.data.topk(1)  # get candidates
            topi = topi.view(-1)

            decoded_words[:, t] = topi
            decoder_input = topi.detach().view(-1, 1)

        for sentence in decoded_words:
            sent = []
            for ind in sentence:
                if self.output_index2word(str(int(ind.item()))) == self.output_index2word(str(EOS_token)):
            decoded_sentences.append(' '.join(sent))

        return decoded_sentences 

示例6: test_knn

def test_knn():
    x = th.randn(8, 3)
    kg = dgl.nn.KNNGraph(3)
    d = th.cdist(x, x)

    def check_knn(g, x, start, end):
        for v in range(start, end):
            src, _ = g.in_edges(v)
            src = set(src.numpy())
            i = v - start
            src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
            assert src == src_ans

    g = kg(x)
    check_knn(g, x, 0, 8)

    g = kg(x.view(2, 4, 3))
    check_knn(g, x, 0, 4)
    check_knn(g, x, 4, 8)

    kg = dgl.nn.SegmentedKNNGraph(3)
    g = kg(x, [3, 5])
    check_knn(g, x, 0, 3)
    check_knn(g, x, 3, 8) 

示例7: predict

def predict(self, x):
        batch_size, dims = x.size()
        query = F.normalize(self.query_proj(x), dim=1)

        # Find the k-nearest neighbors of the query
        scores = torch.matmul(query, torch.t(self.keys_var))
        cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)

        # softmax of cosine similarities - embedding
        softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)

        # retrive memory values - prediction
        y_hat_indices = topk_indices_var.data[:, 0]
        y_hat = self.values[y_hat_indices]

        return y_hat, softmax_score 

示例8: select_topk

def select_topk(args, logits, force_no_eos_id=None):
    Applies topk sampling decoding.
    if force_no_eos_id is not None:
        logits[:, force_no_eos_id] = float('-inf')

    indices_to_remove = logits < \
        torch.topk(logits, args.top_k, axis=-1)[0][
            ..., -1, None]

    logits[indices_to_remove] = float('-inf')

    return logits

# implementation is from Huggingface/transformers repo 

示例9: _topk

def _topk(self, scores, K=40):
        batch, cat, height, width = scores.size()
        topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

        topk_inds = topk_inds % (height * width)
        topk_ys   = (topk_inds / width).int().float()
        topk_xs   = (topk_inds % width).int().float()
        topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
        topk_clses = (topk_ind / K).int()
        topk_inds = _gather_feat(
            topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 

示例10: _topk

def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    topk_inds = topk_inds % (height * width)
    topk_ys   = (topk_inds / width).int().float()
    topk_xs   = (topk_inds % width).int().float()
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
    topk_clses = (topk_ind / K).int()
    topk_inds = _gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 

示例11: select_next_words

def select_next_words(
        self, word_scores, bsz, beam_size, possible_translation_tokens
        cand_scores, cand_indices = torch.topk(word_scores.view(bsz, -1), k=beam_size)
        possible_tokens_size = self.vocab_size
        if possible_translation_tokens is not None:
            possible_tokens_size = possible_translation_tokens.size(0)
        cand_beams = torch.div(cand_indices, possible_tokens_size)
        # Handle vocab reduction
        if possible_translation_tokens is not None:
            possible_translation_tokens = possible_translation_tokens.view(
                1, possible_tokens_size
            ).expand(cand_indices.size(0), possible_tokens_size)
            cand_indices = torch.gather(
                possible_translation_tokens, dim=1, index=cand_indices, out=cand_indices
        return cand_scores, cand_indices, cand_beams 

示例12: get_topk_predicted_tokens

def get_topk_predicted_tokens(self, net_output, src_tokens, log_probs: bool):
        Get self.topk_labels_per_source_token top predicted words for vocab
        reduction (per source token).
        assert (
            isinstance(self.topk_labels_per_source_token, int)
            and self.topk_labels_per_source_token > 0
        ), "topk_labels_per_source_token must be a positive int, or None"

        # number of labels to predict for each example in batch
        k = src_tokens.size(1) * self.topk_labels_per_source_token
        # [batch_size, vocab_size]
        probs = self.get_normalized_probs(net_output, log_probs)
        _, topk_indices = torch.topk(probs, k, dim=1)

        return topk_indices 

示例13: diversity_sibling_rank

def diversity_sibling_rank(self, logprobs, gamma):
        See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation"
        for details
        _, beam_size, vocab_size = logprobs.size()
        logprobs = logprobs.view(-1, vocab_size)
        # Keep consistent with beamsearch class in fairseq
        k = min(2 * beam_size, vocab_size)
        _, indices = torch.topk(logprobs, k)
        # Set diverse penalty as k for all words
        diverse_penalty = torch.ones_like(logprobs) * k
        diversity_sibling_rank = (
            torch.arange(0, k).view(-1, 1).expand(k, logprobs.size(0)).type_as(logprobs)
        # Set diversity penalty accordingly for top-k words
            torch.arange(0, logprobs.size(0)).long(), indices.transpose(0, 1)
        ] = diversity_sibling_rank
        logprobs -= gamma * diverse_penalty
        return logprobs 

示例14: top_k_softmax

def top_k_softmax(logits, k, n):
        top_logits, top_indices = torch.topk(logits, k=min(k + 1, n))

        top_k_logits = top_logits[:, :k]
        top_k_indices = top_indices[:, :k]

        probs = torch.softmax(top_k_logits, dim=-1)
        batch = top_k_logits.shape[0]
        k = top_k_logits.shape[1]

        # Flat to 1D
        indices_flat = torch.reshape(top_k_indices, [-1])
        indices_flat = indices_flat + torch.div(
            torch.arange(batch * k, device=logits.device), k) * n

        tensor = torch.zeros([batch * n], dtype=logits.dtype,
        tensor = tensor.scatter_add(0, indices_flat.long(),
                                    torch.reshape(probs, [-1]))

        return torch.reshape(tensor, [batch, n]) 

示例15: get_roi_rel_points_test

def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
        """Get ``num_points`` most uncertain points during test.

            mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
            pred_label (list): The predication class for each instance.
            cfg (dict): Testing config of point head.

            point_indices (Tensor): A tensor of shape (num_rois, num_points)
                that contains indices from [0, mask_height x mask_width) of the
                most uncertain points.
            point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
                that contains [0, 1] x [0, 1] normalized coordinates of the
                most uncertain points from the [mask_height, mask_width] grid .
        num_points = cfg.subdivision_num_points
        uncertainty_map = self._get_uncertainty(mask_pred, pred_label)
        num_rois, _, mask_height, mask_width = uncertainty_map.shape
        h_step = 1.0 / mask_height
        w_step = 1.0 / mask_width

        uncertainty_map = uncertainty_map.view(num_rois,
                                               mask_height * mask_width)
        num_points = min(mask_height * mask_width, num_points)
        point_indices = uncertainty_map.topk(num_points, dim=1)[1]
        point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2)
        point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
                                                mask_width).float() * w_step
        point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
                                                mask_width).float() * h_step
        return point_indices, point_coords 
