当前位置: 首页>>代码示例>>Python>>正文


Python torch.allclose方法代码示例

本文整理汇总了Python中torch.allclose方法的典型用法代码示例。如果您正苦于以下问题:Python torch.allclose方法的具体用法?Python torch.allclose怎么用?Python torch.allclose使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.allclose方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_ce_loss

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_ce_loss():
    # use_mask and use_sigmoid cannot be true at the same time
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='CrossEntropyLoss',
            use_mask=True,
            use_sigmoid=True,
            loss_weight=1.0)
        build_loss(loss_cfg)

    # test loss with class weights
    loss_cls_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=False,
        class_weight=[0.8, 0.2],
        loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[100, -100]])
    fake_label = torch.Tensor([1]).long()
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

    loss_cls_cfg = dict(
        type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:27,代码来源:test_losses.py

示例2: _check_point_on_manifold

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def _check_point_on_manifold(self, x, *, atol=1e-4, rtol=1e-4):
        row_sum = x.sum(dim=-1)
        col_sum = x.sum(dim=-2)
        row_ok = torch.allclose(
            row_sum, row_sum.new((1,)).fill_(1), atol=atol, rtol=rtol
        )
        col_ok = torch.allclose(
            col_sum, col_sum.new((1,)).fill_(1), atol=atol, rtol=rtol
        )
        if row_ok and col_ok:
            return True, None
        else:
            return (
                False,
                "illegal doubly stochastic matrix with atol={}, rtol={}".format(
                    atol, rtol
                ),
            ) 
开发者ID:geoopt,项目名称:geoopt,代码行数:20,代码来源:birkhoff_polytope.py

示例3: _check_point_on_manifold

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def _check_point_on_manifold(
        self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
    ) -> Tuple[bool, Optional[str]]:
        norm = x.norm(dim=-1)
        ok = torch.allclose(norm, norm.new((1,)).fill_(1), atol=atol, rtol=rtol)
        if not ok:
            return False, "`norm(x) != 1` with atol={}, rtol={}".format(atol, rtol)
        ok = torch.allclose(self._project_on_subspace(x), x, atol=atol, rtol=rtol)
        if not ok:
            return (
                False,
                "`x` is not in the subspace of the manifold with atol={}, rtol={}".format(
                    atol, rtol
                ),
            )
        return True, None 
开发者ID:geoopt,项目名称:geoopt,代码行数:18,代码来源:sphere.py

示例4: test_arbitrary_dimension

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_arbitrary_dimension(dim):
    shape = [3, 4, 2, 5]
    X = torch.randn(*shape, dtype=torch.float64)

    alpha_shape = shape
    alpha_shape[dim] = 1

    alphas = 1.05 + torch.rand(alpha_shape, dtype=torch.float64)

    P = entmax_bisect(X, alpha=alphas, dim=dim)

    ranges = [
        list(range(k)) if i != dim else [slice(None)]
        for i, k in enumerate(shape)
    ]

    for ix in product(*ranges):
        x = X[ix].unsqueeze(0)
        alpha = alphas[ix].item()
        p_true = entmax_bisect(x, alpha=alpha, dim=-1)
        assert torch.allclose(P[ix], p_true) 
开发者ID:deep-spin,项目名称:entmax,代码行数:23,代码来源:test_root_finding.py

示例5: test_polar

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_polar():
    assert Polar().__repr__() == 'Polar(norm=True, max_value=None)'

    pos = torch.Tensor([[0, 0], [1, 0]])
    edge_index = torch.tensor([[0, 1], [1, 0]])
    edge_attr = torch.Tensor([1, 1])

    data = Data(edge_index=edge_index, pos=pos)
    data = Polar(norm=False)(data)
    assert len(data) == 3
    assert data.pos.tolist() == pos.tolist()
    assert data.edge_index.tolist() == edge_index.tolist()
    assert torch.allclose(
        data.edge_attr, torch.Tensor([[1, 0], [1, PI]]), atol=1e-04)

    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)
    data = Polar(norm=True)(data)
    assert len(data) == 3
    assert data.pos.tolist() == pos.tolist()
    assert data.edge_index.tolist() == edge_index.tolist()
    assert torch.allclose(
        data.edge_attr, torch.Tensor([[1, 1, 0], [1, 1, 0.5]]), atol=1e-04) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:24,代码来源:test_polar.py

示例6: test_permuted_global_pool

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_permuted_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
    perm = torch.randperm(N_1 + N_2)

    px = x[perm]
    pbatch = batch[perm]
    px1 = px[pbatch == 0]
    px2 = px[pbatch == 1]

    out = global_add_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.sum(dim=0))
    assert torch.allclose(out[1], px2.sum(dim=0))

    out = global_mean_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.mean(dim=0))
    assert torch.allclose(out[1], px2.mean(dim=0))

    out = global_max_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.max(dim=0)[0])
    assert torch.allclose(out[1], px2.max(dim=0)[0]) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:27,代码来源:test_glob.py

示例7: test_static_graph

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_static_graph():
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    x1, x2 = torch.randn(3, 8), torch.randn(3, 8)

    data1 = Data(edge_index=edge_index, x=x1)
    data2 = Data(edge_index=edge_index, x=x2)
    batch = Batch.from_data_list([data1, data2])

    x = torch.stack([x1, x2], dim=0)
    for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]:
        out1 = conv(batch.x, batch.edge_index)
        assert out1.size(0) == 6
        conv.node_dim = 1
        out2 = conv(x, edge_index)
        assert out2.size()[:2] == (2, 3)
        assert torch.allclose(out1, out2.view(-1, out2.size(-1))) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:18,代码来源:test_static_graph.py

示例8: test_appnp

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_appnp():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = APPNP(K=10, alpha=0.1)
    assert conv.__repr__() == 'APPNP(K=10, alpha=0.1)'
    out = conv(x, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:21,代码来源:test_appnp.py

示例9: test_cluster_gcn_conv

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_cluster_gcn_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ClusterGCNConv(16, 32, diag_lambda=1.)
    assert conv.__repr__() == 'ClusterGCNConv(16, 32, diag_lambda=1.0)'
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert conv(x, edge_index, size=(4, 4)).tolist() == out.tolist()
    assert torch.allclose(conv(x, adj.t()), out)

    t = '(Tensor, Tensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()
    assert jit(x, edge_index, size=(4, 4)).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:23,代码来源:test_cluster_gcn_conv.py

示例10: test_rgcn_conv_equality

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_rgcn_conv_equality(conf):
    num_bases, num_blocks = conf

    x1 = torch.randn(4, 4)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]])
    edge_type = torch.tensor([0, 1, 1, 0, 0, 1])

    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3],
        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1],
    ])
    edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3])

    torch.manual_seed(12345)
    conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks)

    torch.manual_seed(12345)
    conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks)

    out1 = conv1(x1, edge_index, edge_type)
    out2 = conv2(x1, edge_index, edge_type)
    assert torch.allclose(out1, out2, atol=1e-6) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:24,代码来源:test_rgcn_conv.py

示例11: test_agnn_conv

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_agnn_conv(requires_grad):
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = AGNNConv(requires_grad=requires_grad)
    assert conv.__repr__() == 'AGNNConv()'
    out = conv(x, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:21,代码来源:test_agnn_conv.py

示例12: test_arma_conv

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_arma_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ARMAConv(16, 32, num_stacks=8, num_layers=4)
    assert conv.__repr__() == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)'
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert conv(x, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:21,代码来源:test_arma_conv.py

示例13: test_grad1

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def test_grad1():
    torch.manual_seed(1)
    model = Net()
    loss_fn = nn.CrossEntropyLoss()

    n = 4
    data = torch.rand(n, 1, 28, 28)
    targets = torch.LongTensor(n).random_(0, 10)

    autograd_hacks.add_hooks(model)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    autograd_hacks.compute_grad1(model)
    autograd_hacks.disable_hooks()

    # Compare values against autograd
    losses = torch.stack([loss_fn(output[i:i+1], targets[i:i+1]) for i in range(len(data))])

    for layer in model.modules():
        if not autograd_hacks.is_supported(layer):
            continue
        for param in layer.parameters():
            assert torch.allclose(param.grad, param.grad1.mean(dim=0))
            assert torch.allclose(jacobian(losses, param), param.grad1) 
开发者ID:cybertronai,项目名称:autograd-hacks,代码行数:26,代码来源:autograd_hacks_test.py

示例14: _compare_momentum_values

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def _compare_momentum_values(self, optim1, optim2):
        self.assertEqual(len(optim1["param_groups"]), len(optim2["param_groups"]))

        for i in range(len(optim1["param_groups"])):
            self.assertEqual(
                len(optim1["param_groups"][i]["params"]),
                len(optim2["param_groups"][i]["params"]),
            )
            if self._check_momentum_buffer():
                for j in range(len(optim1["param_groups"][i]["params"])):
                    id1 = optim1["param_groups"][i]["params"][j]
                    id2 = optim2["param_groups"][i]["params"][j]
                    self.assertTrue(
                        torch.allclose(
                            optim1["state"][id1]["momentum_buffer"],
                            optim2["state"][id2]["momentum_buffer"],
                        )
                    ) 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:20,代码来源:optim_test_util.py

示例15: compare_batches

# 需要导入模块: import torch [as 别名]
# 或者: from torch import allclose [as 别名]
def compare_batches(test_fixture, batch1, batch2):
    """Compare two batches. Does not do recursive comparison"""
    test_fixture.assertEqual(type(batch1), type(batch2))
    if isinstance(batch1, (tuple, list)):
        test_fixture.assertEqual(len(batch1), len(batch2))
        for n in range(len(batch1)):
            value1 = batch1[n]
            value2 = batch2[n]
            test_fixture.assertEqual(type(value1), type(value2))
            if torch.is_tensor(value1):
                test_fixture.assertTrue(torch.allclose(value1, value2))
            else:
                test_fixture.assertEqual(value1, value2)

    elif isinstance(batch1, dict):
        test_fixture.assertEqual(batch1.keys(), batch2.keys())
        for key, value1 in batch1.items():
            value2 = batch2[key]
            test_fixture.assertEqual(type(value1), type(value2))
            if torch.is_tensor(value1):
                test_fixture.assertTrue(torch.allclose(value1, value2))
            else:
                test_fixture.assertEqual(value1, value2) 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:25,代码来源:utils.py


注:本文中的torch.allclose方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。