本文整理汇总了Python中torch.set_default_dtype方法的典型用法代码示例。如果您正苦于以下问题:Python torch.set_default_dtype方法的具体用法?Python torch.set_default_dtype怎么用?Python torch.set_default_dtype使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.set_default_dtype方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
n, d, m1, m2 = 2, 3, 5, 7
# numpy
import numpy as np
input_numpy = np.random.randn(n, d, m1)
weight = np.random.randn(d, m1, m2)
output_numpy = np.zeros([n, d, m2])
for j in range(d):
# [n, m2] = [n, m1] @ [m1, m2]
output_numpy[:, j, :] = input_numpy[:, j, :] @ weight[j, :, :]
# torch
torch.set_default_dtype(torch.double)
input_torch = torch.from_numpy(input_numpy)
locally_connected = LocallyConnected(d, m1, m2, bias=False)
locally_connected.weight.data[:] = torch.from_numpy(weight)
output_torch = locally_connected(input_torch)
# compare
print(torch.allclose(output_torch, torch.from_numpy(output_numpy)))
示例2: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3)
import notears.utils as ut
ut.set_random_seed(123)
n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
B_true = ut.simulate_dag(d, s0, graph_type)
np.savetxt('W_true.csv', B_true, delimiter=',')
X = ut.simulate_nonlinear_sem(B_true, n, sem_type)
np.savetxt('X.csv', X, delimiter=',')
model = NotearsMLP(dims=[d, 10, 1], bias=True)
W_est = notears_nonlinear(model, X, lambda1=0.01, lambda2=0.01)
assert ut.is_dag(W_est)
np.savetxt('W_est.csv', W_est, delimiter=',')
acc = ut.count_accuracy(B_true, W_est != 0)
print(acc)
示例3: withdtype
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def withdtype():
torch.set_default_dtype(torch.float64)
try:
yield
finally:
torch.set_default_dtype(torch.float32)
示例4: use_floatX
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def use_floatX(request):
dtype_old = torch.get_default_dtype()
torch.set_default_dtype(request.param)
yield request.param
torch.set_default_dtype(dtype_old)
示例5: __enter__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __enter__(self):
"""Set new dtype."""
self.old_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.new_dtype)
示例6: __exit__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __exit__(self, *args):
"""Restor old dtype."""
torch.set_default_dtype(self.old_dtype)
示例7: test_set_torch_dtype
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def test_set_torch_dtype():
"""Test dtype context manager."""
dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
with SetTorchDtype(torch.float64):
a = torch.zeros(1)
assert a.dtype is torch.float64
b = torch.zeros(1)
assert b.dtype is torch.float32
torch.set_default_dtype(dtype)
示例8: setUp
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def setUp(self):
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
示例9: tearDown
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def tearDown(self):
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.float32)
示例10: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def __init__(self, **kwargs):
self.name = 'pytorch'
self.precision = kwargs.get('precision', '32b')
self.dtypemap = {
'float': torch.float64 if self.precision == '64b' else torch.float32,
'int': torch.int64 if self.precision == '64b' else torch.int32,
'bool': torch.bool,
}
torch.set_default_dtype(self.dtypemap["float"])
示例11: main
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def main():
import torch.nn as nn
# torch.set_default_dtype(torch.double)
n, d, out, j = 10000, 3000, 10, 0
input = torch.randn(n, d)
w_true = torch.rand(d, out)
w_true[j, :] = 0
target = torch.matmul(input, w_true)
linear = nn.Linear(d, out)
linear.weight.bounds = [(0, None)] * d * out # hack
for m in range(out):
linear.weight.bounds[m * d + j] = (0, 0)
criterion = nn.MSELoss()
optimizer = LBFGSBScipy(linear.parameters())
print(list(linear.parameters()))
def closure():
optimizer.zero_grad()
output = linear(input)
loss = criterion(output, target)
print('loss:', loss.item())
loss.backward()
return loss
optimizer.step(closure)
print(list(linear.parameters()))
print(w_true.t())
示例12: _test_backward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import set_default_dtype [as 别名]
def _test_backward(self, state, eps=2e-8, atol=1e-5, rtol=1e-3, max_num_per_param=5):
@contextlib.contextmanager
def double_prec():
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
yield
torch.set_default_dtype(saved_dtype)
with double_prec():
models = [m.to(torch.double) for m in networks.get_networks(state, 1)]
trainer = Trainer(state, models)
model = trainer.models[0]
rdata, rlabel = next(iter(state.train_loader))
rdata = rdata.to(state.device, torch.double, non_blocking=True)
rlabel = rlabel.to(state.device, non_blocking=True)
steps = trainer.get_steps()
l, saved = trainer.forward(model, rdata, rlabel, steps)
grad_info = trainer.backward(model, rdata, rlabel, steps, saved)
trainer.accumulate_grad([grad_info])
with torch.no_grad():
for p_idx, p in enumerate(trainer.params):
pdata = p.data
N = p.numel()
for flat_i in np.random.choice(N, min(N, max_num_per_param), replace=False):
i = []
for s in reversed(p.size()):
i.insert(0, flat_i % s)
flat_i //= s
i = tuple(i)
ag = p.grad[i].item()
orig = pdata[i].item()
pdata[i] -= eps
steps = trainer.get_steps()
lm, _ = trainer.forward(model, rdata, rlabel, steps)
pdata[i] += eps * 2
steps = trainer.get_steps()
lp, _ = trainer.forward(model, rdata, rlabel, steps)
ng = (lp - lm).item() / (2 * eps)
pdata[i] = orig
rel_err = abs(ag - ng) / (atol + rtol * abs(ng))
info_msg = "testing param {} with shape [{}] at ({}):\trel_err={:.4f}\t" \
"analytical={:+.6f}\tnumerical={:+.6f}".format(
p_idx, format_intlist(p.size()),
format_intlist(i), rel_err, ag, ng)
if unittest_verbosity() > 0:
print(info_msg)
self.assertTrue(rel_err <= 1, "gradcheck failed when " + info_msg)