当前位置: 首页>>代码示例>>Python>>正文


Python torch.addbmm方法代码示例

本文整理汇总了Python中torch.addbmm方法的典型用法代码示例。如果您正苦于以下问题:Python torch.addbmm方法的具体用法?Python torch.addbmm怎么用?Python torch.addbmm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.addbmm方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def forward(self, add_matrix, batch1, batch2):
        self.save_for_backward(batch1, batch2)
        output = self._get_output(add_matrix)
        return torch.addbmm(output, self.alpha, add_matrix, self.beta,
                batch1, batch2) 
开发者ID:apaszke,项目名称:pytorch-dist,代码行数:7,代码来源:blas.py

示例2: test_addbmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_addbmm(self):
        # num_batches = 10
        # M, N, O = 12, 8, 5
        num_batches = 2
        M, N, O = 2, 3, 4
        b1 = torch.randn(num_batches, M, N)
        b2 = torch.randn(num_batches, N, O)
        res = torch.bmm(b1, b2)
        res2 = torch.Tensor().resize_as_(res[0]).zero_()

        res2.addbmm_(b1,b2)
        self.assertEqual(res2, res.sum(0)[0])

        res2.addbmm_(1,b1,b2)
        self.assertEqual(res2, res.sum(0)[0]*2)

        res2.addbmm_(1.,.5,b1,b2)
        self.assertEqual(res2, res.sum(0)[0]*2.5)

        res3 = torch.addbmm(1,res2,0,b1,b2)
        self.assertEqual(res3, res2)

        res4 = torch.addbmm(1,res2,.5,b1,b2)
        self.assertEqual(res4, res.sum(0)[0]*3)

        res5 = torch.addbmm(0,res2,1,b1,b2)
        self.assertEqual(res5, res.sum(0)[0])

        res6 = torch.addbmm(.1,res2,.5,b1,b2)
        self.assertEqual(res6, res2 * .1 + res.sum(0) * .5) 
开发者ID:apaszke,项目名称:pytorch-dist,代码行数:32,代码来源:test_torch.py

示例3: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def forward(ctx, add_matrix, batch1, batch2, alpha=1, beta=1, inplace=False):
        ctx.alpha = alpha
        ctx.beta = beta
        ctx.save_for_backward(batch1, batch2)
        output = _get_output(ctx, add_matrix, inplace=inplace)
        return torch.addbmm(alpha, add_matrix, beta,
                            batch1, batch2, out=output) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:9,代码来源:blas.py

示例4: test_functional_blas

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_functional_blas(self):
        def compare(fn, *args):
            unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
                                  for arg in args)
            self.assertEqual(fn(*args).data, fn(*unpacked_args))

        def test_blas_add(fn, x, y, z):
            # Checks all signatures
            compare(fn, x, y, z)
            compare(fn, 0.5, x, y, z)
            compare(fn, 0.5, x, 0.25, y, z)

        def test_blas(fn, x, y):
            compare(fn, x, y)

        test_blas(torch.mm, Variable(torch.randn(2, 10)),
                  Variable(torch.randn(10, 4)))
        test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
                      Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
        test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
                  Variable(torch.randn(4, 10, 4)))
        test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
                      Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
        test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
                      Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
        test_blas(torch.mv, Variable(torch.randn(2, 10)),
                  Variable(torch.randn(10)))
        test_blas_add(torch.addmv, Variable(torch.randn(2)),
                      Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
        test_blas(torch.ger, Variable(torch.randn(5)),
                  Variable(torch.randn(6)))
        test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
                      Variable(torch.randn(5)), Variable(torch.randn(6))) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:35,代码来源:test_autograd.py

示例5: test_addbmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_addbmm(self):
        # num_batches = 10
        # M, N, O = 12, 8, 5
        num_batches = 2
        M, N, O = 2, 3, 4
        b1 = torch.randn(num_batches, M, N)
        b2 = torch.randn(num_batches, N, O)
        res = torch.bmm(b1, b2)
        res2 = torch.Tensor().resize_as_(res[0]).zero_()

        res2.addbmm_(b1, b2)
        self.assertEqual(res2, res.sum(0)[0])

        res2.addbmm_(1, b1, b2)
        self.assertEqual(res2, res.sum(0)[0] * 2)

        res2.addbmm_(1., .5, b1, b2)
        self.assertEqual(res2, res.sum(0)[0] * 2.5)

        res3 = torch.addbmm(1, res2, 0, b1, b2)
        self.assertEqual(res3, res2)

        res4 = torch.addbmm(1, res2, .5, b1, b2)
        self.assertEqual(res4, res.sum(0)[0] * 3)

        res5 = torch.addbmm(0, res2, 1, b1, b2)
        self.assertEqual(res5, res.sum(0)[0])

        res6 = torch.addbmm(.1, res2, .5, b1, b2)
        self.assertEqual(res6, res2 * .1 + res.sum(0) * .5) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:32,代码来源:test_torch.py

示例6: test_addbmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_addbmm(self):
        # num_batches = 10
        # M, N, O = 12, 8, 5
        num_batches = 2
        M, N, O = 2, 3, 4
        b1 = torch.randn(num_batches, M, N)
        b2 = torch.randn(num_batches, N, O)
        res = torch.bmm(b1, b2)
        res2 = torch.Tensor().resize_as_(res[0]).zero_()

        res2.addbmm_(b1, b2)
        self.assertEqual(res2, res.sum(0, False))

        res2.addbmm_(1, b1, b2)
        self.assertEqual(res2, res.sum(0, False) * 2)

        res2.addbmm_(1., .5, b1, b2)
        self.assertEqual(res2, res.sum(0, False) * 2.5)

        res3 = torch.addbmm(1, res2, 0, b1, b2)
        self.assertEqual(res3, res2)

        res4 = torch.addbmm(1, res2, .5, b1, b2)
        self.assertEqual(res4, res.sum(0, False) * 3)

        res5 = torch.addbmm(0, res2, 1, b1, b2)
        self.assertEqual(res5, res.sum(0, False))

        res6 = torch.addbmm(.1, res2, .5, b1, b2)
        self.assertEqual(res6, res2 * .1 + res.sum(0) * .5) 
开发者ID:hughperkins,项目名称:pytorch-coriander,代码行数:32,代码来源:test_torch.py

示例7: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def forward(ctx, add_matrix, batch1, batch2, alpha=1, beta=1, inplace=False):
        ctx.alpha = alpha
        ctx.beta = beta
        ctx.add_matrix_size = add_matrix.size()
        ctx.save_for_backward(batch1, batch2)
        output = _get_output(ctx, add_matrix, inplace=inplace)
        return torch.addbmm(alpha, add_matrix, beta,
                            batch1, batch2, out=output) 
开发者ID:ezyang,项目名称:pytorch,代码行数:10,代码来源:blas.py

示例8: test_addbmm

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_addbmm(self):
        # num_batches = 10
        # M, N, O = 12, 8, 5
        num_batches = 2
        M, N, O = 2, 3, 4
        b1 = torch.randn(num_batches, M, N)
        b2 = torch.randn(num_batches, N, O)
        res = torch.bmm(b1, b2)
        res2 = torch.Tensor().resize_as_(res[0]).zero_()

        res2.addbmm_(b1, b2)
        self.assertEqual(res2, res.sum(0, False))

        res2.addbmm_(1, b1, b2)
        self.assertEqual(res2, res.sum(0, False) * 2)

        res2.addbmm_(1., .5, b1, b2)
        self.assertEqual(res2, res.sum(0, False) * 2.5)

        res3 = torch.addbmm(1, res2, 0, b1, b2)
        self.assertEqual(res3, res2)

        res4 = torch.addbmm(1, res2, .5, b1, b2)
        self.assertEqual(res4, res.sum(0, False) * 3)

        res5 = torch.addbmm(0, res2, 1, b1, b2)
        self.assertEqual(res5, res.sum(0, False))

        res6 = torch.addbmm(.1, res2, .5, b1, b2)
        self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5)) 
开发者ID:ezyang,项目名称:pytorch,代码行数:32,代码来源:test_torch.py

示例9: _test_broadcast_fused_matmul

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def _test_broadcast_fused_matmul(self, cast):
        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]

        for fn in fns:
            batch_dim = random.randint(1, 8)
            n_dim = random.randint(1, 8)
            m_dim = random.randint(1, 8)
            p_dim = random.randint(1, 8)

            def dims_full_for_fn():
                if fn == "baddbmm":
                    return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
                elif fn == "addbmm":
                    return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
                elif fn == "addmm":
                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
                elif fn == "addmv":
                    return ([n_dim], [n_dim, m_dim], [m_dim])
                elif fn == "addr":
                    return ([n_dim, m_dim], [n_dim], [m_dim])
                else:
                    raise AssertionError("unknown function")

            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)

            t0_small = cast(torch.randn(*t0_dims_small).float())
            t1 = cast(torch.randn(*t1_dims).float())
            t2 = cast(torch.randn(*t2_dims).float())

            t0_full = cast(t0_small.expand(*t0_dims_full))

            fntorch = getattr(torch, fn)
            r0 = fntorch(t0_small, t1, t2)
            r1 = fntorch(t0_full, t1, t2)
            self.assertEqual(r0, r1) 
开发者ID:ezyang,项目名称:pytorch,代码行数:38,代码来源:test_torch.py

示例10: test_functional_blas

# 需要导入模块: import torch [as 别名]
# 或者: from torch import addbmm [as 别名]
def test_functional_blas(self):
        def compare(fn, *args):
            unpacked_args = tuple(arg.data if isinstance(arg, Variable) else arg
                                  for arg in args)
            unpacked_result = fn(*unpacked_args)
            packed_result = fn(*args).data
            # if non-Variable torch function returns a scalar, compare to scalar
            if not torch.is_tensor(unpacked_result):
                assert packed_result.dim() == 1
                assert packed_result.nelement() == 1
                packed_result = packed_result[0]
            self.assertEqual(packed_result, unpacked_result)

        def test_blas_add(fn, x, y, z):
            # Checks all signatures
            compare(fn, x, y, z)
            compare(fn, 0.5, x, y, z)
            compare(fn, 0.5, x, 0.25, y, z)

        def test_blas(fn, x, y):
            compare(fn, x, y)

        test_blas(torch.mm, Variable(torch.randn(2, 10)),
                  Variable(torch.randn(10, 4)))
        test_blas_add(torch.addmm, Variable(torch.randn(2, 4)),
                      Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
        test_blas(torch.bmm, Variable(torch.randn(4, 2, 10)),
                  Variable(torch.randn(4, 10, 4)))
        test_blas_add(torch.addbmm, Variable(torch.randn(2, 4)),
                      Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
        test_blas_add(torch.baddbmm, Variable(torch.randn(4, 2, 4)),
                      Variable(torch.randn(4, 2, 10)), Variable(torch.randn(4, 10, 4)))
        test_blas(torch.mv, Variable(torch.randn(2, 10)),
                  Variable(torch.randn(10)))
        test_blas_add(torch.addmv, Variable(torch.randn(2)),
                      Variable(torch.randn(2, 10)), Variable(torch.randn(10)))
        test_blas(torch.ger, Variable(torch.randn(5)),
                  Variable(torch.randn(6)))
        test_blas_add(torch.addr, Variable(torch.randn(5, 6)),
                      Variable(torch.randn(5)), Variable(torch.randn(6)))
        test_blas(torch.matmul, Variable(torch.randn(6)), Variable(torch.randn(6)))
        test_blas(torch.matmul, Variable(torch.randn(10, 4)), Variable(torch.randn(4)))
        test_blas(torch.matmul, Variable(torch.randn(5)), Variable(torch.randn(5, 6)))
        test_blas(torch.matmul, Variable(torch.randn(2, 10)), Variable(torch.randn(10, 4)))
        test_blas(torch.matmul, Variable(torch.randn(5, 2, 10)), Variable(torch.randn(5, 10, 4)))
        test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(3, 5, 10, 4)))
        test_blas(torch.matmul, Variable(torch.randn(3, 5, 2, 10)), Variable(torch.randn(10)))
        test_blas(torch.matmul, Variable(torch.randn(10)), Variable(torch.randn(3, 5, 10, 4))) 
开发者ID:ezyang,项目名称:pytorch,代码行数:50,代码来源:test_autograd.py


注:本文中的torch.addbmm方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。