当前位置: 首页>>代码示例>>Python>>正文


Python torch.dsmm方法代码示例

本文整理汇总了Python中torch.dsmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch.dsmm方法的具体用法?Python torch.dsmm怎么用?Python torch.dsmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.dsmm方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_dsmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(x.to_dense(), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:14,代码来源:test_sparse.py

示例2: test_interpolation

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5) 
开发者ID:jrg365,项目名称:gpytorch,代码行数:13,代码来源:test_cubic_interpolation.py

示例3: _derivative_quadratic_form_factory

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def _derivative_quadratic_form_factory(self, *args):
        def closure(left_vectors, right_vectors):
            if left_vectors.ndimension() == 1:
                left_factor = left_vectors.unsqueeze(0)
                right_factor = right_vectors.unsqueeze(0)
            else:
                left_factor = left_vectors
                right_factor = right_vectors
            if len(args) == 1:
                columns, = args
                return kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor),
            elif len(args) == 3:
                columns, W_left, W_right = args
                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None
            elif len(args) == 4:
                columns, W_left, W_right, added_diag, = args
                diag_grad = columns.new(len(added_diag)).zero_()
                diag_grad[0] = (left_factor * right_factor).sum()

                left_factor = torch.dsmm(W_left.t(), left_factor.t()).t()
                right_factor = torch.dsmm(W_right.t(), right_factor.t()).t()

                res = kp_sym_toeplitz_derivative_quadratic_form(columns, left_factor, right_factor)
                return res, None, None, diag_grad

        return closure 
开发者ID:jrg365,项目名称:gpytorch,代码行数:32,代码来源:kronecker_product_lazy_variable.py

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def forward(self, dense):
        if self.sparse.ndimension() == 3:
            return bdsmm(self.sparse, dense)
        else:
            return torch.dsmm(self.sparse, dense) 
开发者ID:jrg365,项目名称:gpytorch,代码行数:7,代码来源:dsmm.py

示例5: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def backward(self, grad_output):
        if self.sparse.ndimension() == 3:
            return bdsmm(self.sparse.transpose(1, 2), grad_output)
        else:
            return torch.dsmm(self.sparse.t(), grad_output) 
开发者ID:jrg365,项目名称:gpytorch,代码行数:7,代码来源:dsmm.py

示例6: test_dsmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def test_dsmm(self):
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            y = self.randn(dj, dk)

            res = torch.dsmm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(res, expected)

        test_shape(7, 5, 3)
        test_shape(1000, 100, 100)
        test_shape(3000, 64, 300) 
开发者ID:pytorch,项目名称:pytorch,代码行数:14,代码来源:test_sparse.py

示例7: kp_interpolated_toeplitz_matmul

# 需要导入模块: import torch [as 别名]
# 或者: from torch import dsmm [as 别名]
def kp_interpolated_toeplitz_matmul(toeplitz_columns, tensor, interp_left=None, interp_right=None, noise_diag=None):
    """
    Given an interpolated matrix interp_left * T_1 \otimes ... \otimes T_d * interp_right, plus possibly an additional
    diagonal component s*I, compute a product with some tensor or matrix tensor, where T_i is
    symmetric Toeplitz matrices.

    Args:
        - toeplitz_columns (d x m matrix) - columns of d toeplitz matrix T_i with
          length n_i
        - interp_left (sparse matrix nxm) - Left interpolation matrix
        - interp_right (sparse matrix pxm) - Right interpolation matrix
        - tensor (matrix p x k) - Vector (k=1) or matrix (k>1) to multiply WKW with
        - noise_diag (tensor p) - If not none, add (s*I)tensor to WKW at the end.

    Returns:
        - tensor
    """
    output_dims = tensor.ndimension()
    noise_term = None

    if output_dims == 1:
        tensor = tensor.unsqueeze(1)

    if noise_diag is not None:
        noise_term = noise_diag.unsqueeze(1).expand_as(tensor) * tensor

    if interp_left is not None:
        # Get interp_{r}^{T} tensor
        interp_right_tensor = torch.dsmm(interp_right.t(), tensor)
        # Get (T interp_{r}^{T}) tensor
        rhs = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, interp_right_tensor)

        # Get (interp_{l} T interp_{r}^{T})tensor
        output = torch.dsmm(interp_left, rhs)
    else:
        output = kronecker_product_toeplitz_matmul(toeplitz_columns, toeplitz_columns, tensor)

    if noise_term is not None:
        # Get (interp_{l} T interp_{r}^{T} + \sigma^{2}I)tensor
        output = output + noise_term

    if output_dims == 1:
        output = output.squeeze(1)
    return output 
开发者ID:jrg365,项目名称:gpytorch,代码行数:46,代码来源:kronecker_product.py


注:本文中的torch.dsmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。