当前位置: 首页>>代码示例>>Python>>正文


Python torch.set_default_dtype方法代码示例

本文整理汇总了Python中torch.set_default_dtype方法的典型用法代码示例。如果您正苦于以下问题:Python torch.set_default_dtype方法的具体用法?Python torch.set_default_dtype怎么用?Python torch.set_default_dtype使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.set_default_dtype方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
    n, d, m1, m2 = 2, 3, 5, 7

    # numpy
    import numpy as np
    input_numpy = np.random.randn(n, d, m1)
    weight = np.random.randn(d, m1, m2)
    output_numpy = np.zeros([n, d, m2])
    for j in range(d):
        # [n, m2] = [n, m1] @ [m1, m2]
        output_numpy[:, j, :] = input_numpy[:, j, :] @ weight[j, :, :]

    # torch
    torch.set_default_dtype(torch.double)
    input_torch = torch.from_numpy(input_numpy)
    locally_connected = LocallyConnected(d, m1, m2, bias=False)
    locally_connected.weight.data[:] = torch.from_numpy(weight)
    output_torch = locally_connected(input_torch)

    # compare
    print(torch.allclose(output_torch, torch.from_numpy(output_numpy))) 
开发者ID:xunzheng,项目名称:notears,代码行数:23,代码来源:locally_connected.py

示例2: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
    torch.set_default_dtype(torch.double)
    np.set_printoptions(precision=3)

    import notears.utils as ut
    ut.set_random_seed(123)

    n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
    B_true = ut.simulate_dag(d, s0, graph_type)
    np.savetxt('W_true.csv', B_true, delimiter=',')

    X = ut.simulate_nonlinear_sem(B_true, n, sem_type)
    np.savetxt('X.csv', X, delimiter=',')

    model = NotearsMLP(dims=[d, 10, 1], bias=True)
    W_est = notears_nonlinear(model, X, lambda1=0.01, lambda2=0.01)
    assert ut.is_dag(W_est)
    np.savetxt('W_est.csv', W_est, delimiter=',')
    acc = ut.count_accuracy(B_true, W_est != 0)
    print(acc) 
开发者ID:xunzheng,项目名称:notears,代码行数:22,代码来源:nonlinear.py

示例3: withdtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def withdtype():
    torch.set_default_dtype(torch.float64)
    try:
        yield
    finally:
        torch.set_default_dtype(torch.float32) 
开发者ID:geoopt,项目名称:geoopt,代码行数:8,代码来源:test_rhmc.py

示例4: use_floatX

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def use_floatX(request):
    dtype_old = torch.get_default_dtype()
    torch.set_default_dtype(request.param)
    yield request.param
    torch.set_default_dtype(dtype_old) 
开发者ID:geoopt,项目名称:geoopt,代码行数:7,代码来源:test_manifold_basic.py

示例5: __enter__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __enter__(self):
        """Set new dtype."""
        self.old_dtype = torch.get_default_dtype()
        torch.set_default_dtype(self.new_dtype) 
开发者ID:befelix,项目名称:safe-exploration,代码行数:6,代码来源:utilities.py

示例6: __exit__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __exit__(self, *args):
        """Restor old dtype."""
        torch.set_default_dtype(self.old_dtype) 
开发者ID:befelix,项目名称:safe-exploration,代码行数:5,代码来源:utilities.py

示例7: test_set_torch_dtype

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def test_set_torch_dtype():
    """Test dtype context manager."""
    dtype = torch.get_default_dtype()

    torch.set_default_dtype(torch.float32)
    with SetTorchDtype(torch.float64):
        a = torch.zeros(1)

    assert a.dtype is torch.float64
    b = torch.zeros(1)
    assert b.dtype is torch.float32

    torch.set_default_dtype(dtype) 
开发者ID:befelix,项目名称:safe-exploration,代码行数:15,代码来源:test_utilities.py

示例8: setUp

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def setUp(self):
        torch.set_default_tensor_type(torch.DoubleTensor)
        torch.set_default_dtype(torch.float64) 
开发者ID:learnables,项目名称:cherry,代码行数:5,代码来源:spinup_ddpg_tests.py

示例9: tearDown

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def tearDown(self):
        torch.set_default_tensor_type(torch.FloatTensor)
        torch.set_default_dtype(torch.float32) 
开发者ID:learnables,项目名称:cherry,代码行数:5,代码来源:spinup_ddpg_tests.py

示例10: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __init__(self, **kwargs):
        self.name = 'pytorch'
        self.precision = kwargs.get('precision', '32b')
        self.dtypemap = {
            'float': torch.float64 if self.precision == '64b' else torch.float32,
            'int': torch.int64 if self.precision == '64b' else torch.int32,
            'bool': torch.bool,
        }
        torch.set_default_dtype(self.dtypemap["float"]) 
开发者ID:scikit-hep,项目名称:pyhf,代码行数:11,代码来源:pytorch_backend.py

示例11: main

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
    import torch.nn as nn
    # torch.set_default_dtype(torch.double)

    n, d, out, j = 10000, 3000, 10, 0
    input = torch.randn(n, d)
    w_true = torch.rand(d, out)
    w_true[j, :] = 0
    target = torch.matmul(input, w_true)
    linear = nn.Linear(d, out)
    linear.weight.bounds = [(0, None)] * d * out  # hack
    for m in range(out):
        linear.weight.bounds[m * d + j] = (0, 0)
    criterion = nn.MSELoss()
    optimizer = LBFGSBScipy(linear.parameters())
    print(list(linear.parameters()))

    def closure():
        optimizer.zero_grad()
        output = linear(input)
        loss = criterion(output, target)
        print('loss:', loss.item())
        loss.backward()
        return loss
    optimizer.step(closure)
    print(list(linear.parameters()))
    print(w_true.t()) 
开发者ID:xunzheng,项目名称:notears,代码行数:29,代码来源:lbfgsb_scipy.py

示例12: _test_backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def _test_backward(self, state, eps=2e-8, atol=1e-5, rtol=1e-3, max_num_per_param=5):
        @contextlib.contextmanager
        def double_prec():
            saved_dtype = torch.get_default_dtype()
            torch.set_default_dtype(torch.double)
            yield
            torch.set_default_dtype(saved_dtype)

        with double_prec():
            models = [m.to(torch.double) for m in networks.get_networks(state, 1)]
            trainer = Trainer(state, models)

            model = trainer.models[0]

            rdata, rlabel = next(iter(state.train_loader))
            rdata = rdata.to(state.device, torch.double, non_blocking=True)
            rlabel = rlabel.to(state.device, non_blocking=True)
            steps = trainer.get_steps()

            l, saved = trainer.forward(model, rdata, rlabel, steps)
            grad_info = trainer.backward(model, rdata, rlabel, steps, saved)
            trainer.accumulate_grad([grad_info])

            with torch.no_grad():
                for p_idx, p in enumerate(trainer.params):
                    pdata = p.data
                    N = p.numel()
                    for flat_i in np.random.choice(N, min(N, max_num_per_param), replace=False):
                        i = []
                        for s in reversed(p.size()):
                            i.insert(0, flat_i % s)
                            flat_i //= s
                        i = tuple(i)
                        ag = p.grad[i].item()
                        orig = pdata[i].item()
                        pdata[i] -= eps
                        steps = trainer.get_steps()
                        lm, _ = trainer.forward(model, rdata, rlabel, steps)
                        pdata[i] += eps * 2
                        steps = trainer.get_steps()
                        lp, _ = trainer.forward(model, rdata, rlabel, steps)
                        ng = (lp - lm).item() / (2 * eps)
                        pdata[i] = orig
                        rel_err = abs(ag - ng) / (atol + rtol * abs(ng))
                        info_msg = "testing param {} with shape [{}] at ({}):\trel_err={:.4f}\t" \
                                   "analytical={:+.6f}\tnumerical={:+.6f}".format(
                                       p_idx, format_intlist(p.size()),
                                       format_intlist(i), rel_err, ag, ng)
                        if unittest_verbosity() > 0:
                            print(info_msg)
                        self.assertTrue(rel_err <= 1, "gradcheck failed when " + info_msg) 
开发者ID:SsnL,项目名称:dataset-distillation,代码行数:53,代码来源:test_train_distilled_image.py


注:本文中的torch.set_default_dtype方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。