本文整理汇总了Python中torch.Tensor.detach方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.detach方法的具体用法?Python Tensor.detach怎么用?Python Tensor.detach使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.Tensor
的用法示例。
在下文中一共展示了Tensor.detach方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_best_span
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
raise ValueError("Input shapes must be (batch_size, passage_length)")
batch_size, passage_length = span_start_logits.size()
max_span_log_prob = [-1e20] * batch_size
span_start_argmax = [0] * batch_size
best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)
span_start_logits = span_start_logits.detach().cpu().numpy()
span_end_logits = span_end_logits.detach().cpu().numpy()
for b in range(batch_size): # pylint: disable=invalid-name
for j in range(passage_length):
val1 = span_start_logits[b, span_start_argmax[b]]
if val1 < span_start_logits[b, j]:
span_start_argmax[b] = j
val1 = span_start_logits[b, j]
val2 = span_end_logits[b, j]
if val1 + val2 > max_span_log_prob[b]:
best_word_span[b, 0] = span_start_argmax[b]
best_word_span[b, 1] = j
max_span_log_prob[b] = val1 + val2
return best_word_span
示例2: _run_mst_decoding
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
heads = []
head_tags = []
for energy, length in zip(batch_energy.detach().cpu(), lengths):
scores, tag_ids = energy.max(dim=0)
# Although we need to include the root node so that the MST includes it,
# we do not want any word to be the parent of the root node.
# Here, we enforce this by setting the scores for all word -> ROOT edges
# edges to be 0.
scores[0, :] = 0
# Decode the heads. Because we modify the scores to prevent
# adding in word -> ROOT edges, we need to find the labels ourselves.
instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)
# Find the labels which correspond to the edges in the max spanning tree.
instance_head_tags = []
for child, parent in enumerate(instance_heads):
instance_head_tags.append(tag_ids[parent, child].item())
# We don't care what the head or tag is for the root token, but by default it's
# not necesarily the same in the batched vs unbatched case, which is annoying.
# Here we'll just set them to zero.
instance_heads[0] = 0
instance_head_tags[0] = 0
heads.append(instance_heads)
head_tags.append(instance_head_tags)
return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
示例3: decode_all
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
if not isinstance(predicted_indices, numpy.ndarray):
predicted_indices = predicted_indices.detach().cpu().numpy()
all_predicted_tokens = []
for indices in predicted_indices:
indices = list(indices)
# Collect indices till the first end_symbol
if self._end_index in indices:
indices = indices[:indices.index(self._end_index)]
predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
for x in indices]
all_predicted_tokens.append(predicted_tokens)
return all_predicted_tokens
示例4: __init__
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def __init__(self,
terminal_actions: torch.Tensor,
checklist_target: torch.Tensor,
checklist_mask: torch.Tensor,
checklist: torch.Tensor) -> None:
self.terminal_actions = terminal_actions
self.checklist_target = checklist_target
self.checklist_mask = checklist_mask
self.checklist = checklist
# Mapping from batch action indices to indices in any of the four vectors above.
self.terminal_indices_dict: Dict[int, int] = {}
for checklist_index, batch_action_index in enumerate(terminal_actions.detach().cpu()):
action_index = int(batch_action_index[0])
if action_index == -1:
continue
self.terminal_indices_dict[action_index] = checklist_index
示例5: __init__
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def __init__(self,
terminal_actions: torch.Tensor,
checklist_target: torch.Tensor,
checklist_mask: torch.Tensor,
checklist: torch.Tensor,
terminal_indices_dict: Dict[int, int] = None) -> None:
self.terminal_actions = terminal_actions
self.checklist_target = checklist_target
self.checklist_mask = checklist_mask
self.checklist = checklist
if terminal_indices_dict is not None:
self.terminal_indices_dict = terminal_indices_dict
else:
self.terminal_indices_dict: Dict[int, int] = {}
for checklist_index, batch_action_index in enumerate(terminal_actions.detach().cpu()):
action_index = int(batch_action_index[0])
if action_index == -1:
continue
self.terminal_indices_dict[action_index] = checklist_index
示例6: _fix_feature
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def _fix_feature(Z: Tensor, value: Optional[float]) -> Tensor:
r"""Helper function returns a Tensor like `Z` filled with `value` if provided."""
if value is None:
return Z.detach()
return torch.full_like(Z, value)