本文整理汇总了Python中torch.Tensor.bmm方法的典型用法代码示例。如果您正苦于以下问题:Python Tensor.bmm方法的具体用法?Python Tensor.bmm怎么用?Python Tensor.bmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.Tensor
的用法示例。
在下文中一共展示了Tensor.bmm方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: weighted_sum
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
"""
Takes a matrix of vectors and a set of weights over the rows in the matrix (which we call an
"attention" vector), and returns a weighted sum of the rows in the matrix. This is the typical
computation performed after an attention mechanism.
Note that while we call this a "matrix" of vectors and an attention "vector", we also handle
higher-order tensors. We always sum over the second-to-last dimension of the "matrix", and we
assume that all dimensions in the "matrix" prior to the last dimension are matched in the
"vector". Non-matched dimensions in the "vector" must be `directly after the batch dimension`.
For example, say I have a "matrix" with dimensions ``(batch_size, num_queries, num_words,
embedding_dim)``. The attention "vector" then must have at least those dimensions, and could
have more. Both:
- ``(batch_size, num_queries, num_words)`` (distribution over words for each query)
- ``(batch_size, num_documents, num_queries, num_words)`` (distribution over words in a
query for each document)
are valid input "vectors", producing tensors of shape:
``(batch_size, num_queries, embedding_dim)`` and
``(batch_size, num_documents, num_queries, embedding_dim)`` respectively.
"""
# We'll special-case a few settings here, where there are efficient (but poorly-named)
# operations in pytorch that already do the computation we need.
if attention.dim() == 2 and matrix.dim() == 3:
return attention.unsqueeze(1).bmm(matrix).squeeze(1)
if attention.dim() == 3 and matrix.dim() == 3:
return attention.bmm(matrix)
if matrix.dim() - 1 < attention.dim():
expanded_size = list(matrix.size())
for i in range(attention.dim() - matrix.dim() + 1):
matrix = matrix.unsqueeze(1)
expanded_size.insert(i + 1, attention.size(i + 1))
matrix = matrix.expand(*expanded_size)
intermediate = attention.unsqueeze(-1).expand_as(matrix) * matrix
return intermediate.sum(dim=-2)
示例2: forward
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:
return matrix_1.bmm(matrix_2.transpose(2, 1))
示例3: _forward_internal
# 需要导入模块: from torch import Tensor [as 别名]
# 或者: from torch.Tensor import bmm [as 别名]
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
return matrix.bmm(vector.unsqueeze(-1)).squeeze(-1)