當前位置: 首頁>>代碼示例>>Python>>正文


Python LongTensor.clone方法代碼示例

本文整理匯總了Python中torch.LongTensor.clone方法的典型用法代碼示例。如果您正苦於以下問題:Python LongTensor.clone方法的具體用法?Python LongTensor.clone怎麽用?Python LongTensor.clone使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.LongTensor的用法示例。


在下文中一共展示了LongTensor.clone方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: construct_trees

# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import clone [as 別名]
    def construct_trees(self,
                        predictions: torch.FloatTensor,
                        all_spans: torch.LongTensor,
                        num_spans: torch.LongTensor,
                        sentences: List[List[str]],
                        pos_tags: List[List[str]] = None) -> List[Tree]:
        """
        Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
        The trees use exclusive end indices, which contrasts with how spans are
        represented in the rest of the model.

        Parameters
        ----------
        predictions : ``torch.FloatTensor``, required.
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        all_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the span
            indices we scored.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        sentences : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, optional (default = None).
            A list of POS tags for each word in the sentence for each element
            in the batch.

        Returns
        -------
        A ``List[Tree]`` containing the decoded trees for each element in the batch.
        """
        # Switch to using exclusive end spans.
        exclusive_end_spans = all_spans.clone()
        exclusive_end_spans[:, :, -1] += 1
        no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

        trees: List[Tree] = []
        for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions,
                                                                          exclusive_end_spans,
                                                                          sentences)):
            selected_spans = []
            for prediction, span in zip(scored_spans[:num_spans[batch_index]],
                                        spans[:num_spans[batch_index]]):
                start, end = span
                no_label_prob = prediction[no_label_id]
                label_prob, label_index = torch.max(prediction, -1)

                # Does the span have a label != NO-LABEL or is it the root node?
                # If so, include it in the spans that we consider.
                if int(label_index) != no_label_id or (start == 0 and end == len(sentence)):
                    # TODO(Mark): Remove this once pylint sorts out named tuples.
                    # https://github.com/PyCQA/pylint/issues/1418
                    selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter
                                                          end=int(end),
                                                          label_prob=float(label_prob),
                                                          no_label_prob=float(no_label_prob),
                                                          label_index=int(label_index)))

            # The spans we've selected might overlap, which causes problems when we try
            # to construct the tree as they won't nest properly.
            consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans)

            spans_to_labels = {(span.start, span.end):
                               self.vocab.get_token_from_index(span.label_index, "labels")
                               for span in consistent_spans}
            sentence_pos = pos_tags[batch_index] if pos_tags is not None else None
            trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos))

        return trees
開發者ID:Jordan-Sauchuk,項目名稱:allennlp,代碼行數:72,代碼來源:constituency_parser.py


注:本文中的torch.LongTensor.clone方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。