当前位置: 首页>>代码示例>>Python>>正文


Python Tensor.sort方法代码示例

本文整理汇总了Python中torch.Tensor.sort方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.sort方法的具体用法?Python Tensor.sort怎么用?Python Tensor.sort使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.Tensor的用法示例。


在下文中一共展示了Tensor.sort方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: sort_batch_by_length

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import sort [as 别名]
def sort_batch_by_length(tensor: torch.Tensor, sequence_lengths: torch.Tensor):
    """
    Sort a batch first tensor by some specified lengths.

    Parameters
    ----------
    tensor : torch.FloatTensor, required.
        A batch first Pytorch tensor.
    sequence_lengths : torch.LongTensor, required.
        A tensor representing the lengths of some dimension of the tensor which
        we want to sort by.

    Returns
    -------
    sorted_tensor : torch.FloatTensor
        The original tensor sorted along the batch dimension with respect to sequence_lengths.
    sorted_sequence_lengths : torch.LongTensor
        The original sequence_lengths sorted by decreasing size.
    restoration_indices : torch.LongTensor
        Indices into the sorted_tensor such that
        ``sorted_tensor.index_select(0, restoration_indices) == original_tensor``
    permuation_index : torch.LongTensor
        The indices used to sort the tensor. This is useful if you want to sort many
        tensors using the same ordering.
    """

    if not isinstance(tensor, torch.Tensor) or not isinstance(sequence_lengths, torch.Tensor):
        raise ConfigurationError("Both the tensor and sequence lengths must be torch.Tensors.")

    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)

    index_range = sequence_lengths.new_tensor(torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index
开发者ID:pyknife,项目名称:allennlp,代码行数:40,代码来源:util.py

示例2: _compute_new_states

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import sort [as 别名]
    def _compute_new_states(state: WikiTablesDecoderState,
                            log_probs: torch.Tensor,
                            hidden_state: torch.Tensor,
                            memory_cell: torch.Tensor,
                            action_embeddings: torch.Tensor,
                            attended_question: torch.Tensor,
                            attention_weights: torch.Tensor,
                            considered_actions: List[List[int]],
                            allowed_actions: List[Set[int]],
                            max_actions: int = None) -> List[WikiTablesDecoderState]:
        # Each group index here might get accessed multiple times, and doing the slicing operation
        # each time is more expensive than doing it once upfront.  These three lines give about a
        # 10% speedup in training time.  I also tried this with sorted_log_probs and
        # action_embeddings, but those get accessed for _each action_, so doing the splits there
        # didn't help.
        hidden_state = [x.squeeze(0) for x in hidden_state.split(1, 0)]
        memory_cell = [x.squeeze(0) for x in memory_cell.split(1, 0)]
        attended_question = [x.squeeze(0) for x in attended_question.split(1, 0)]

        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)
        if max_actions is not None:
            # We might need a version of `sorted_log_probs` on the CPU later, but only if we need
            # to truncate the best states to `max_actions`.
            sorted_log_probs_cpu = sorted_log_probs.detach().cpu().numpy()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices,
                                                                       sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if action == -1:
                    # This was padding.
                    continue
                if allowed_actions is not None and action not in allowed_actions[group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append((group_index, action_index, action))
        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            if max_actions is not None:
                # We sorted previously by _group_index_, but we then combined by _batch_index_.  We
                # need to get the top next states for each _batch_ instance, so we sort all of the
                # instance's states again (across group index) by score.  We don't need to do this
                # if `max_actions` is None, because we'll be keeping all of the next states,
                # anyway.
                best_states.sort(key=lambda x: sorted_log_probs_cpu[x[:2]], reverse=True)
                best_states = best_states[:max_actions]
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                batch_index = state.batch_indices[group_index]
                new_action_history = state.action_history[group_index] + [action]
                new_score = sorted_log_probs[group_index, action_index]

                # `action_index` is the index in the _sorted_ tensors, but the action embedding
                # matrix is _not_ sorted, so we need to get back the original, non-sorted action
                # index before we get the action embedding.
                action_embedding_index = sorted_actions[group_index][action_index]
                action_embedding = action_embeddings[group_index, action_embedding_index, :]
                production_rule = state.possible_actions[batch_index][action][0]
                new_grammar_state = state.grammar_state[group_index].take_action(production_rule)
                if state.checklist_state[0] is not None:
                    new_checklist_state = [state.checklist_state[group_index].update(action)]
                else:
                    new_checklist_state = None
                if state.debug_info is not None:
                    debug_info = {
                            'considered_actions': considered_actions[group_index],
                            'question_attention': attention_weights[group_index],
                            'probabilities': probs_cpu[group_index],
                            }
                    new_debug_info = [state.debug_info[group_index] + [debug_info]]
                else:
                    new_debug_info = None

                new_rnn_state = RnnState(hidden_state[group_index],
                                         memory_cell[group_index],
                                         action_embedding,
                                         attended_question[group_index],
                                         state.rnn_state[group_index].encoder_outputs,
                                         state.rnn_state[group_index].encoder_output_mask)
                new_state = WikiTablesDecoderState(batch_indices=[batch_index],
                                                   action_history=[new_action_history],
                                                   score=[new_score],
                                                   rnn_state=[new_rnn_state],
                                                   grammar_state=[new_grammar_state],
                                                   action_embeddings=state.action_embeddings,
                                                   output_action_embeddings=state.output_action_embeddings,
                                                   action_biases=state.action_biases,
                                                   action_indices=state.action_indices,
                                                   possible_actions=state.possible_actions,
                                                   flattened_linking_scores=state.flattened_linking_scores,
#.........这里部分代码省略.........
开发者ID:pyknife,项目名称:allennlp,代码行数:103,代码来源:wikitables_decoder_step.py


注:本文中的torch.Tensor.sort方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。