当前位置: 首页>>代码示例>>Python>>正文


Python torch.equal函数代码示例

本文整理汇总了Python中torch.equal函数的典型用法代码示例。如果您正苦于以下问题:Python equal函数的具体用法?Python equal怎么用?Python equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了equal函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_degenerate_GPyTorchPosterior

    def test_degenerate_GPyTorchPosterior(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # singular covariance matrix
            degenerate_covar = torch.tensor(
                [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device
            )
            mean = torch.rand(3, dtype=dtype, device=device)
            mvn = MultivariateNormal(mean, lazify(degenerate_covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            variance_exp = degenerate_covar.diag().unsqueeze(-1)
            self.assertTrue(torch.equal(posterior.variance, variance_exp))

            # rsample
            with warnings.catch_warnings(record=True) as w:
                # we check that the p.d. warning is emitted - this only
                # happens once per posterior, so we need to check only once
                samples = posterior.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=device)
            b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
            b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            with warnings.catch_warnings(record=True) as w:
                b_samples = b_posterior.rsample(
                    sample_shape=torch.Size([4]), base_samples=b_base_samples
                )
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
开发者ID:saschwan,项目名称:botorch,代码行数:60,代码来源:test_gpytorch.py

示例2: test_remote_tensor_multi_var_methods

    def test_remote_tensor_multi_var_methods(self):
        hook = TorchHook(verbose=False)
        local = hook.local_worker
        remote = VirtualWorker(hook, 1)
        local.add_worker(remote)

        x = torch.FloatTensor([[1, 2], [4, 3], [5, 6]])
        x.send(remote)
        y, z = torch.max(x, 1)
        assert torch.equal(y.get(), torch.FloatTensor([2, 4, 6]))
        assert torch.equal(z.get(), torch.LongTensor([1, 0, 1]))

        x = torch.FloatTensor([[0, 0], [1, 0]]).send(remote)
        y, z = torch.qr(x)
        assert (y.get() == torch.FloatTensor([[0, -1], [-1, 0]])).all()
        assert (z.get() == torch.FloatTensor([[-1, 0], [0, 0]])).all()

        x = torch.arange(1, 6).send(remote)
        y, z = torch.kthvalue(x, 4)
        assert (y.get() == torch.FloatTensor([4])).all()
        assert (z.get() == torch.LongTensor([3])).all()

        x = torch.FloatTensor([[0, 0], [1, 1]]).send(remote)
        y, z = torch.eig(x, True)
        assert (y.get() == torch.FloatTensor([[1, 0], [0, 0]])).all()
        assert ((z.get() == torch.FloatTensor([[0, 0], [1, 0]])) == torch.ByteTensor([[1, 0], [1, 0]])).all()

        x = torch.zeros(3, 3).send(remote)
        w, y, z = torch.svd(x)
        assert (w.get() == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
        assert (y.get() == torch.FloatTensor([0, 0, 0])).all()
        assert (z.get() == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:32,代码来源:torch_test.py

示例3: test_regex_matches_are_initialized_correctly

    def test_regex_matches_are_initialized_correctly(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                self.conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        # pyhocon does funny things if there's a . in a key.  This test makes sure that we
        # handle these kinds of regexes correctly.
        json_params = """{"initializer": [
        ["conv", {"type": "constant", "val": 5}],
        ["funky_na.*bi", {"type": "constant", "val": 7}]
        ]}
        """
        params = Params(pyhocon.ConfigFactory.parse_string(json_params))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for parameter in model.conv.parameters():
            assert torch.equal(parameter.data, torch.ones(parameter.size()) * 5)

        parameter = model.linear_1_with_funky_name.bias
        assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:28,代码来源:initializers_test.py

示例4: test_add_output_dim

 def test_add_output_dim(self, cuda=False):
     for double in (False, True):
         tkwargs = {
             "device": torch.device("cuda") if cuda else torch.device("cpu"),
             "dtype": torch.double if double else torch.float,
         }
         original_batch_shape = torch.Size([2])
         # check exception is raised
         X = torch.rand(2, 1, **tkwargs)
         with self.assertRaises(ValueError):
             add_output_dim(X=X, original_batch_shape=original_batch_shape)
         # test no new batch dims
         X = torch.rand(2, 2, 1, **tkwargs)
         X_out, output_dim_idx = add_output_dim(
             X=X, original_batch_shape=original_batch_shape
         )
         self.assertTrue(torch.equal(X_out, X.unsqueeze(0)))
         self.assertEqual(output_dim_idx, 0)
         # test new batch dims
         X = torch.rand(3, 2, 2, 1, **tkwargs)
         X_out, output_dim_idx = add_output_dim(
             X=X, original_batch_shape=original_batch_shape
         )
         self.assertTrue(torch.equal(X_out, X.unsqueeze(1)))
         self.assertEqual(output_dim_idx, 1)
开发者ID:saschwan,项目名称:botorch,代码行数:25,代码来源:test_utils.py

示例5: test_q_noisy_expected_improvement

    def test_q_noisy_expected_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # the event shape is `b x q x t` = 1 x 2 x 1
            samples_noisy = torch.tensor([1.0, 0.0], device=device, dtype=dtype)
            samples_noisy = samples_noisy.view(1, 2, 1)
            # X_baseline is `q' x d` = 1 x 1
            X_baseline = torch.zeros(1, 1, device=device, dtype=dtype)
            mm_noisy = MockModel(MockPosterior(samples=samples_noisy))
            # X is `q x d` = 1 x 1
            X = torch.zeros(1, 1, device=device, dtype=dtype)

            # basic test
            sampler = IIDNormalSampler(num_samples=2)
            acqf = qNoisyExpectedImprovement(
                model=mm_noisy, X_baseline=X_baseline, sampler=sampler
            )
            res = acqf(X)
            self.assertEqual(res.item(), 1.0)

            # basic test, no resample
            sampler = IIDNormalSampler(num_samples=2, seed=12345)
            acqf = qNoisyExpectedImprovement(
                model=mm_noisy, X_baseline=X_baseline, sampler=sampler
            )
            res = acqf(X)
            self.assertEqual(res.item(), 1.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
            bs = acqf.sampler.base_samples.clone()
            acqf(X)
            self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

            # basic test, qmc, no resample
            sampler = SobolQMCNormalSampler(num_samples=2)
            acqf = qNoisyExpectedImprovement(
                model=mm_noisy, X_baseline=X_baseline, sampler=sampler
            )
            res = acqf(X)
            self.assertEqual(res.item(), 1.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
            bs = acqf.sampler.base_samples.clone()
            acqf(X)
            self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

            # basic test, qmc, resample
            sampler = SobolQMCNormalSampler(num_samples=2, resample=True, seed=12345)
            acqf = qNoisyExpectedImprovement(
                model=mm_noisy, X_baseline=X_baseline, sampler=sampler
            )
            res = acqf(X)
            self.assertEqual(res.item(), 1.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
            bs = acqf.sampler.base_samples.clone()
            acqf(X)
            self.assertFalse(torch.equal(acqf.sampler.base_samples, bs))
开发者ID:saschwan,项目名称:botorch,代码行数:55,代码来源:test_monte_carlo.py

示例6: test_local_tensor_iterable_methods

    def test_local_tensor_iterable_methods(self):

        x = torch.FloatTensor([1, 2, 3])
        y = torch.FloatTensor([2, 3, 4])
        z = torch.FloatTensor([5, 6, 7])
        assert(torch.equal(torch.stack([x, y, z]), torch.FloatTensor([[1, 2, 3], [2, 3, 4], [5, 6, 7]])))

        x = torch.FloatTensor([1, 2, 3])
        y = torch.FloatTensor([2, 3, 4])
        z = torch.FloatTensor([5, 6, 7])
        assert (torch.equal(torch.cat([x, y, z]), torch.FloatTensor([1, 2, 3, 2, 3, 4, 5, 6, 7])))
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:11,代码来源:torch_test.py

示例7: test_MockPosterior

 def test_MockPosterior(self):
     mean = torch.rand(2)
     variance = torch.eye(2)
     samples = torch.rand(1, 2)
     mp = MockPosterior(mean=mean, variance=variance, samples=samples)
     self.assertTrue(torch.equal(mp.mean, mean))
     self.assertTrue(torch.equal(mp.variance, variance))
     self.assertTrue(torch.all(mp.sample() == samples.unsqueeze(0)))
     self.assertTrue(
         torch.all(mp.sample(torch.Size([2])) == samples.repeat(2, 1, 1))
     )
开发者ID:saschwan,项目名称:botorch,代码行数:11,代码来源:test_mock.py

示例8: test_GPyTorchPosterior

 def test_GPyTorchPosterior(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.rand(3, dtype=dtype, device=device)
         variance = 1 + torch.rand(3, dtype=dtype, device=device)
         covar = variance.diag()
         mvn = MultivariateNormal(mean, lazify(covar))
         posterior = GPyTorchPosterior(mvn=mvn)
         # basics
         self.assertEqual(posterior.device.type, device.type)
         self.assertTrue(posterior.dtype == dtype)
         self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
         self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
         self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1)))
         # rsample
         samples = posterior.rsample()
         self.assertEqual(samples.shape, torch.Size([1, 3, 1]))
         samples = posterior.rsample(sample_shape=torch.Size([4]))
         self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
         samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
         self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
         # rsample w/ base samples
         base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
         # incompatible shapes
         with self.assertRaises(RuntimeError):
             posterior.rsample(
                 sample_shape=torch.Size([3]), base_samples=base_samples
             )
         samples_b1 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         samples_b2 = posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=base_samples
         )
         self.assertTrue(torch.allclose(samples_b1, samples_b2))
         base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
         samples2_b1 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         samples2_b2 = posterior.rsample(
             sample_shape=torch.Size([4, 2]), base_samples=base_samples2
         )
         self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
         # collapse_batch_dims
         b_mean = torch.rand(2, 3, dtype=dtype, device=device)
         b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=device)
         b_covar = b_variance.unsqueeze(-1) * torch.eye(3).type_as(b_variance)
         b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
         b_posterior = GPyTorchPosterior(mvn=b_mvn)
         b_base_samples = torch.randn(4, 1, 3, 1, device=device, dtype=dtype)
         b_samples = b_posterior.rsample(
             sample_shape=torch.Size([4]), base_samples=b_base_samples
         )
         self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
开发者ID:saschwan,项目名称:botorch,代码行数:54,代码来源:test_gpytorch.py

示例9: test_make_grid_not_inplace

    def test_make_grid_not_inplace(self):
        t = torch.rand(5, 3, 10, 10)
        t_clone = t.clone()

        utils.make_grid(t, normalize=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=False)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'

        utils.make_grid(t, normalize=True, scale_each=True)
        assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
开发者ID:WENXINGEVIN,项目名称:vision,代码行数:12,代码来源:test_utils.py

示例10: test_generic_mc_objective

 def test_generic_mc_objective(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         obj = GenericMCObjective(generic_obj)
         samples = torch.randn(1, device=device, dtype=dtype)
         self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
         samples = torch.randn(2, device=device, dtype=dtype)
         self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
         samples = torch.randn(3, 1, device=device, dtype=dtype)
         self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
         samples = torch.randn(3, 2, device=device, dtype=dtype)
         self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
开发者ID:saschwan,项目名称:botorch,代码行数:12,代码来源:test_objective.py

示例11: test_q_expected_improvement

    def test_q_expected_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # the event shape is `b x q x t` = 1 x 1 x 1
            samples = torch.zeros(1, 1, 1, device=device, dtype=dtype)
            mm = MockModel(MockPosterior(samples=samples))
            # X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking
            X = torch.zeros(1, 1, device=device, dtype=dtype)

            # basic test
            sampler = IIDNormalSampler(num_samples=2)
            acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
            res = acqf(X)
            self.assertEqual(res.item(), 0.0)

            # test shifting best_f value
            acqf = qExpectedImprovement(model=mm, best_f=-1, sampler=sampler)
            res = acqf(X)
            self.assertEqual(res.item(), 1.0)

            # basic test, no resample
            sampler = IIDNormalSampler(num_samples=2, seed=12345)
            acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
            res = acqf(X)
            self.assertEqual(res.item(), 0.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
            bs = acqf.sampler.base_samples.clone()
            res = acqf(X)
            self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

            # basic test, qmc, no resample
            sampler = SobolQMCNormalSampler(num_samples=2)
            acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
            res = acqf(X)
            self.assertEqual(res.item(), 0.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
            bs = acqf.sampler.base_samples.clone()
            acqf(X)
            self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

            # basic test, qmc, resample
            sampler = SobolQMCNormalSampler(num_samples=2, resample=True)
            acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
            res = acqf(X)
            self.assertEqual(res.item(), 0.0)
            self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
            bs = acqf.sampler.base_samples.clone()
            acqf(X)
            self.assertFalse(torch.equal(acqf.sampler.base_samples, bs))
开发者ID:saschwan,项目名称:botorch,代码行数:49,代码来源:test_monte_carlo.py

示例12: do_test_per_param_optim

    def do_test_per_param_optim(self, fixed_param, free_param):
        pyro.clear_param_store()

        def model():
            prior_dist = Normal(self.mu0, torch.pow(self.lam0, -0.5))
            mu_latent = pyro.sample("mu_latent", prior_dist)
            x_dist = Normal(mu_latent, torch.pow(self.lam, -0.5))
            pyro.observe("obs", x_dist, self.data)
            return mu_latent

        def guide():
            mu_q = pyro.param(
                "mu_q",
                Variable(
                    torch.zeros(1),
                    requires_grad=True))
            log_sig_q = pyro.param(
                "log_sig_q", Variable(
                    torch.zeros(1), requires_grad=True))
            sig_q = torch.exp(log_sig_q)
            pyro.sample("mu_latent", Normal(mu_q, sig_q))

        def optim_params(module_name, param_name, tags):
            if param_name == fixed_param:
                return {'lr': 0.00}
            elif param_name == free_param:
                return {'lr': 0.01}

        adam = optim.Adam(optim_params)
        adam2 = optim.Adam(optim_params)
        svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
        svi2 = SVI(model, guide, adam2, loss="ELBO", trace_graph=True)

        svi.step()
        adam_initial_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step']
        adam.save('adam.unittest.save')
        svi.step()
        adam_final_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step']
        adam2.load('adam.unittest.save')
        svi2.step()
        adam2_step_count_after_load_and_step = list(adam2.get_state()['mu_q']['state'].items())[0][1]['step']

        assert adam_initial_step_count == 1
        assert adam_final_step_count == 2
        assert adam2_step_count_after_load_and_step == 2

        free_param_unchanged = torch.equal(pyro.param(free_param).data, torch.zeros(1))
        fixed_param_unchanged = torch.equal(pyro.param(fixed_param).data, torch.zeros(1))
        assert fixed_param_unchanged and not free_param_unchanged
开发者ID:Magica-Chen,项目名称:pyro,代码行数:49,代码来源:test_optim.py

示例13: test_match_batch_shape

    def test_match_batch_shape(self):
        X = torch.rand(3, 2)
        Y = torch.rand(1, 3, 2)
        X_tf = match_batch_shape(X, Y)
        self.assertTrue(torch.equal(X_tf, X.unsqueeze(0)))

        X = torch.rand(1, 3, 2)
        Y = torch.rand(2, 3, 2)
        X_tf = match_batch_shape(X, Y)
        self.assertTrue(torch.equal(X_tf, X.repeat(2, 1, 1)))

        X = torch.rand(2, 3, 2)
        Y = torch.rand(1, 3, 2)
        with self.assertRaises(RuntimeError):
            match_batch_shape(X, Y)
开发者ID:saschwan,项目名称:botorch,代码行数:15,代码来源:test_transforms.py

示例14: test_standardize

 def test_standardize(self, cuda=False):
     tkwargs = {"device": torch.device("cuda" if cuda else "cpu")}
     for dtype in (torch.float, torch.double):
         tkwargs["dtype"] = dtype
         X = torch.tensor([0.0, 0.0], **tkwargs)
         self.assertTrue(torch.equal(X, standardize(X)))
         X2 = torch.tensor([0.0, 1.0, 1.0, 1.0], **tkwargs)
         expected_X2_stdized = torch.tensor([-1.5, 0.5, 0.5, 0.5], **tkwargs)
         self.assertTrue(torch.equal(expected_X2_stdized, standardize(X2)))
         X3 = torch.tensor(
             [[0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], **tkwargs
         ).transpose(1, 0)
         X3_stdized = standardize(X3)
         self.assertTrue(torch.equal(X3_stdized[:, 0], expected_X2_stdized))
         self.assertTrue(torch.equal(X3_stdized[:, 1], torch.zeros(4, **tkwargs)))
开发者ID:saschwan,项目名称:botorch,代码行数:15,代码来源:test_transforms.py

示例15: test_local_tensor_multi_var_methods

    def test_local_tensor_multi_var_methods(self):
        x = torch.FloatTensor([[1, 2], [2, 3], [5, 6]])
        t, s = torch.max(x, 1)
        assert (t == torch.FloatTensor([2, 3, 6])).float().sum() == 3
        assert (s == torch.LongTensor([1, 1, 1])).float().sum() == 3

        x = torch.FloatTensor([[0, 0], [1, 1]])
        y, z = torch.eig(x, True)
        assert (y == torch.FloatTensor([[1, 0], [0, 0]])).all()
        assert (torch.equal(z == torch.FloatTensor([[0, 0], [1, 0]]), torch.ByteTensor([[1, 0], [1, 0]])))

        x = torch.FloatTensor([[0, 0], [1, 0]])
        y, z = torch.qr(x)
        assert (y == torch.FloatTensor([[0, -1], [-1, 0]])).all()
        assert (z == torch.FloatTensor([[-1, 0], [0, 0]])).all()

        x = torch.arange(1, 6)
        y, z = torch.kthvalue(x, 4)
        assert (y == torch.FloatTensor([4])).all()
        assert (z == torch.LongTensor([3])).all()

        x = torch.zeros(3, 3)
        w, y, z = torch.svd(x)
        assert (w == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
        assert (y == torch.FloatTensor([0, 0, 0])).all()
        assert (z == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:26,代码来源:torch_test.py


注:本文中的torch.equal函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。