當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.topk方法代碼示例

本文整理匯總了Python中torch.topk方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.topk方法的具體用法?Python torch.topk怎麽用?Python torch.topk使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.topk方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: greedy_decode

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, flag):
        """
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        """
        decoded = []
        decoder = self.m_decoder if not flag else self.p_decoder
        for t in range(self.max_ts):
            proba, last_hidden = decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
開發者ID:AuCson,項目名稱:SEDST,代碼行數:22,代碼來源:unsup_net.py

示例2: pz_selective_sampling

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def pz_selective_sampling(self, pz_proba):
        """
        Selective sampling of pz(do max-sampling but prevent repeated words)
        """
        pz_proba = pz_proba.data
        z_proba, z_token = torch.topk(pz_proba, pz_proba.size(0), dim=2)
        z_token = z_token.transpose(0, 1)  # [B,Tz,top_Tz]
        all_sampled_z = []
        for b in range(z_token.size(0)):
            sampled_z = []
            for t in range(z_token.size(1)):
                for i in range(z_token.size(2)):
                    if z_token[b][t][i] not in sampled_z:
                        sampled_z.append(z_token[b][t][i])
                        break
            all_sampled_z.append(sampled_z)
        return all_sampled_z 
開發者ID:AuCson,項目名稱:SEDST,代碼行數:19,代碼來源:unsup_net.py

示例3: greedy_decode

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def greedy_decode(self, pz_dec_outs, pz_proba, u_enc_out, m_tm1, last_hidden, degree_input):
        """
        greedy decoding of the response
        :param pz_dec_outs:
        :param u_enc_out:
        :param m_tm1:
        :param last_hidden:
        :return: nested-list
        """
        decoded = []
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, pz_proba, u_enc_out, m_tm1, degree_input, last_hidden)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index)
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
開發者ID:AuCson,項目名稱:SEDST,代碼行數:21,代碼來源:semi_sup_net.py

示例4: greedy_decode

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def greedy_decode(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
        decoded = []
        bspan_index_np = pad_sequences(bspan_index).transpose((1, 0))
        for t in range(self.max_ts):
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1,
                                                   degree_input, last_hidden, bspan_index_np)
            proba = torch.cat((proba[:, :2], proba[:, 3:]), 1)
            mt_proba, mt_index = torch.topk(proba, 1)  # [B,1]
            mt_index.add_(mt_index.ge(2).long())
            mt_index = mt_index.data.view(-1)
            decoded.append(mt_index.clone())
            for i in range(mt_index.size(0)):
                if mt_index[i] >= cfg.vocab_size:
                    mt_index[i] = 2  # unk
            m_tm1 = cuda_(Variable(mt_index).view(1, -1))
        decoded = torch.stack(decoded, dim=0).transpose(0, 1)
        decoded = list(decoded)
        return [list(_) for _ in decoded] 
開發者ID:ConvLab,項目名稱:ConvLab,代碼行數:20,代碼來源:tsd_net.py

示例5: greedy_decode

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def greedy_decode(self, decoder_hidden, encoder_outputs, target_tensor):
        decoded_sentences = []
        batch_size, seq_len = target_tensor.size()
        decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=self.device)

        decoded_words = torch.zeros((batch_size, self.max_len))
        for t in range(self.max_len):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)

            topv, topi = decoder_output.data.topk(1)  # get candidates
            topi = topi.view(-1)

            decoded_words[:, t] = topi
            decoder_input = topi.detach().view(-1, 1)

        for sentence in decoded_words:
            sent = []
            for ind in sentence:
                if self.output_index2word(str(int(ind.item()))) == self.output_index2word(str(EOS_token)):
                    break
                sent.append(self.output_index2word(str(int(ind.item()))))
            decoded_sentences.append(' '.join(sent))

        return decoded_sentences 
開發者ID:ConvLab,項目名稱:ConvLab,代碼行數:26,代碼來源:model.py

示例6: test_knn

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def test_knn():
    x = th.randn(8, 3)
    kg = dgl.nn.KNNGraph(3)
    d = th.cdist(x, x)

    def check_knn(g, x, start, end):
        for v in range(start, end):
            src, _ = g.in_edges(v)
            src = set(src.numpy())
            i = v - start
            src_ans = set(th.topk(d[start:end, start:end][i], 3, largest=False)[1].numpy() + start)
            assert src == src_ans

    g = kg(x)
    check_knn(g, x, 0, 8)

    g = kg(x.view(2, 4, 3))
    check_knn(g, x, 0, 4)
    check_knn(g, x, 4, 8)

    kg = dgl.nn.SegmentedKNNGraph(3)
    g = kg(x, [3, 5])
    check_knn(g, x, 0, 3)
    check_knn(g, x, 3, 8) 
開發者ID:dmlc,項目名稱:dgl,代碼行數:26,代碼來源:test_geometry.py

示例7: predict

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def predict(self, x):
        batch_size, dims = x.size()
        query = F.normalize(self.query_proj(x), dim=1)

        # Find the k-nearest neighbors of the query
        scores = torch.matmul(query, torch.t(self.keys_var))
        cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)

        # softmax of cosine similarities - embedding
        softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)

        # retrive memory values - prediction
        y_hat_indices = topk_indices_var.data[:, 0]
        y_hat = self.values[y_hat_indices]

        return y_hat, softmax_score 
開發者ID:RUSH-LAB,項目名稱:LSH_Memory,代碼行數:18,代碼來源:memory.py

示例8: select_topk

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def select_topk(args, logits, force_no_eos_id=None):
    """
    Applies topk sampling decoding.
    """        
    if force_no_eos_id is not None:
        logits[:, force_no_eos_id] = float('-inf')

    indices_to_remove = logits < \
        torch.topk(logits, args.top_k, axis=-1)[0][
            ..., -1, None]

    logits[indices_to_remove] = float('-inf')

    return logits


# implementation is from Huggingface/transformers repo 
開發者ID:bme-chatbots,項目名稱:dialogue-generation,代碼行數:19,代碼來源:interact.py

示例9: _topk

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def _topk(self, scores, K=40):
        batch, cat, height, width = scores.size()
          
        topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

        topk_inds = topk_inds % (height * width)
        topk_ys   = (topk_inds / width).int().float()
        topk_xs   = (topk_inds % width).int().float()
          
        topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
        topk_clses = (topk_ind / K).int()
        topk_inds = _gather_feat(
            topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
        topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 
開發者ID:tensorboy,項目名稱:centerpose,代碼行數:19,代碼來源:centernet_tensorrt_engine.py

示例10: _topk

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
      
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    topk_inds = topk_inds % (height * width)
    topk_ys   = (topk_inds / width).int().float()
    topk_xs   = (topk_inds % width).int().float()
      
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
    topk_clses = (topk_ind / K).int()
    topk_inds = _gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs 
開發者ID:tensorboy,項目名稱:centerpose,代碼行數:19,代碼來源:decode.py

示例11: select_next_words

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def select_next_words(
        self, word_scores, bsz, beam_size, possible_translation_tokens
    ):
        cand_scores, cand_indices = torch.topk(word_scores.view(bsz, -1), k=beam_size)
        possible_tokens_size = self.vocab_size
        if possible_translation_tokens is not None:
            possible_tokens_size = possible_translation_tokens.size(0)
        cand_beams = torch.div(cand_indices, possible_tokens_size)
        cand_indices.fmod_(possible_tokens_size)
        # Handle vocab reduction
        if possible_translation_tokens is not None:
            possible_translation_tokens = possible_translation_tokens.view(
                1, possible_tokens_size
            ).expand(cand_indices.size(0), possible_tokens_size)
            cand_indices = torch.gather(
                possible_translation_tokens, dim=1, index=cand_indices, out=cand_indices
            )
        return cand_scores, cand_indices, cand_beams 
開發者ID:pytorch,項目名稱:translate,代碼行數:20,代碼來源:competing_completed.py

示例12: get_topk_predicted_tokens

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def get_topk_predicted_tokens(self, net_output, src_tokens, log_probs: bool):
        """
        Get self.topk_labels_per_source_token top predicted words for vocab
        reduction (per source token).
        """
        assert (
            isinstance(self.topk_labels_per_source_token, int)
            and self.topk_labels_per_source_token > 0
        ), "topk_labels_per_source_token must be a positive int, or None"

        # number of labels to predict for each example in batch
        k = src_tokens.size(1) * self.topk_labels_per_source_token
        # [batch_size, vocab_size]
        probs = self.get_normalized_probs(net_output, log_probs)
        _, topk_indices = torch.topk(probs, k, dim=1)

        return topk_indices 
開發者ID:pytorch,項目名稱:translate,代碼行數:19,代碼來源:word_predictor.py

示例13: diversity_sibling_rank

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def diversity_sibling_rank(self, logprobs, gamma):
        """
        See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation"
        for details
        """
        _, beam_size, vocab_size = logprobs.size()
        logprobs = logprobs.view(-1, vocab_size)
        # Keep consistent with beamsearch class in fairseq
        k = min(2 * beam_size, vocab_size)
        _, indices = torch.topk(logprobs, k)
        # Set diverse penalty as k for all words
        diverse_penalty = torch.ones_like(logprobs) * k
        diversity_sibling_rank = (
            torch.arange(0, k).view(-1, 1).expand(k, logprobs.size(0)).type_as(logprobs)
        )
        # Set diversity penalty accordingly for top-k words
        diverse_penalty[
            torch.arange(0, logprobs.size(0)).long(), indices.transpose(0, 1)
        ] = diversity_sibling_rank
        logprobs -= gamma * diverse_penalty
        return logprobs 
開發者ID:pytorch,項目名稱:translate,代碼行數:23,代碼來源:beam_decode.py

示例14: top_k_softmax

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def top_k_softmax(logits, k, n):
        top_logits, top_indices = torch.topk(logits, k=min(k + 1, n))

        top_k_logits = top_logits[:, :k]
        top_k_indices = top_indices[:, :k]

        probs = torch.softmax(top_k_logits, dim=-1)
        batch = top_k_logits.shape[0]
        k = top_k_logits.shape[1]

        # Flat to 1D
        indices_flat = torch.reshape(top_k_indices, [-1])
        indices_flat = indices_flat + torch.div(
            torch.arange(batch * k, device=logits.device), k) * n

        tensor = torch.zeros([batch * n], dtype=logits.dtype,
                             device=logits.device)
        tensor = tensor.scatter_add(0, indices_flat.long(),
                                    torch.reshape(probs, [-1]))

        return torch.reshape(tensor, [batch, n]) 
開發者ID:XMUNLP,項目名稱:Tagger,代碼行數:23,代碼來源:recurrent.py

示例15: get_roi_rel_points_test

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import topk [as 別名]
def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
        """Get ``num_points`` most uncertain points during test.

        Args:
            mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
                mask_height, mask_width) for class-specific or class-agnostic
                prediction.
            pred_label (list): The predication class for each instance.
            cfg (dict): Testing config of point head.

        Returns:
            point_indices (Tensor): A tensor of shape (num_rois, num_points)
                that contains indices from [0, mask_height x mask_width) of the
                most uncertain points.
            point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
                that contains [0, 1] x [0, 1] normalized coordinates of the
                most uncertain points from the [mask_height, mask_width] grid .
        """
        num_points = cfg.subdivision_num_points
        uncertainty_map = self._get_uncertainty(mask_pred, pred_label)
        num_rois, _, mask_height, mask_width = uncertainty_map.shape
        h_step = 1.0 / mask_height
        w_step = 1.0 / mask_width

        uncertainty_map = uncertainty_map.view(num_rois,
                                               mask_height * mask_width)
        num_points = min(mask_height * mask_width, num_points)
        point_indices = uncertainty_map.topk(num_points, dim=1)[1]
        point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2)
        point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
                                                mask_width).float() * w_step
        point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
                                                mask_width).float() * h_step
        return point_indices, point_coords 
開發者ID:open-mmlab,項目名稱:mmdetection,代碼行數:36,代碼來源:mask_point_head.py


注:本文中的torch.topk方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。