本文整理汇总了Python中torch.Tensor.exp方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.exp方法的具体用法?Python Tensor.exp怎么用?Python Tensor.exp使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.Tensor
的用法示例。
在下文中一共展示了Tensor.exp方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _compute_new_states
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import exp [as 别名]
def _compute_new_states(state: WikiTablesDecoderState,
log_probs: torch.Tensor,
hidden_state: torch.Tensor,
memory_cell: torch.Tensor,
action_embeddings: torch.Tensor,
attended_question: torch.Tensor,
attention_weights: torch.Tensor,
considered_actions: List[List[int]],
allowed_actions: List[Set[int]],
max_actions: int = None) -> List[WikiTablesDecoderState]:
# Each group index here might get accessed multiple times, and doing the slicing operation
# each time is more expensive than doing it once upfront. These three lines give about a
# 10% speedup in training time. I also tried this with sorted_log_probs and
# action_embeddings, but those get accessed for _each action_, so doing the splits there
# didn't help.
hidden_state = [x.squeeze(0) for x in hidden_state.split(1, 0)]
memory_cell = [x.squeeze(0) for x in memory_cell.split(1, 0)]
attended_question = [x.squeeze(0) for x in attended_question.split(1, 0)]
sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)
if max_actions is not None:
# We might need a version of `sorted_log_probs` on the CPU later, but only if we need
# to truncate the best states to `max_actions`.
sorted_log_probs_cpu = sorted_log_probs.detach().cpu().numpy()
if state.debug_info is not None:
probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices,
sorted_actions)):
for action_index, action in enumerate(group_actions):
# `action` is currently the index in `log_probs`, not the actual action ID. To get
# the action ID, we need to go through `considered_actions`.
action = considered_actions[group_index][action]
if action == -1:
# This was padding.
continue
if allowed_actions is not None and action not in allowed_actions[group_index]:
# This happens when our _decoder trainer_ wants us to only evaluate certain
# actions, likely because they are the gold actions in this state. We just skip
# emitting any state that isn't allowed by the trainer, because constructing the
# new state can be expensive.
continue
best_next_states[batch_index].append((group_index, action_index, action))
new_states = []
for batch_index, best_states in sorted(best_next_states.items()):
if max_actions is not None:
# We sorted previously by _group_index_, but we then combined by _batch_index_. We
# need to get the top next states for each _batch_ instance, so we sort all of the
# instance's states again (across group index) by score. We don't need to do this
# if `max_actions` is None, because we'll be keeping all of the next states,
# anyway.
best_states.sort(key=lambda x: sorted_log_probs_cpu[x[:2]], reverse=True)
best_states = best_states[:max_actions]
for group_index, action_index, action in best_states:
# We'll yield a bunch of states here that all have a `group_size` of 1, so that the
# learning algorithm can decide how many of these it wants to keep, and it can just
# regroup them later, as that's a really easy operation.
batch_index = state.batch_indices[group_index]
new_action_history = state.action_history[group_index] + [action]
new_score = sorted_log_probs[group_index, action_index]
# `action_index` is the index in the _sorted_ tensors, but the action embedding
# matrix is _not_ sorted, so we need to get back the original, non-sorted action
# index before we get the action embedding.
action_embedding_index = sorted_actions[group_index][action_index]
action_embedding = action_embeddings[group_index, action_embedding_index, :]
production_rule = state.possible_actions[batch_index][action][0]
new_grammar_state = state.grammar_state[group_index].take_action(production_rule)
if state.checklist_state[0] is not None:
new_checklist_state = [state.checklist_state[group_index].update(action)]
else:
new_checklist_state = None
if state.debug_info is not None:
debug_info = {
'considered_actions': considered_actions[group_index],
'question_attention': attention_weights[group_index],
'probabilities': probs_cpu[group_index],
}
new_debug_info = [state.debug_info[group_index] + [debug_info]]
else:
new_debug_info = None
new_rnn_state = RnnState(hidden_state[group_index],
memory_cell[group_index],
action_embedding,
attended_question[group_index],
state.rnn_state[group_index].encoder_outputs,
state.rnn_state[group_index].encoder_output_mask)
new_state = WikiTablesDecoderState(batch_indices=[batch_index],
action_history=[new_action_history],
score=[new_score],
rnn_state=[new_rnn_state],
grammar_state=[new_grammar_state],
action_embeddings=state.action_embeddings,
output_action_embeddings=state.output_action_embeddings,
action_biases=state.action_biases,
action_indices=state.action_indices,
possible_actions=state.possible_actions,
flattened_linking_scores=state.flattened_linking_scores,
#.........这里部分代码省略.........