本文整理匯總了Python中torch.LongTensor.dim方法的典型用法代碼示例。如果您正苦於以下問題:Python LongTensor.dim方法的具體用法?Python LongTensor.dim怎麽用?Python LongTensor.dim使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.dim方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: flattened_index_select
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import dim [as 別名]
def flattened_index_select(target: torch.Tensor,
indices: torch.LongTensor) -> torch.Tensor:
"""
The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
that each of the set_size rows should select. The `target` has size
``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
``(batch_size, set_size, subset_size, embedding_size)``.
Parameters
----------
target : ``torch.Tensor``, required.
A Tensor of shape (batch_size, sequence_length, embedding_size).
indices : ``torch.LongTensor``, required.
A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
as this tensor is an index into the sequence_length dimension of the target.
Returns
-------
selected : ``torch.Tensor``, required.
A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
"""
if indices.dim() != 2:
raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
"only 2 dimensional inputs are supported.".format(indices.size()))
# Shape: (batch_size, set_size * subset_size, embedding_size)
flattened_selected = target.index_select(1, indices.view(-1))
# Shape: (batch_size, set_size, subset_size, embedding_size)
selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
return selected
示例2: forward
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import dim [as 別名]
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
spans: torch.LongTensor,
metadata: List[Dict[str, Any]],
pos_tags: Dict[str, torch.LongTensor] = None,
span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
----------
tokens : Dict[str, torch.LongTensor], required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
spans : ``torch.LongTensor``, required.
A tensor of shape ``(batch_size, num_spans, 2)`` representing the
inclusive start and end indices of all possible spans in the sentence.
metadata : List[Dict[str, Any]], required.
A dictionary of metadata for each batch element which has keys:
tokens : ``List[str]``, required.
The original string tokens in the sentence.
gold_tree : ``nltk.Tree``, optional (default = None)
Gold NLTK trees for use in evaluation.
pos_tags : ``List[str]``, optional.
The POS tags for the sentence. These can be used in the
model as embedded features, but they are passed here
in addition for use in constructing the tree.
pos_tags : ``torch.LongTensor``, optional (default = None)
The output of a ``SequenceLabelField`` containing POS tags.
span_labels : ``torch.LongTensor``, optional (default = None)
A torch tensor representing the integer gold class labels for all possible
spans, of shape ``(batch_size, num_spans)``.
Returns
-------
An output dictionary consisting of:
class_probabilities : ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing a distribution over the label classes per span.
spans : ``torch.LongTensor``
The original spans tensor.
tokens : ``List[List[str]]``, required.
A list of tokens in the sentence for each element in the batch.
pos_tags : ``List[List[str]]``, required.
A list of POS tags in the sentence for each element in the batch.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised.
"""
embedded_text_input = self.text_field_embedder(tokens)
if pos_tags is not None and self.pos_tag_embedding is not None:
embedded_pos_tags = self.pos_tag_embedding(pos_tags)
embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
elif self.pos_tag_embedding is not None:
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")
mask = get_text_field_mask(tokens)
# Looking at the span start index is enough to know if
# this is padding or not. Shape: (batch_size, num_spans)
span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
if span_mask.dim() == 1:
# This happens if you use batch_size 1 and encounter
# a length 1 sentence in PTB, which do exist. -.-
span_mask = span_mask.unsqueeze(-1)
if span_labels is not None and span_labels.dim() == 1:
span_labels = span_labels.unsqueeze(-1)
num_spans = get_lengths_from_binary_sequence_mask(span_mask)
encoded_text = self.encoder(embedded_text_input, mask)
span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
if self.feedforward_layer is not None:
span_representations = self.feedforward_layer(span_representations)
logits = self.tag_projection_layer(span_representations)
class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))
output_dict = {
"class_probabilities": class_probabilities,
"spans": spans,
"tokens": [meta["tokens"] for meta in metadata],
"pos_tags": [meta.get("pos_tags") for meta in metadata],
"num_spans": num_spans
}
if span_labels is not None:
loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
self.tag_accuracy(class_probabilities, span_labels, span_mask)
output_dict["loss"] = loss
# The evalb score is expensive to compute, so we only compute
# it for the validation and test sets.
batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
#.........這裏部分代碼省略.........