本文整理汇总了Python中torch.LongTensor.unsqueeze方法的典型用法代码示例。如果您正苦于以下问题:Python LongTensor.unsqueeze方法的具体用法?Python LongTensor.unsqueeze怎么用?Python LongTensor.unsqueeze使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.unsqueeze方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> None:
# shape (batch_size, num_spans)
span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]
if span_indices_mask is not None:
# It's not strictly necessary to multiply the span indices by the mask here,
# but it's possible that the span representation was padded with something other
# than 0 (such as -1, which would be an invalid index), so we do so anyway to
# be safe.
span_starts = span_starts * span_indices_mask
span_ends = span_ends * span_indices_mask
if not self._use_exclusive_start_indices:
start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
else:
# We want `exclusive` span starts, so we remove 1 from the forward span starts
# as the AllenNLP ``SpanField`` is inclusive.
# shape (batch_size, num_spans)
exclusive_span_starts = span_starts - 1
# shape (batch_size, num_spans, 1)
start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))
# We'll check the indices here at runtime, because it's difficult to debug
# if this goes wrong and it's tricky to get right.
if (exclusive_span_starts < 0).any():
raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
f"but found: exclusive_span_starts: {exclusive_span_starts}.")
start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
end_embeddings = util.batched_index_select(sequence_tensor, span_ends)
# We're using sentinels, so we need to replace all the elements which were
# outside the dimensions of the sequence_tensor with the start sentinel.
float_start_sentinel_mask = start_sentinel_mask.float()
start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
+ float_start_sentinel_mask * self._start_sentinel
combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(span_ends - span_starts,
num_total_buckets=self._num_width_embeddings)
else:
span_widths = span_ends - span_starts
span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([combined_tensors, span_width_embeddings], -1)
if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1).float()
return combined_tensors
示例2: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self, # pylint: disable=arguments-differ
embeddings: torch.FloatTensor,
mask: torch.LongTensor,
num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
torch.LongTensor, torch.FloatTensor]:
"""
Extracts the top-k scoring items with respect to the scorer. We additionally return
the indices of the top-k in their original order, not ordered by score, so that downstream
components can rely on the original ordering (e.g., for knowing what spans are valid
antecedents in a coreference resolution model).
Parameters
----------
embeddings : ``torch.FloatTensor``, required.
A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
each item in the list that we want to prune.
mask : ``torch.LongTensor``, required.
A tensor of shape (batch_size, num_items), denoting unpadded elements of
``embeddings``.
num_items_to_keep : ``int``, required.
The number of items to keep when pruning.
Returns
-------
top_embeddings : ``torch.FloatTensor``
The representations of the top-k scoring items.
Has shape (batch_size, num_items_to_keep, embedding_size).
top_mask : ``torch.LongTensor``
The corresponding mask for ``top_embeddings``.
Has shape (batch_size, num_items_to_keep).
top_indices : ``torch.IntTensor``
The indices of the top-k scoring items into the original ``embeddings``
tensor. This is returned because it can be useful to retain pointers to
the original items, if each item is being scored by multiple distinct
scorers, for instance. Has shape (batch_size, num_items_to_keep).
top_item_scores : ``torch.FloatTensor``
The values of the top-k scoring items.
Has shape (batch_size, num_items_to_keep, 1).
"""
mask = mask.unsqueeze(-1)
num_items = embeddings.size(1)
# Shape: (batch_size, num_items, 1)
scores = self._scorer(embeddings)
if scores.size(-1) != 1 or scores.dim() != 3:
raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape"
f"(batch_size, num_items, 1), but found shape {scores.size()}")
# Make sure that we don't select any masked items by setting their scores to be very
# negative. These are logits, typically, so -1e20 should be plenty negative.
scores = util.replace_masked_values(scores, mask, -1e20)
# Shape: (batch_size, num_items_to_keep, 1)
_, top_indices = scores.topk(num_items_to_keep, 1)
# Now we order the selected indices in increasing order with
# respect to their indices (and hence, with respect to the
# order they originally appeared in the ``embeddings`` tensor).
top_indices, _ = torch.sort(top_indices, 1)
# Shape: (batch_size, num_items_to_keep)
top_indices = top_indices.squeeze(-1)
# Shape: (batch_size * num_items_to_keep)
# torch.index_select only accepts 1D indices, but here
# we need to select items for each element in the batch.
flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)
# Shape: (batch_size, num_items_to_keep, embedding_size)
top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
# Shape: (batch_size, num_items_to_keep)
top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)
# Shape: (batch_size, num_items_to_keep, 1)
top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)
return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
示例3: _get_linking_probabilities
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def _get_linking_probabilities(self,
worlds: List[WikiTablesWorld],
linking_scores: torch.FloatTensor,
question_mask: torch.LongTensor,
entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
"""
Produces the probability of an entity given a question word and type. The logic below
separates the entities by type since the softmax normalization term sums over entities
of a single type.
Parameters
----------
worlds : ``List[WikiTablesWorld]``
linking_scores : ``torch.FloatTensor``
Has shape (batch_size, num_question_tokens, num_entities).
question_mask: ``torch.LongTensor``
Has shape (batch_size, num_question_tokens).
entity_type_dict : ``Dict[int, int]``
This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
Returns
-------
batch_probabilities : ``torch.FloatTensor``
Has shape ``(batch_size, num_question_tokens, num_entities)``.
Contains all the probabilities for an entity given a question word.
"""
_, num_question_tokens, num_entities = linking_scores.size()
batch_probabilities = []
for batch_index, world in enumerate(worlds):
all_probabilities = []
num_entities_in_instance = 0
# NOTE: The way that we're doing this here relies on the fact that entities are
# implicitly sorted by their types when we sort them by name, and that numbers come
# before "fb:cell", and "fb:cell" comes before "fb:row". This is not a great
# assumption, and could easily break later, but it should work for now.
for type_index in range(self._num_entity_types):
# This index of 0 is for the null entity for each type, representing the case where a
# word doesn't link to any entity.
entity_indices = [0]
entities = world.table_graph.entities
for entity_index, _ in enumerate(entities):
if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
entity_indices.append(entity_index)
if len(entity_indices) == 1:
# No entities of this type; move along...
continue
# We're subtracting one here because of the null entity we added above.
num_entities_in_instance += len(entity_indices) - 1
# We separate the scores by type, since normalization is done per type. There's an
# extra "null" entity per type, also, so we have `num_entities_per_type + 1`. We're
# selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
# so we get back something of shape (num_question_tokens,) for each index we're
# selecting. All of the selected indices together then make a tensor of shape
# (num_question_tokens, num_entities_per_type + 1).
indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
entity_scores = linking_scores[batch_index].index_select(1, indices)
# We used index 0 for the null entity, so this will actually have some values in it.
# But we want the null entity's score to be 0, so we set that here.
entity_scores[:, 0] = 0
# No need for a mask here, as this is done per batch instance, with no padding.
type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
all_probabilities.append(type_probabilities[:, 1:])
# We need to add padding here if we don't have the right number of entities.
if num_entities_in_instance != num_entities:
zeros = linking_scores.new_zeros(num_question_tokens,
num_entities - num_entities_in_instance)
all_probabilities.append(zeros)
# (num_question_tokens, num_entities)
probabilities = torch.cat(all_probabilities, dim=1)
batch_probabilities.append(probabilities)
batch_probabilities = torch.stack(batch_probabilities, dim=0)
return batch_probabilities * question_mask.unsqueeze(-1).float()
示例4: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
spans: torch.LongTensor,
metadata: List[Dict[str, Any]],
pos_tags: Dict[str, torch.LongTensor] = None,
span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
----------
tokens : Dict[str, torch.LongTensor], required
The output of ``TextField.as_array()``, which should typically be passed directly to a
``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
for the ``TokenIndexers`` when you created the ``TextField`` representing your
sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
which knows how to combine different word representations into a single vector per
token in your input.
spans : ``torch.LongTensor``, required.
A tensor of shape ``(batch_size, num_spans, 2)`` representing the
inclusive start and end indices of all possible spans in the sentence.
metadata : List[Dict[str, Any]], required.
A dictionary of metadata for each batch element which has keys:
tokens : ``List[str]``, required.
The original string tokens in the sentence.
gold_tree : ``nltk.Tree``, optional (default = None)
Gold NLTK trees for use in evaluation.
pos_tags : ``List[str]``, optional.
The POS tags for the sentence. These can be used in the
model as embedded features, but they are passed here
in addition for use in constructing the tree.
pos_tags : ``torch.LongTensor``, optional (default = None)
The output of a ``SequenceLabelField`` containing POS tags.
span_labels : ``torch.LongTensor``, optional (default = None)
A torch tensor representing the integer gold class labels for all possible
spans, of shape ``(batch_size, num_spans)``.
Returns
-------
An output dictionary consisting of:
class_probabilities : ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
representing a distribution over the label classes per span.
spans : ``torch.LongTensor``
The original spans tensor.
tokens : ``List[List[str]]``, required.
A list of tokens in the sentence for each element in the batch.
pos_tags : ``List[List[str]]``, required.
A list of POS tags in the sentence for each element in the batch.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised.
"""
embedded_text_input = self.text_field_embedder(tokens)
if pos_tags is not None and self.pos_tag_embedding is not None:
embedded_pos_tags = self.pos_tag_embedding(pos_tags)
embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
elif self.pos_tag_embedding is not None:
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")
mask = get_text_field_mask(tokens)
# Looking at the span start index is enough to know if
# this is padding or not. Shape: (batch_size, num_spans)
span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
if span_mask.dim() == 1:
# This happens if you use batch_size 1 and encounter
# a length 1 sentence in PTB, which do exist. -.-
span_mask = span_mask.unsqueeze(-1)
if span_labels is not None and span_labels.dim() == 1:
span_labels = span_labels.unsqueeze(-1)
num_spans = get_lengths_from_binary_sequence_mask(span_mask)
encoded_text = self.encoder(embedded_text_input, mask)
span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
if self.feedforward_layer is not None:
span_representations = self.feedforward_layer(span_representations)
logits = self.tag_projection_layer(span_representations)
class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))
output_dict = {
"class_probabilities": class_probabilities,
"spans": spans,
"tokens": [meta["tokens"] for meta in metadata],
"pos_tags": [meta.get("pos_tags") for meta in metadata],
"num_spans": num_spans
}
if span_labels is not None:
loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
self.tag_accuracy(class_probabilities, span_labels, span_mask)
output_dict["loss"] = loss
# The evalb score is expensive to compute, so we only compute
# it for the validation and test sets.
batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
#.........这里部分代码省略.........
示例5: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.LongTensor = None,
span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
# both of shape (batch_size, num_spans, 1)
span_starts, span_ends = span_indices.split(1, dim=-1)
# shape (batch_size, num_spans, 1)
# These span widths are off by 1, because the span ends are `inclusive`.
span_widths = span_ends - span_starts
# We need to know the maximum span width so we can
# generate indices to extract the spans from the sequence tensor.
# These indices will then get masked below, such that if the length
# of a given span is smaller than the max, the rest of the values
# are masked.
max_batch_span_width = span_widths.max().item() + 1
# shape (batch_size, sequence_length, 1)
global_attention_logits = self._global_attention(sequence_tensor)
# Shape: (1, 1, max_batch_span_width)
max_span_range_indices = util.get_range_vector(max_batch_span_width,
util.get_device_of(sequence_tensor)).view(1, 1, -1)
# Shape: (batch_size, num_spans, max_batch_span_width)
# This is a broadcasted comparison - for each span we are considering,
# we are creating a range vector of size max_span_width, but masking values
# which are greater than the actual length of the span.
#
# We're using <= here (and for the mask below) because the span ends are
# inclusive, so we want to include indices which are equal to span_widths rather
# than using it as a non-inclusive upper bound.
span_mask = (max_span_range_indices <= span_widths).float()
raw_span_indices = span_ends - max_span_range_indices
# We also don't want to include span indices which are less than zero,
# which happens because some spans near the beginning of the sequence
# have an end index < max_batch_span_width, so we add this to the mask here.
span_mask = span_mask * (raw_span_indices >= 0).float()
span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()
# Shape: (batch_size * num_spans * max_batch_span_width)
flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))
# Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
# Shape: (batch_size, num_spans, max_batch_span_width)
span_attention_logits = util.batched_index_select(global_attention_logits,
span_indices,
flat_span_indices).squeeze(-1)
# Shape: (batch_size, num_spans, max_batch_span_width)
span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)
# Do a weighted sum of the embedded spans with
# respect to the normalised attention distributions.
# Shape: (batch_size, num_spans, embedding_dim)
attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)
if span_indices_mask is not None:
# Above we were masking the widths of spans with respect to the max
# span width in the batch. Here we are masking the spans which were
# originally passed in as padding.
return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()
return attended_text_embeddings
示例6: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
valid_actions: List[List[ProductionRule]],
action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
if we're training, or a BeamSearch for inference, if we're not.
Parameters
----------
tokens : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
valid_actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
target_action_sequence : torch.Tensor, optional (default=None)
The action sequence for the correct action sequence, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
trailing dimension.
sql_queries : List[List[str]], optional (default=None)
A list of the SQL queries that are given during training or validation.
"""
embedded_utterance = self._utterance_embedder(tokens)
mask = util.get_text_field_mask(tokens).float()
batch_size = embedded_utterance.size(0)
# (batch_size, num_tokens, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)
if action_sequence is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
action_sequence = action_sequence.squeeze(-1)
target_mask = action_sequence != self._action_padding_index
else:
target_mask = None
outputs: Dict[str, Any] = {}
if action_sequence is not None:
# target_action_sequence is of shape (batch_size, 1, target_sequence_length)
# here after we unsqueeze it for the MML trainer.
loss_output = self._decoder_trainer.decode(initial_state,
self._transition_function,
(action_sequence.unsqueeze(1),
target_mask.unsqueeze(1)))
outputs.update(loss_output)
if not self.training:
action_mapping = []
for batch_actions in valid_actions:
batch_action_mapping = {}
for action_index, action in enumerate(batch_actions):
batch_action_mapping[action_index] = action[0]
action_mapping.append(batch_action_mapping)
outputs['action_mapping'] = action_mapping
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(self._max_decoding_steps,
initial_state,
self._transition_function,
keep_final_unfinished_states=True)
outputs['best_action_sequence'] = []
outputs['debug_info'] = []
outputs['predicted_sql_query'] = []
outputs['sql_queries'] = []
for i in range(batch_size):
# Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i not in best_final_states:
self._exact_match(0)
self._denotation_accuracy(0)
self._valid_sql_query(0)
self._action_similarity(0)
outputs['predicted_sql_query'].append('')
continue
best_action_indices = best_final_states[i][0].action_history[0]
action_strings = [action_mapping[i][action_index]
for action_index in best_action_indices]
predicted_sql_query = action_sequence_to_sql(action_strings)
if action_sequence is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
targets = action_sequence[i].data
sequence_in_targets = 0
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._exact_match(sequence_in_targets)
similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
self._action_similarity(similarity.ratio())
outputs['best_action_sequence'].append(action_strings)
#.........这里部分代码省略.........
示例7: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import unsqueeze [as 别名]
def forward(self, # type: ignore
utterance: Dict[str, torch.LongTensor],
world: List[AtisWorld],
actions: List[List[ProductionRule]],
linking_scores: torch.Tensor,
target_action_sequence: torch.LongTensor = None,
sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
if we're training, or a BeamSearch for inference, if we're not.
Parameters
----------
utterance : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
world : ``List[AtisWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
linking_scores: ``torch.Tensor``
A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
is deterministically generated where each entry indicates whether a token generated an entity.
This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
target_action_sequence : torch.Tensor, optional (default=None)
The action sequence for the correct action sequence, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
trailing dimension.
sql_queries : List[List[str]], optional (default=None)
A list of the SQL queries that are given during training or validation.
"""
initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
batch_size = linking_scores.shape[0]
if target_action_sequence is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
target_action_sequence = target_action_sequence.squeeze(-1)
target_mask = target_action_sequence != self._action_padding_index
else:
target_mask = None
if self.training:
# target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
# the MML trainer.
return self._decoder_trainer.decode(initial_state,
self._transition_function,
(target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
else:
# TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
action_mapping = {}
for batch_index, batch_actions in enumerate(actions):
for action_index, action in enumerate(batch_actions):
action_mapping[(batch_index, action_index)] = action[0]
outputs: Dict[str, Any] = {'action_mapping': action_mapping}
outputs['linking_scores'] = linking_scores
if target_action_sequence is not None:
outputs['loss'] = self._decoder_trainer.decode(initial_state,
self._transition_function,
(target_action_sequence.unsqueeze(1),
target_mask.unsqueeze(1)))['loss']
num_steps = self._max_decoding_steps
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(num_steps,
initial_state,
self._transition_function,
keep_final_unfinished_states=False)
outputs['best_action_sequence'] = []
outputs['debug_info'] = []
outputs['entities'] = []
outputs['predicted_sql_query'] = []
outputs['sql_queries'] = []
outputs['utterance'] = []
outputs['tokenized_utterance'] = []
for i in range(batch_size):
# Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i not in best_final_states:
self._exact_match(0)
self._denotation_accuracy(0)
self._valid_sql_query(0)
self._action_similarity(0)
outputs['predicted_sql_query'].append('')
continue
best_action_indices = best_final_states[i][0].action_history[0]
action_strings = [action_mapping[(i, action_index)]
for action_index in best_action_indices]
predicted_sql_query = action_sequence_to_sql(action_strings)
if target_action_sequence is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
#.........这里部分代码省略.........