当前位置: 首页>>代码示例>>Python>>正文


Python LongTensor.squeeze方法代码示例

本文整理汇总了Python中torch.LongTensor.squeeze方法的典型用法代码示例。如果您正苦于以下问题:Python LongTensor.squeeze方法的具体用法?Python LongTensor.squeeze怎么用?Python LongTensor.squeeze使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.LongTensor的用法示例。


在下文中一共展示了LongTensor.squeeze方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _get_checklist_info

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def _get_checklist_info(agenda: torch.LongTensor,
                            all_actions: List[ProductionRule],
                            terminal_productions: Set[str],
                            max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
开发者ID:apmoore1,项目名称:allennlp,代码行数:48,代码来源:wikitables_erm_semantic_parser.py

示例2: _get_checklist_info

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
开发者ID:pyknife,项目名称:allennlp,代码行数:46,代码来源:nlvr_coverage_semantic_parser.py

示例3: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRule]],
                identifier: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        label_strings = self._get_label_strings(labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
                                                                 initial_state,
                                                                 self._decoder_step,
                                                                 keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [best_final_states[i][0].action_history[0]]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings, worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs
开发者ID:apmoore1,项目名称:allennlp,代码行数:81,代码来源:nlvr_direct_semantic_parser.py

示例4: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRule]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default = None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=example_lisp_string,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            self._compute_validation_outputs(actions,
#.........这里部分代码省略.........
开发者ID:apmoore1,项目名称:allennlp,代码行数:103,代码来源:wikitables_mml_semantic_parser.py

示例5: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
#.........这里部分代码省略.........
开发者ID:apmoore1,项目名称:allennlp,代码行数:103,代码来源:text2sql_parser.py

示例6: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def forward(self,  # type: ignore
                utterance: Dict[str, torch.LongTensor],
                world: List[AtisWorld],
                actions: List[List[ProductionRule]],
                linking_scores: torch.Tensor,
                target_action_sequence: torch.LongTensor = None,
                sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(initial_state,
                                                self._transition_function,
                                                (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            outputs['linking_scores'] = linking_scores
            if target_action_sequence is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._transition_function,
                                                               (target_action_sequence.unsqueeze(1),
                                                                target_mask.unsqueeze(1)))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            outputs['utterance'] = []
            outputs['tokenized_utterance'] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[(i, action_index)]
                                  for action_index in best_action_indices]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
#.........这里部分代码省略.........
开发者ID:apmoore1,项目名称:allennlp,代码行数:103,代码来源:atis_semantic_parser.py

示例7: forward

# 需要导入模块: from torch import LongTensor [as 别名]
# 或者: from torch.LongTensor import squeeze [as 别名]
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
#.........这里部分代码省略.........
开发者ID:apmoore1,项目名称:allennlp,代码行数:103,代码来源:quarel_semantic_parser.py


注:本文中的torch.LongTensor.squeeze方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。