当前位置: 首页>>代码示例>>Python>>正文


Python LongTensor.dim方法代码示例

本文整理汇总了Python中torch.LongTensor.dim方法的典型用法代码示例。如果您正苦于以下问题:Python LongTensor.dim方法的具体用法?Python LongTensor.dim怎么用?Python LongTensor.dim使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.LongTensor的用法示例。


在下文中一共展示了LongTensor.dim方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: flattened_index_select

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import dim [as 别名]
def flattened_index_select(target: torch.Tensor,
                           indices: torch.LongTensor) -> torch.Tensor:
    """
    The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
    that each of the set_size rows should select. The `target` has size
    ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
    ``(batch_size, set_size, subset_size, embedding_size)``.

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, sequence_length, embedding_size).
    indices : ``torch.LongTensor``, required.
        A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
        as this tensor is an index into the sequence_length dimension of the target.

    Returns
    -------
    selected : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
    """
    if indices.dim() != 2:
        raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
                                 "only 2 dimensional inputs are supported.".format(indices.size()))
    # Shape: (batch_size, set_size * subset_size, embedding_size)
    flattened_selected = target.index_select(1, indices.view(-1))

    # Shape: (batch_size, set_size, subset_size, embedding_size)
    selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
    return selected
开发者ID:pyknife,项目名称:allennlp,代码行数:32,代码来源:util.py

示例2: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import dim [as 别名]
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
#.........这里部分代码省略.........
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:103,代码来源:constituency_parser.py


注:本文中的torch.LongTensor.dim方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。