当前位置: 首页>>代码示例>>Python>>正文


Python torch.device函数代码示例

本文整理汇总了Python中torch.device函数的典型用法代码示例。如果您正苦于以下问题:Python device函数的具体用法?Python device怎么用?Python device使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了device函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_degenerate_GPyTorchPosterior

    def test_degenerate_GPyTorchPosterior(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            # singular covariance matrix
            degenerate_covar = torch.tensor(
                [[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device
            )
            mean = torch.rand(3, dtype=dtype, device=device)
            mvn = MultivariateNormal(mean, lazify(degenerate_covar))
            posterior = GPyTorchPosterior(mvn=mvn)
            # basics
            self.assertEqual(posterior.device.type, device.type)
            self.assertTrue(posterior.dtype == dtype)
            self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
            self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
            variance_exp = degenerate_covar.diag().unsqueeze(-1)
            self.assertTrue(torch.equal(posterior.variance, variance_exp))

            # rsample
            with warnings.catch_warnings(record=True) as w:
                # we check that the p.d. warning is emitted - this only
                # happens once per posterior, so we need to check only once
                samples = posterior.rsample(sample_shape=torch.Size([4]))
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
            samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
            self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
            # rsample w/ base samples
            base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
            samples_b1 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            samples_b2 = posterior.rsample(
                sample_shape=torch.Size([4]), base_samples=base_samples
            )
            self.assertTrue(torch.allclose(samples_b1, samples_b2))
            base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            samples2_b1 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            samples2_b2 = posterior.rsample(
                sample_shape=torch.Size([4, 2]), base_samples=base_samples2
            )
            self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
            # collapse_batch_dims
            b_mean = torch.rand(2, 3, dtype=dtype, device=device)
            b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
            b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
            b_posterior = GPyTorchPosterior(mvn=b_mvn)
            b_base_samples = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
            with warnings.catch_warnings(record=True) as w:
                b_samples = b_posterior.rsample(
                    sample_shape=torch.Size([4]), base_samples=b_base_samples
                )
                self.assertEqual(len(w), 1)
                self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
                self.assertTrue("not p.d." in str(w[-1].message))
            self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
开发者ID:saschwan,项目名称:botorch,代码行数:60,代码来源:test_gpytorch.py

示例2: test_joint_optimize

 def test_joint_optimize(
     self,
     mock_get_best_candidates,
     mock_gen_candidates,
     mock_gen_batch_initial_conditions,
     cuda=False,
 ):
     q = 3
     num_restarts = 2
     raw_samples = 10
     options = {}
     mock_acq_function = MockAcquisitionFunction()
     tkwargs = {"device": torch.device("cuda") if cuda else torch.device("cpu")}
     for dtype in (torch.float, torch.double):
         tkwargs["dtype"] = dtype
         mock_gen_batch_initial_conditions.return_value = torch.zeros(
             num_restarts, q, 3, **tkwargs
         )
         mock_gen_candidates.return_value = torch.cat(
             [i * torch.ones(1, q, 3, **tkwargs) for i in range(num_restarts)], dim=0
         )
         mock_get_best_candidates.return_value = torch.ones(1, q, 3, **tkwargs)
         expected_candidates = mock_get_best_candidates.return_value
         bounds = torch.stack(
             [torch.zeros(3, **tkwargs), 4 * torch.ones(3, **tkwargs)]
         )
         candidates = joint_optimize(
             acq_function=mock_acq_function,
             bounds=bounds,
             q=q,
             num_restarts=num_restarts,
             raw_samples=raw_samples,
             options=options,
         )
         self.assertTrue(torch.equal(candidates, expected_candidates))
开发者ID:saschwan,项目名称:botorch,代码行数:35,代码来源:test_optimize.py

示例3: test_upper_confidence_bound

    def test_upper_confidence_bound(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([[0.0]], device=device, dtype=dtype)
            variance = torch.tensor([[1.0]], device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = UpperConfidenceBound(model=mm, beta=1.0)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            ucb = module(X)
            ucb_expected = torch.tensor([1.0], device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4))

            module = UpperConfidenceBound(model=mm, beta=1.0, maximize=False)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            ucb = module(X)
            ucb_expected = torch.tensor([-1.0], device=device, dtype=dtype)
            self.assertTrue(torch.allclose(ucb, ucb_expected, atol=1e-4))

            # check for proper error if multi-output model
            mean2 = torch.rand(1, 2, device=device, dtype=dtype)
            variance2 = torch.rand(1, 2, device=device, dtype=dtype)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            module2 = UpperConfidenceBound(model=mm2, beta=1.0)
            with self.assertRaises(UnsupportedError):
                module2(X)
开发者ID:saschwan,项目名称:botorch,代码行数:26,代码来源:test_analytic.py

示例4: test_constrained_expected_improvement_batch

 def test_constrained_expected_improvement_batch(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         mean = torch.tensor(
             [[-0.5, 0.0, 5.0, 0.0], [0.0, 0.0, 5.0, 0.0], [0.5, 0.0, 5.0, 0.0]],
             device=device,
             dtype=dtype,
         ).unsqueeze(dim=-2)
         variance = torch.ones(3, 4, device=device, dtype=dtype).unsqueeze(dim=-2)
         N = torch.distributions.Normal(loc=0.0, scale=1.0)
         a = N.icdf(torch.tensor(0.75))  # get a so that P(-a <= N <= a) = 0.5
         mm = MockModel(MockPosterior(mean=mean, variance=variance))
         module = ConstrainedExpectedImprovement(
             model=mm,
             best_f=0.0,
             objective_index=0,
             constraints={1: [None, 0], 2: [5.0, None], 3: [-a, a]},
         )
         X = torch.empty(3, 1, 1, device=device, dtype=dtype)  # dummy
         ei = module(X)
         ei_expected_unconstrained = torch.tensor(
             [0.19780, 0.39894, 0.69780], device=device, dtype=dtype
         )
         ei_expected = ei_expected_unconstrained * 0.5 * 0.5 * 0.5
         self.assertTrue(torch.allclose(ei, ei_expected, atol=1e-4))
开发者ID:saschwan,项目名称:botorch,代码行数:25,代码来源:test_analytic.py

示例5: test_FixedNoiseMultiTaskGP_single_output

    def test_FixedNoiseMultiTaskGP_single_output(self, cuda=False):
        for double in (False, True):
            tkwargs = {
                "device": torch.device("cuda") if cuda else torch.device("cpu"),
                "dtype": torch.double if double else torch.float,
            }
            model = _get_fixed_noise_model_single_output(**tkwargs)
            self.assertIsInstance(model, FixedNoiseMultiTaskGP)
            self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
            self.assertIsInstance(model.mean_module, ConstantMean)
            self.assertIsInstance(model.covar_module, ScaleKernel)
            matern_kernel = model.covar_module.base_kernel
            self.assertIsInstance(matern_kernel, MaternKernel)
            self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
            self.assertIsInstance(model.task_covar_module, IndexKernel)
            self.assertEqual(model._rank, 2)
            self.assertEqual(
                model.task_covar_module.covar_factor.shape[-1], model._rank
            )

            # test model fitting
            mll = ExactMarginalLogLikelihood(model.likelihood, model)
            mll = fit_gpytorch_model(mll, options={"maxiter": 1})

            # test posterior
            test_x = torch.rand(2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultivariateNormal)

            # test posterior (batch eval)
            test_x = torch.rand(3, 2, 1, **tkwargs)
            posterior_f = model.posterior(test_x)
            self.assertIsInstance(posterior_f, GPyTorchPosterior)
            self.assertIsInstance(posterior_f.mvn, MultivariateNormal)
开发者ID:saschwan,项目名称:botorch,代码行数:35,代码来源:test_multitask.py

示例6: test_probability_of_improvement

    def test_probability_of_improvement(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            mean = torch.tensor([0.0], device=device, dtype=dtype).view(1, 1)
            variance = torch.ones(1, 1, device=device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean, variance=variance))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.0250, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            module = ProbabilityOfImprovement(model=mm, best_f=1.96, maximize=False)
            X = torch.zeros(1, 1, device=device, dtype=dtype)
            pi = module(X)
            pi_expected = torch.tensor(0.9750, device=device, dtype=dtype)
            self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))

            # check for proper error if multi-output model
            mean2 = torch.rand(1, 2, device=device, dtype=dtype)
            variance2 = torch.ones_like(mean2)
            mm2 = MockModel(MockPosterior(mean=mean2, variance=variance2))
            module2 = ProbabilityOfImprovement(model=mm2, best_f=0.0)
            with self.assertRaises(UnsupportedError):
                module2(X)
开发者ID:saschwan,项目名称:botorch,代码行数:26,代码来源:test_analytic.py

示例7: set_opt

 def set_opt(self):
     model_classes = {  # 因为属性太长, 需要进行 dict 映射, 避免用户输入麻烦
         'lstm': torch.nn.LSTM,
         'cnn': torch.nn.Conv2d,
     }
     input_colses = {
         'lstm': ['text_raw_indices'],
         'cnn': ['text_raw_indices', 'aspect_indices'],
     }
     initializers = {
         'xavier_uniform_': torch.nn.init.xavier_uniform_,
         'xavier_normal_': torch.nn.init.xavier_normal,
         'orthogonal_': torch.nn.init.orthogonal_,
     }
     optimizers = {
         'adadelta': torch.optim.Adadelta,  # default lr=1.0
         'adagrad': torch.optim.Adagrad,    # default lr=0.01
         'adam': torch.optim.Adam,          # default lr=0.001
         'adamax': torch.optim.Adamax,      # default lr=0.002
         'asgd': torch.optim.ASGD,          # default lr=0.01
         'rmsprop': torch.optim.RMSprop,    # default lr=0.01
         'sgd': torch.optim.SGD,
     }
     self.model_class = model_classes[self.model_name]
     self.inputs_cols = input_colses[self.model_name]
     self.initializer = initializers[self.initializer]
     self.optimizer = optimizers[self.optimizer]
     self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \
         if self.device is None else torch.device(self.device)
开发者ID:coder352,项目名称:shellscript,代码行数:29,代码来源:l44_config_setattr.py

示例8: test_FixedNoiseGP

    def test_FixedNoiseGP(self, cuda=False):
        for batch_shape in (torch.Size([]), torch.Size([2])):
            for num_outputs in (1, 2):
                for double in (False, True):
                    tkwargs = {
                        "device": torch.device("cuda") if cuda else torch.device("cpu"),
                        "dtype": torch.double if double else torch.float,
                    }
                    model = self._get_model(
                        batch_shape=batch_shape,
                        num_outputs=num_outputs,
                        n=10,
                        **tkwargs
                    )
                    self.assertIsInstance(model, FixedNoiseGP)
                    self.assertIsInstance(
                        model.likelihood, FixedNoiseGaussianLikelihood
                    )
                    self.assertIsInstance(model.mean_module, ConstantMean)
                    self.assertIsInstance(model.covar_module, ScaleKernel)
                    matern_kernel = model.covar_module.base_kernel
                    self.assertIsInstance(matern_kernel, MaternKernel)
                    self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)

                    # test model fitting
                    mll = ExactMarginalLogLikelihood(model.likelihood, model)
                    mll = fit_gpytorch_model(mll, options={"maxiter": 1})

                    # Test forward
                    test_x = torch.rand(batch_shape + torch.Size([3, 1]), **tkwargs)
                    posterior = model(test_x)
                    self.assertIsInstance(posterior, MultivariateNormal)

                    # TODO: Pass observation noise into posterior
                    # posterior_obs = model.posterior(test_x, observation_noise=True)
                    # self.assertTrue(
                    #     torch.allclose(
                    #         posterior_f.variance + 0.01,
                    #         posterior_obs.variance
                    #     )
                    # )

                    # test posterior
                    # test non batch evaluation
                    X = torch.rand(batch_shape + torch.Size([3, 1]), **tkwargs)
                    posterior = model.posterior(X)
                    self.assertIsInstance(posterior, GPyTorchPosterior)
                    self.assertEqual(
                        posterior.mean.shape, batch_shape + torch.Size([3, num_outputs])
                    )
                    # test batch evaluation
                    X = torch.rand(
                        torch.Size([2]) + batch_shape + torch.Size([3, 1]), **tkwargs
                    )
                    posterior = model.posterior(X)
                    self.assertIsInstance(posterior, GPyTorchPosterior)
                    self.assertEqual(
                        posterior.mean.shape,
                        torch.Size([2]) + batch_shape + torch.Size([3, num_outputs]),
                    )
开发者ID:saschwan,项目名称:botorch,代码行数:60,代码来源:test_gp_regression.py

示例9: _setUp

 def _setUp(self, double=False, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     dtype = torch.double if double else torch.float
     train_x = torch.linspace(0, 1, 10, device=device, dtype=dtype).unsqueeze(-1)
     train_y = torch.sin(train_x * (2 * math.pi)).squeeze(-1)
     train_yvar = torch.tensor(0.1 ** 2, device=device)
     noise = torch.tensor(NOISE, device=device, dtype=dtype)
     self.train_x = train_x
     self.train_y = train_y + noise
     self.train_yvar = train_yvar
     self.bounds = torch.tensor([[0.0], [1.0]], device=device, dtype=dtype)
     model_st = SingleTaskGP(self.train_x, self.train_y)
     self.model_st = model_st.to(device=device, dtype=dtype)
     self.mll_st = ExactMarginalLogLikelihood(
         self.model_st.likelihood, self.model_st
     )
     self.mll_st = fit_gpytorch_model(self.mll_st, options={"maxiter": 5})
     model_fn = FixedNoiseGP(
         self.train_x, self.train_y, self.train_yvar.expand_as(self.train_y)
     )
     self.model_fn = model_fn.to(device=device, dtype=dtype)
     self.mll_fn = ExactMarginalLogLikelihood(
         self.model_fn.likelihood, self.model_fn
     )
     self.mll_fn = fit_gpytorch_model(self.mll_fn, options={"maxiter": 5})
开发者ID:saschwan,项目名称:botorch,代码行数:25,代码来源:test_end_to_end.py

示例10: stylize

def stylize(**kwargs):
    opt = Config()

    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    device=t.device('cuda') if opt.use_gpu else t.device('cpu')
    
    # 图片处理
    content_image = tv.datasets.folder.default_loader(opt.content_path)
    content_transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device).detach()

    # 模型
    style_model = TransformerNet().eval()
    style_model.load_state_dict(t.load(opt.model_path, map_location=lambda _s, _: _s))
    style_model.to(device)

    # 风格迁移与保存
    output = style_model(content_image)
    output_data = output.cpu().data[0]
    tv.utils.save_image(((output_data / 255)).clamp(min=0, max=1), opt.result_path)
开发者ID:672401341,项目名称:pytorch-book,代码行数:25,代码来源:main.py

示例11: test_gen_batch_initial_conditions_simple_warning

 def test_gen_batch_initial_conditions_simple_warning(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         bounds = torch.tensor([[0, 0], [1, 1]], device=device, dtype=dtype)
         with warnings.catch_warnings(record=True) as ws:
             with mock.patch(
                 "botorch.optim.optimize.draw_sobol_samples",
                 return_value=torch.zeros(10, 1, 2, device=device, dtype=dtype),
             ):
                 batch_initial_conditions = gen_batch_initial_conditions(
                     acq_function=MockAcquisitionFunction(),
                     bounds=bounds,
                     q=1,
                     num_restarts=2,
                     raw_samples=10,
                 )
                 self.assertEqual(len(ws), 1)
                 self.assertTrue(
                     issubclass(ws[-1].category, BadInitialCandidatesWarning)
                 )
                 self.assertTrue(
                     torch.equal(
                         batch_initial_conditions,
                         torch.zeros(2, 1, 2, device=device, dtype=dtype),
                     )
                 )
开发者ID:saschwan,项目名称:botorch,代码行数:26,代码来源:test_optimize.py

示例12: test_MultivariateNormalQMCEngineDegenerate

 def test_MultivariateNormalQMCEngineDegenerate(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         # X, Y iid standard Normal and Z = X + Y, random vector (X, Y, Z)
         mean = torch.zeros(3, device=device, dtype=dtype)
         cov = torch.tensor(
             [[1, 0, 1], [0, 1, 1], [1, 1, 2]], device=device, dtype=dtype
         )
         engine = MultivariateNormalQMCEngine(mean=mean, cov=cov, seed=12345)
         samples = engine.draw(n=2000)
         self.assertEqual(samples.dtype, dtype)
         self.assertEqual(samples.device.type, device.type)
         self.assertTrue(torch.all(torch.abs(samples.mean(dim=0)) < 1e-2))
         self.assertTrue(torch.abs(torch.std(samples[:, 0]) - 1) < 1e-2)
         self.assertTrue(torch.abs(torch.std(samples[:, 1]) - 1) < 1e-2)
         self.assertTrue(torch.abs(torch.std(samples[:, 2]) - math.sqrt(2)) < 1e-2)
         for i in (0, 1, 2):
             _, pval = shapiro(samples[:, i].cpu().numpy())
             self.assertGreater(pval, 0.9)
         cov = np.cov(samples.cpu().numpy().transpose())
         self.assertLess(np.abs(cov[0, 1]), 1e-2)
         self.assertLess(np.abs(cov[0, 2] - 1), 1e-2)
         # check to see if X + Y = Z almost exactly
         self.assertTrue(
             torch.all(
                 torch.abs(samples[:, 0] + samples[:, 1] - samples[:, 2]) < 1e-5
             )
         )
开发者ID:saschwan,项目名称:botorch,代码行数:28,代码来源:test_normal.py

示例13: data_from_config

    def data_from_config(cls, config, train=False, model=None):
        if model is None:
            model = cls.model_from_config(config, load_checkpoint=False)
        if train:
            data = config.data
            dataset = data.train.dataset(**vars(data.train.config))
            train_set, valid_set = cls.random_split(
                dataset, data.valid.split,
                train_transform=data.train.transform(model),
                train_target_transform=data.train.target_transform(model),
                valid_transform=data.valid.transform(model),
                valid_target_transform=data.valid.target_transform(model))

            cuda = torch.device(config.device).type == 'cuda'
            train_loader = cls.data_loader(train_set, data.train.batch_size,
                                           train=True, cuda=cuda,
                                           workers=data.train.num_loaders)
            valid_loader = cls.data_loader(valid_set, data.valid.batch_size,
                                           train=False, cuda=cuda,
                                           workers=data.valid.num_loaders)
            return train_loader, valid_loader
        else:
            data = config.data.test
            test_set = TransformedDataset.of(
                data.dataset(**vars(data.config)),
                transform=data.transform(model),
                target_transform=data.target_transform(model))

            cuda = torch.device(config.device).type == 'cuda'
            loader = cls.data_loader(test_set, data.batch_size,
                                     train=False, cuda=cuda,
                                     workers=data.num_loaders)
            return loader
开发者ID:ModarTensai,项目名称:network_moments,代码行数:33,代码来源:train.py

示例14: __init__

    def __init__(
        self,
        cfg,
        confidence_threshold=0.7,
        show_mask_heatmaps=False,
        masks_per_dim=2,
        min_image_size=224,
    ):
        self.cfg = cfg.clone()
        self.model = build_detection_model(cfg)
        self.model.eval()
        self.device = torch.device(cfg.MODEL.DEVICE)
        self.model.to(self.device)
        self.min_image_size = min_image_size

        checkpointer = DetectronCheckpointer(cfg, self.model)
        _ = checkpointer.load(cfg.MODEL.WEIGHT)

        self.transforms = self.build_transform()

        mask_threshold = -1 if show_mask_heatmaps else 0.5
        self.masker = Masker(threshold=mask_threshold, padding=1)

        # used to make colors for each class
        self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])

        self.cpu_device = torch.device("cpu")
        self.confidence_threshold = confidence_threshold
        self.show_mask_heatmaps = show_mask_heatmaps
        self.masks_per_dim = masks_per_dim
开发者ID:laycoding,项目名称:maskrcnn-benchmark,代码行数:30,代码来源:predictor.py

示例15: test_add_output_dim

 def test_add_output_dim(self, cuda=False):
     for double in (False, True):
         tkwargs = {
             "device": torch.device("cuda") if cuda else torch.device("cpu"),
             "dtype": torch.double if double else torch.float,
         }
         original_batch_shape = torch.Size([2])
         # check exception is raised
         X = torch.rand(2, 1, **tkwargs)
         with self.assertRaises(ValueError):
             add_output_dim(X=X, original_batch_shape=original_batch_shape)
         # test no new batch dims
         X = torch.rand(2, 2, 1, **tkwargs)
         X_out, output_dim_idx = add_output_dim(
             X=X, original_batch_shape=original_batch_shape
         )
         self.assertTrue(torch.equal(X_out, X.unsqueeze(0)))
         self.assertEqual(output_dim_idx, 0)
         # test new batch dims
         X = torch.rand(3, 2, 2, 1, **tkwargs)
         X_out, output_dim_idx = add_output_dim(
             X=X, original_batch_shape=original_batch_shape
         )
         self.assertTrue(torch.equal(X_out, X.unsqueeze(1)))
         self.assertEqual(output_dim_idx, 1)
开发者ID:saschwan,项目名称:botorch,代码行数:25,代码来源:test_utils.py


注:本文中的torch.device函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。