本文整理汇总了Python中torch.dsmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch.dsmm方法的具体用法?Python torch.dsmm怎么用?Python torch.dsmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.dsmm方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_dsmm
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(x.to_dense(), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
示例2: test_interpolation
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_interpolation():
x = torch.linspace(0.01, 1, 100)
grid = torch.linspace(-0.05, 1.05, 50)
J, C = Interpolation().interpolate(grid, x)
W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
test_func_grid = grid.pow(2)
test_func_x = x.pow(2)
interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()
assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
示例3: _derivative_quadratic_form_factory
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def _derivative_quadratic_form_factory(self, *args):
def closure(left_vectors, right_vectors):
if left_vectors.ndimension() == 1:
left_factor = left_vectors.unsqueeze(0)
right_factor = right_vectors.unsqueeze(0)
else:
left_factor = left_vectors
right_factor = right_vectors
if len(args) == 1:
columns, = args
return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
elif len(args) == 3:
columns, W_left, W_right = args
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None
elif len(args) == 4:
columns, W_left, W_right, added_diag, = args
diag_grad = columns.new(len(added_diag)).zero_()
diag_grad[0] = (left_factor * right_factor).sum()
left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()
res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
return res, None, None, diag_grad
return closure
示例4: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def forward(self, dense):
if self.sparse.ndimension() == 3:
return bdsmm(self.sparse, dense)
else:
return torch.dsmm(self.sparse, dense)
示例5: backward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def backward(self, grad_output):
if self.sparse.ndimension() == 3:
return bdsmm(self.sparse.transpose(1, 2), grad_output)
else:
return torch.dsmm(self.sparse.t(), grad_output)
示例6: test_dsmm
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_dsmm(self):
def test_shape(di, dj, dk):
x = self._gen_sparse(2, 20, [di, dj])[0]
y = self.randn(dj, dk)
res = torch.dsmm(x, y)
expected = torch.mm(self.safeToDense(x), y)
self.assertEqual(res, expected)
test_shape(7, 5, 3)
test_shape(1000, 100, 100)
test_shape(3000, 64, 300)
示例7: kp_interpolated_toeplitz_matmul
# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
"""
Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
symmetric Toeplitz matrices.
Args:
- toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
length n_i
- interp_left (sparse matrix nxm) - Left interpolation matrix
- interp_right (sparse matrix pxm) - Right interpolation matrix
- tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
- noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.
Returns:
- tensor
"""
output_dims = tensor.ndimension()
noise_term = None
if output_dims == 1:
tensor = tensor.unsqueeze(1)
if noise_diag is not None:
noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor
if interp_left is not None:
# Get interp_{r}^{T} tensor
interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
# Get (T interp_{r}^{T}) tensor
rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)
# Get (interp_{l} T interp_{r}^{T})tensor
output = torch.dsmm(interp_left, rhs)
else:
output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)
if noise_term is not None:
# Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
output = output + noise_term
if output_dims == 1:
output = output.squeeze(1)
return output