当前位置: 首页>>代码示例>>Python>>正文


Python Tensor.detach方法代码示例

本文整理汇总了Python中torch.Tensor.detach方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.detach方法的具体用法?Python Tensor.detach怎么用?Python Tensor.detach使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.Tensor的用法示例。


在下文中一共展示了Tensor.detach方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_best_span

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
开发者ID:apmoore1,项目名称:allennlp,代码行数:27,代码来源:bidaf.py

示例2: _run_mst_decoding

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
    def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        heads = []
        head_tags = []
        for energy, length in zip(batch_energy.detach().cpu(), lengths):
            scores, tag_ids = energy.max(dim=0)
            # Although we need to include the root node so that the MST includes it,
            # we do not want any word to be the parent of the root node.
            # Here, we enforce this by setting the scores for all word -> ROOT edges
            # edges to be 0.
            scores[0, :] = 0
            # Decode the heads. Because we modify the scores to prevent
            # adding in word -> ROOT edges, we need to find the labels ourselves.
            instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False)

            # Find the labels which correspond to the edges in the max spanning tree.
            instance_head_tags = []
            for child, parent in enumerate(instance_heads):
                instance_head_tags.append(tag_ids[parent, child].item())
            # We don't care what the head or tag is for the root token, but by default it's
            # not necesarily the same in the batched vs unbatched case, which is annoying.
            # Here we'll just set them to zero.
            instance_heads[0] = 0
            instance_head_tags[0] = 0
            heads.append(instance_heads)
            head_tags.append(instance_head_tags)
        return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags))
开发者ID:ziaridoy20,项目名称:allennlp,代码行数:28,代码来源:biaffine_dependency_parser.py

示例3: decode_all

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
 def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
     if not isinstance(predicted_indices, numpy.ndarray):
         predicted_indices = predicted_indices.detach().cpu().numpy()
     all_predicted_tokens = []
     for indices in predicted_indices:
         indices = list(indices)
         # Collect indices till the first end_symbol
         if self._end_index in indices:
             indices = indices[:indices.index(self._end_index)]
         predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                             for x in indices]
         all_predicted_tokens.append(predicted_tokens)
     return all_predicted_tokens
开发者ID:apmoore1,项目名称:allennlp,代码行数:15,代码来源:event2mind.py

示例4: __init__

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
 def __init__(self,
              terminal_actions: torch.Tensor,
              checklist_target: torch.Tensor,
              checklist_mask: torch.Tensor,
              checklist: torch.Tensor) -> None:
     self.terminal_actions = terminal_actions
     self.checklist_target = checklist_target
     self.checklist_mask = checklist_mask
     self.checklist = checklist
     # Mapping from batch action indices to indices in any of the four vectors above.
     self.terminal_indices_dict: Dict[int, int] = {}
     for checklist_index, batch_action_index in enumerate(terminal_actions.detach().cpu()):
         action_index = int(batch_action_index[0])
         if action_index == -1:
             continue
         self.terminal_indices_dict[action_index] = checklist_index
开发者ID:pyknife,项目名称:allennlp,代码行数:18,代码来源:checklist_state.py

示例5: __init__

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
 def __init__(self,
              terminal_actions: torch.Tensor,
              checklist_target: torch.Tensor,
              checklist_mask: torch.Tensor,
              checklist: torch.Tensor,
              terminal_indices_dict: Dict[int, int] = None) -> None:
     self.terminal_actions = terminal_actions
     self.checklist_target = checklist_target
     self.checklist_mask = checklist_mask
     self.checklist = checklist
     if terminal_indices_dict is not None:
         self.terminal_indices_dict = terminal_indices_dict
     else:
         self.terminal_indices_dict: Dict[int, int] = {}
         for checklist_index, batch_action_index in enumerate(terminal_actions.detach().cpu()):
             action_index = int(batch_action_index[0])
             if action_index == -1:
                 continue
             self.terminal_indices_dict[action_index] = checklist_index
开发者ID:apmoore1,项目名称:allennlp,代码行数:21,代码来源:checklist_statelet.py

示例6: _fix_feature

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import detach [as 别名]
def _fix_feature(Z: Tensor, value: Optional[float]) -> Tensor:
    r"""Helper function returns a Tensor like `Z` filled with `value` if provided."""
    if value is None:
        return Z.detach()
    return torch.full_like(Z, value)
开发者ID:saschwan,项目名称:botorch,代码行数:7,代码来源:utils.py


注:本文中的torch.Tensor.detach方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。