本文整理匯總了Python中torch.LongTensor.cpu方法的典型用法代碼示例。如果您正苦於以下問題:Python LongTensor.cpu方法的具體用法?Python LongTensor.cpu怎麽用?Python LongTensor.cpu使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.cpu方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: from torch import LongTensor [as 別名]
# 或者: from torch.LongTensor import cpu [as 別名]
#.........這裏部分代碼省略.........
----------
tokens : Dict[str, torch.LongTensor], required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
spans : ``torch.LongTensor``, required.
A tensor of shape ``(batch_size, num_spans, 2)`` representing the
inclusive start and end indices of all possible spans in the sentence.
metadata : List[Dict[str, Any]], required.
A dictionary of metadata for each batch element which has keys:
tokens : ``List[str]``, required.
The original string tokens in the sentence.
gold_tree : ``nltk.Tree``, optional (default = None)
Gold NLTK trees for use in evaluation.
pos_tags : ``List[str]``, optional.
The POS tags for the sentence. These can be used in the
model as embedded features, but they are passed here
in addition for use in constructing the tree.
pos_tags : ``torch.LongTensor``, optional (default = None)
The output of a ``SequenceLabelField`` containing POS tags.
span_labels : ``torch.LongTensor``, optional (default = None)
A torch tensor representing the integer gold class labels for all possible
spans, of shape ``(batch_size, num_spans)``.
Returns
-------
An output dictionary consisting of:
class_probabilities : ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing a distribution over the label classes per span.
spans : ``torch.LongTensor``
The original spans tensor.
tokens : ``List[List[str]]``, required.
A list of tokens in the sentence for each element in the batch.
pos_tags : ``List[List[str]]``, required.
A list of POS tags in the sentence for each element in the batch.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised.
"""
embedded_text_input = self.text_field_embedder(tokens)
if pos_tags is not None and self.pos_tag_embedding is not None:
embedded_pos_tags = self.pos_tag_embedding(pos_tags)
embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
elif self.pos_tag_embedding is not None:
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")
mask = get_text_field_mask(tokens)
# Looking at the span start index is enough to know if
# this is padding or not. Shape: (batch_size, num_spans)
span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
if span_mask.dim() == 1:
# This happens if you use batch_size 1 and encounter
# a length 1 sentence in PTB, which do exist. -.-
span_mask = span_mask.unsqueeze(-1)
if span_labels is not None and span_labels.dim() == 1:
span_labels = span_labels.unsqueeze(-1)
num_spans = get_lengths_from_binary_sequence_mask(span_mask)
encoded_text = self.encoder(embedded_text_input, mask)
span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
if self.feedforward_layer is not None:
span_representations = self.feedforward_layer(span_representations)
logits = self.tag_projection_layer(span_representations)
class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))
output_dict = {
"class_probabilities": class_probabilities,
"spans": spans,
"tokens": [meta["tokens"] for meta in metadata],
"pos_tags": [meta.get("pos_tags") for meta in metadata],
"num_spans": num_spans
}
if span_labels is not None:
loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
self.tag_accuracy(class_probabilities, span_labels, span_mask)
output_dict["loss"] = loss
# The evalb score is expensive to compute, so we only compute
# it for the validation and test sets.
batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
for tree in batch_gold_trees]
predicted_trees = self.construct_trees(class_probabilities.cpu().data,
spans.cpu().data,
num_spans.data,
output_dict["tokens"],
gold_pos_tags)
self._evalb_score(predicted_trees, batch_gold_trees)
return output_dict