本文整理汇总了Python中torch.LongTensor.squeeze方法的典型用法代码示例。如果您正苦于以下问题:Python LongTensor.squeeze方法的具体用法?Python LongTensor.squeeze怎么用?Python LongTensor.squeeze使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.LongTensor
的用法示例。
在下文中一共展示了LongTensor.squeeze方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_checklist_info
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def _get_checklist_info(agenda: torch.LongTensor,
all_actions: List[ProductionRule],
terminal_productions: Set[str],
max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
world, and a length to pad the checklist vectors to, and returns a target checklist against
which the checklist at each state will be compared to compute a loss, indices of
``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
are relevant for checklist loss computation.
Parameters
----------
``agenda`` : ``torch.LongTensor``
Agenda of one instance of size ``(agenda_size, 1)``.
``all_actions`` : ``List[ProductionRule]``
All actions for one instance.
``terminal_productions`` : ``Set[str]``
String representations of terminal productions in the corresponding world.
``max_num_terminals`` : ``int``
Length to which the checklist vectors will be padded till. This is the max number of
terminal productions in all the worlds in the batch.
"""
terminal_indices = []
target_checklist_list = []
agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
# We want to return checklist target and terminal actions that are column vectors to make
# computing softmax over the difference between checklist and target easier.
for index, action in enumerate(all_actions):
# Each action is a ProductionRule, a tuple where the first item is the production
# rule string.
if action[0] in terminal_productions:
terminal_indices.append([index])
if index in agenda_indices_set:
target_checklist_list.append([1])
else:
target_checklist_list.append([0])
while len(target_checklist_list) < max_num_terminals:
target_checklist_list.append([0])
terminal_indices.append([-1])
# (max_num_terminals, 1)
terminal_actions = agenda.new_tensor(terminal_indices)
# (max_num_terminals, 1)
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
checklist_mask = (target_checklist != 0).float()
return target_checklist, terminal_actions, checklist_mask
示例2: _get_checklist_info
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def _get_checklist_info(self,
agenda: torch.LongTensor,
all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Takes an agenda and a list of all actions and returns a target checklist against which the
checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
``False``, indices of all terminals that are not in the agenda will be masked.
Parameters
----------
``agenda`` : ``torch.LongTensor``
Agenda of one instance of size ``(agenda_size, 1)``.
``all_actions`` : ``List[ProductionRuleArray]``
All actions for one instance.
"""
terminal_indices = []
target_checklist_list = []
agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
for index, action in enumerate(all_actions):
# Each action is a ProductionRuleArray, a tuple where the first item is the production
# rule string.
if action[0] in self._terminal_productions:
terminal_indices.append([index])
if index in agenda_indices_set:
target_checklist_list.append([1])
else:
target_checklist_list.append([0])
# We want to return checklist target and terminal actions that are column vectors to make
# computing softmax over the difference between checklist and target easier.
# (num_terminals, 1)
terminal_actions = agenda.new_tensor(terminal_indices)
# (num_terminals, 1)
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
if self._penalize_non_agenda_actions:
# All terminal actions are relevant
checklist_mask = torch.ones_like(target_checklist)
else:
checklist_mask = (target_checklist != 0).float()
return target_checklist, terminal_actions, checklist_mask
示例3: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def forward(self, # type: ignore
sentence: Dict[str, torch.LongTensor],
worlds: List[List[NlvrWorld]],
actions: List[List[ProductionRule]],
identifier: List[str] = None,
target_action_sequences: torch.LongTensor = None,
labels: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Decoder logic for producing type constrained target sequences, trained to maximize marginal
likelihod over a set of approximate logical forms.
"""
batch_size = len(worlds)
initial_rnn_state = self._get_initial_rnn_state(sentence)
initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
for i in range(batch_size)]
label_strings = self._get_label_strings(labels) if labels is not None else None
# TODO (pradeep): Assuming all worlds give the same set of valid actions.
initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
range(batch_size)]
initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
action_history=[[] for _ in range(batch_size)],
score=initial_score_list,
rnn_state=initial_rnn_state,
grammar_state=initial_grammar_state,
possible_actions=actions,
extras=label_strings)
if target_action_sequences is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
target_action_sequences = target_action_sequences.squeeze(-1)
target_mask = target_action_sequences != self._action_padding_index
else:
target_mask = None
outputs: Dict[str, torch.Tensor] = {}
if identifier is not None:
outputs["identifier"] = identifier
if target_action_sequences is not None:
outputs = self._decoder_trainer.decode(initial_state,
self._decoder_step,
(target_action_sequences, target_mask))
if not self.training:
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
initial_state,
self._decoder_step,
keep_final_unfinished_states=False)
best_action_sequences: Dict[int, List[List[int]]] = {}
for i in range(batch_size):
# Decoding may not have terminated with any completed logical forms, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i in best_final_states:
best_action_indices = [best_final_states[i][0].action_history[0]]
best_action_sequences[i] = best_action_indices
batch_action_strings = self._get_action_strings(actions, best_action_sequences)
batch_denotations = self._get_denotations(batch_action_strings, worlds)
if target_action_sequences is not None:
self._update_metrics(action_strings=batch_action_strings,
worlds=worlds,
label_strings=label_strings)
else:
if metadata is not None:
outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
outputs['debug_info'] = []
for i in range(batch_size):
outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore
outputs["best_action_strings"] = batch_action_strings
outputs["denotations"] = batch_denotations
action_mapping = {}
for batch_index, batch_actions in enumerate(actions):
for action_index, action in enumerate(batch_actions):
action_mapping[(batch_index, action_index)] = action[0]
outputs['action_mapping'] = action_mapping
return outputs
示例4: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def forward(self, # type: ignore
question: Dict[str, torch.LongTensor],
table: Dict[str, torch.LongTensor],
world: List[WikiTablesWorld],
actions: List[List[ProductionRule]],
example_lisp_string: List[str] = None,
target_action_sequences: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
In this method we encode the table entities, link them to words in the question, then
encode the question. Then we set up the initial state for the decoder, and pass that
state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
if we're not.
Parameters
----------
question : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
table : ``Dict[str, torch.LongTensor]``
The output of ``KnowledgeGraphField.as_array()`` applied on the table
``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each
entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
get embeddings for each entity.
world : ``List[WikiTablesWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
example_lisp_string : ``List[str]``, optional (default = None)
The example (lisp-formatted) string corresponding to the given input. This comes
directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE
when evaluating denotation accuracy; it is otherwise unused.
target_action_sequences : torch.Tensor, optional (default = None)
A list of possibly valid action sequences, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, num_action_sequences,
sequence_length)``.
metadata : ``List[Dict[str, Any]]``, optional, (default = None)
Metadata containing the original tokenized question within a 'question_tokens' key.
"""
outputs: Dict[str, Any] = {}
rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
table,
world,
actions,
outputs)
batch_size = len(rnn_state)
initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
initial_score_list = [initial_score[i] for i in range(batch_size)]
initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), # type: ignore
action_history=[[] for _ in range(batch_size)],
score=initial_score_list,
rnn_state=rnn_state,
grammar_state=grammar_state,
possible_actions=actions,
extras=example_lisp_string,
debug_info=None)
if target_action_sequences is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
target_action_sequences = target_action_sequences.squeeze(-1)
target_mask = target_action_sequences != self._action_padding_index
else:
target_mask = None
if self.training:
return self._decoder_trainer.decode(initial_state,
self._decoder_step,
(target_action_sequences, target_mask))
else:
if target_action_sequences is not None:
outputs['loss'] = self._decoder_trainer.decode(initial_state,
self._decoder_step,
(target_action_sequences, target_mask))['loss']
num_steps = self._max_decoding_steps
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(num_steps,
initial_state,
self._decoder_step,
keep_final_unfinished_states=False)
for i in range(batch_size):
# Decoding may not have terminated with any completed logical forms, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i in best_final_states:
best_action_indices = best_final_states[i][0].action_history[0]
if target_action_sequences is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
targets = target_action_sequences[i].data
sequence_in_targets = 0
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._action_sequence_accuracy(sequence_in_targets)
self._compute_validation_outputs(actions,
#.........这里部分代码省略.........
示例5: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def forward(self, # type: ignore
tokens: Dict[str, torch.LongTensor],
valid_actions: List[List[ProductionRule]],
action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
if we're training, or a BeamSearch for inference, if we're not.
Parameters
----------
tokens : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
valid_actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
target_action_sequence : torch.Tensor, optional (default=None)
The action sequence for the correct action sequence, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
trailing dimension.
sql_queries : List[List[str]], optional (default=None)
A list of the SQL queries that are given during training or validation.
"""
embedded_utterance = self._utterance_embedder(tokens)
mask = util.get_text_field_mask(tokens).float()
batch_size = embedded_utterance.size(0)
# (batch_size, num_tokens, encoder_output_dim)
encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)
if action_sequence is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
action_sequence = action_sequence.squeeze(-1)
target_mask = action_sequence != self._action_padding_index
else:
target_mask = None
outputs: Dict[str, Any] = {}
if action_sequence is not None:
# target_action_sequence is of shape (batch_size, 1, target_sequence_length)
# here after we unsqueeze it for the MML trainer.
loss_output = self._decoder_trainer.decode(initial_state,
self._transition_function,
(action_sequence.unsqueeze(1),
target_mask.unsqueeze(1)))
outputs.update(loss_output)
if not self.training:
action_mapping = []
for batch_actions in valid_actions:
batch_action_mapping = {}
for action_index, action in enumerate(batch_actions):
batch_action_mapping[action_index] = action[0]
action_mapping.append(batch_action_mapping)
outputs['action_mapping'] = action_mapping
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(self._max_decoding_steps,
initial_state,
self._transition_function,
keep_final_unfinished_states=True)
outputs['best_action_sequence'] = []
outputs['debug_info'] = []
outputs['predicted_sql_query'] = []
outputs['sql_queries'] = []
for i in range(batch_size):
# Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i not in best_final_states:
self._exact_match(0)
self._denotation_accuracy(0)
self._valid_sql_query(0)
self._action_similarity(0)
outputs['predicted_sql_query'].append('')
continue
best_action_indices = best_final_states[i][0].action_history[0]
action_strings = [action_mapping[i][action_index]
for action_index in best_action_indices]
predicted_sql_query = action_sequence_to_sql(action_strings)
if action_sequence is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
targets = action_sequence[i].data
sequence_in_targets = 0
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._exact_match(sequence_in_targets)
similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
self._action_similarity(similarity.ratio())
outputs['best_action_sequence'].append(action_strings)
#.........这里部分代码省略.........
示例6: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def forward(self, # type: ignore
utterance: Dict[str, torch.LongTensor],
world: List[AtisWorld],
actions: List[List[ProductionRule]],
linking_scores: torch.Tensor,
target_action_sequence: torch.LongTensor = None,
sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
if we're training, or a BeamSearch for inference, if we're not.
Parameters
----------
utterance : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
world : ``List[AtisWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
linking_scores: ``torch.Tensor``
A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
is deterministically generated where each entry indicates whether a token generated an entity.
This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
target_action_sequence : torch.Tensor, optional (default=None)
The action sequence for the correct action sequence, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
trailing dimension.
sql_queries : List[List[str]], optional (default=None)
A list of the SQL queries that are given during training or validation.
"""
initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
batch_size = linking_scores.shape[0]
if target_action_sequence is not None:
# Remove the trailing dimension (from ListField[ListField[IndexField]]).
target_action_sequence = target_action_sequence.squeeze(-1)
target_mask = target_action_sequence != self._action_padding_index
else:
target_mask = None
if self.training:
# target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
# the MML trainer.
return self._decoder_trainer.decode(initial_state,
self._transition_function,
(target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
else:
# TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
action_mapping = {}
for batch_index, batch_actions in enumerate(actions):
for action_index, action in enumerate(batch_actions):
action_mapping[(batch_index, action_index)] = action[0]
outputs: Dict[str, Any] = {'action_mapping': action_mapping}
outputs['linking_scores'] = linking_scores
if target_action_sequence is not None:
outputs['loss'] = self._decoder_trainer.decode(initial_state,
self._transition_function,
(target_action_sequence.unsqueeze(1),
target_mask.unsqueeze(1)))['loss']
num_steps = self._max_decoding_steps
# This tells the state to start keeping track of debug info, which we'll pass along in
# our output dictionary.
initial_state.debug_info = [[] for _ in range(batch_size)]
best_final_states = self._beam_search.search(num_steps,
initial_state,
self._transition_function,
keep_final_unfinished_states=False)
outputs['best_action_sequence'] = []
outputs['debug_info'] = []
outputs['entities'] = []
outputs['predicted_sql_query'] = []
outputs['sql_queries'] = []
outputs['utterance'] = []
outputs['tokenized_utterance'] = []
for i in range(batch_size):
# Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
# isn't long enough (or if the model is not trained enough and gets into an
# infinite action loop).
if i not in best_final_states:
self._exact_match(0)
self._denotation_accuracy(0)
self._valid_sql_query(0)
self._action_similarity(0)
outputs['predicted_sql_query'].append('')
continue
best_action_indices = best_final_states[i][0].action_history[0]
action_strings = [action_mapping[(i, action_index)]
for action_index in best_action_indices]
predicted_sql_query = action_sequence_to_sql(action_strings)
if target_action_sequence is not None:
# Use a Tensor, not a Variable, to avoid a memory leak.
#.........这里部分代码省略.........
示例7: forward
# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
def forward(self, # type: ignore
question: Dict[str, torch.LongTensor],
table: Dict[str, torch.LongTensor],
world: List[QuarelWorld],
actions: List[List[ProductionRule]],
entity_bits: torch.Tensor = None,
denotation_target: torch.Tensor = None,
target_action_sequences: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
# pylint: disable=unused-argument
"""
In this method we encode the table entities, link them to words in the question, then
encode the question. Then we set up the initial state for the decoder, and pass that
state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
if we're not.
Parameters
----------
question : Dict[str, torch.LongTensor]
The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
be passed through a ``TextFieldEmbedder`` and then through an encoder.
table : ``Dict[str, torch.LongTensor]``
The output of ``KnowledgeGraphField.as_array()`` applied on the table
``KnowledgeGraphField``. This output is similar to a ``TextField`` output, where each
entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
get embeddings for each entity.
world : ``List[QuarelWorld]``
We use a ``MetadataField`` to get the ``World`` for each input instance. Because of
how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
actions : ``List[List[ProductionRule]]``
A list of all possible actions for each ``World`` in the batch, indexed into a
``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these
and use the embeddings to determine which action to take at each timestep in the
decoder.
target_action_sequences : torch.Tensor, optional (default=None)
A list of possibly valid action sequences, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, num_action_sequences,
sequence_length)``.
"""
table_text = table['text']
self._debug_count -= 1
# (batch_size, question_length, embedding_dim)
embedded_question = self._question_embedder(question)
question_mask = util.get_text_field_mask(question).float()
num_question_tokens = embedded_question.size(1)
# (batch_size, num_entities, num_entity_tokens, embedding_dim)
embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
# entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
# entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
# These encode the same information, but for efficiency reasons later it's nice
# to have one version as a tensor and one that's accessible on the cpu.
entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)
if self._use_entities:
if self._entity_similarity_mode == "dot_product":
# Compute entity and question word cosine similarity. Need to add a small value to
# to the table norm since there are padding values which cause a divide by 0.
embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
num_entities * num_entity_tokens,
self._embedding_dim),
torch.transpose(embedded_question, 1, 2))
question_entity_similarity = question_entity_similarity.view(batch_size,
num_entities,
num_entity_tokens,
num_question_tokens)
# (batch_size, num_entities, num_question_tokens)
question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
linking_scores = question_entity_similarity_max_score
elif self._entity_similarity_mode == "weighted_dot_product":
embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
product = torch.mul(eqe, ete)
product = product.view(batch_size,
num_question_tokens*num_entities*num_entity_tokens,
self._embedding_dim)
question_entity_similarity = self._entity_similarity_layer(product)
question_entity_similarity = question_entity_similarity.view(batch_size,
num_entities,
num_entity_tokens,
num_question_tokens)
# (batch_size, num_entities, num_question_tokens)
question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
#.........这里部分代码省略.........