当前位置: 首页>>代码示例>>Python>>正文


Python Tensor.bmm方法代码示例

本文整理汇总了Python中torch.Tensor.bmm方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.bmm方法的具体用法?Python Tensor.bmm怎么用?Python Tensor.bmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.Tensor的用法示例。


在下文中一共展示了Tensor.bmm方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: weighted_sum

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
    """
    Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an
    "attention" vector), and returns a weighted sum of the rows in the matrix.  This is the typical
    computation performed after an attention mechanism.

    Note that while we call this a "matrix" of vectors and an attention "vector", we also handle
    higher-order tensors.  We always sum over the second-to-last dimension of the "matrix", and we
    assume that all dimensions in the "matrix" prior to the last dimension are matched in the
    "vector".  Non-matched dimensions in the "vector" must be `directly after the batch dimension`.

    For example, say I have a "matrix" with dimensions ``(batch_size, num_queries, num_words,
    embedding_dim)``.  The attention "vector" then must have at least those dimensions, and could
    have more. Both:

        - ``(batch_size, num_queries, num_words)`` (distribution over words for each query)
        - ``(batch_size, num_documents, num_queries, num_words)`` (distribution over words in a
          query for each document)

    are valid input "vectors", producing tensors of shape:
    ``(batch_size, num_queries, embedding_dim)`` and
    ``(batch_size, num_documents, num_queries, embedding_dim)`` respectively.
    """
    # We'll special-case a few settings here, where there are efficient (but poorly-named)
    # operations in pytorch that already do the computation we need.
    if attention.dim() == 2 and matrix.dim() == 3:
        return attention.unsqueeze(1).bmm(matrix).squeeze(1)
    if attention.dim() == 3 and matrix.dim() == 3:
        return attention.bmm(matrix)
    if matrix.dim() - 1 < attention.dim():
        expanded_size = list(matrix.size())
        for i in range(attention.dim() - matrix.dim() + 1):
            matrix = matrix.unsqueeze(1)
            expanded_size.insert(i + 1, attention.size(i + 1))
        matrix = matrix.expand(*expanded_size)
    intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
    return intermediate.sum(dim=-2)
开发者ID:cyzhangAThit,项目名称:GLUE-baselines,代码行数:39,代码来源:util.py

示例2: forward

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
 def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:
     return matrix_1.bmm(matrix_2.transpose(2, 1))
开发者ID:apmoore1,项目名称:allennlp,代码行数:4,代码来源:dot_product_matrix_attention.py

示例3: _forward_internal

# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
 def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
     return matrix.bmm(vector.unsqueeze(-1)).squeeze(-1)
开发者ID:apmoore1,项目名称:allennlp,代码行数:4,代码来源:dot_product_attention.py


注:本文中的torch.Tensor.bmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。