当前位置: 首页>>代码示例>>Python>>正文


Python torch.equal方法代码示例

本文整理汇总了Python中torch.equal方法的典型用法代码示例。如果您正苦于以下问题:Python torch.equal方法的具体用法?Python torch.equal怎么用?Python torch.equal使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.equal方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_strides

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_strides():
    from mmdet.core import AnchorGenerator
    # Square strides
    self = AnchorGenerator([10], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 5., 5., 15.], [5., 5., 15., 15.]])

    assert torch.equal(anchors[0], expected_anchors)

    # Different strides in x and y direction
    self = AnchorGenerator([(10, 20)], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 15., 5., 25.], [5., 15., 15., 25.]])

    assert torch.equal(anchors[0], expected_anchors) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:21,代码来源:test_anchor.py

示例2: test_max_pool_2d

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_max_pool_2d():
    test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]),
                              ('in_channel', [1, 3]), ('out_channel', [1, 3]),
                              ('kernel_size', [3, 5]), ('stride', [1, 2]),
                              ('padding', [0, 1]), ('dilation', [1, 2])])

    for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
            *list(test_cases.values())):
        # wrapper op with 0-dim input
        x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
        wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
        wrapper_out = wrapper(x_empty)

        # torch op with 3-dim input as shape reference
        x_normal = torch.randn(3, in_cha, in_h, in_w)
        ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
        ref_out = ref(x_normal)

        assert wrapper_out.shape[0] == 0
        assert wrapper_out.shape[1:] == ref_out.shape[1:]

        assert torch.equal(wrapper(x_normal), ref_out) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:24,代码来源:test_wrappers.py

示例3: recommend

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def recommend(self, full_graph, K, h_user, h_item):
        """
        Return a (n_user, K) matrix of recommended items for each user
        """
        graph_slice = full_graph.edge_type_subgraph([self.user_to_item_etype])
        n_users = full_graph.number_of_nodes(self.user_ntype)
        latest_interactions = dgl.sampling.select_topk(graph_slice, 1, self.timestamp, edge_dir='out')
        user, latest_items = latest_interactions.all_edges(form='uv', order='srcdst')
        # each user should have at least one "latest" interaction
        assert torch.equal(user, torch.arange(n_users))

        recommended_batches = []
        user_batches = torch.arange(n_users).split(self.batch_size)
        for user_batch in user_batches:
            latest_item_batch = latest_items[user_batch].to(device=h_item.device)
            dist = h_item[latest_item_batch] @ h_item.t()
            # exclude items that are already interacted
            for i, u in enumerate(user_batch.tolist()):
                interacted_items = full_graph.successors(u, etype=self.user_to_item_etype)
                dist[i, interacted_items] = -np.inf
            recommended_batches.append(dist.topk(K, 1)[1])

        recommendations = torch.cat(recommended_batches, 0)
        return recommendations 
开发者ID:dmlc,项目名称:dgl,代码行数:26,代码来源:evaluation.py

示例4: test_simple_conv_encoder_forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_simple_conv_encoder_forward(simple_conv_encoder_image_classification,
                                     simple_tensor_conv2d):
    """

    Args:
        simple_conv_encoder_image_classification (@pytest.fixture): SimpleConvEncoder
        simple_tensor_conv2d (@pytest.fixture): torch.Tensor

    Asserts: True if the output's dimension is equal to the input's one
             and that element-wise, the values have changed.

    """
    input_dim = simple_tensor_conv2d.dim()
    output = simple_conv_encoder_image_classification.forward(
        simple_tensor_conv2d)
    output_dim = output.dim()
    equivalent = torch.equal(simple_tensor_conv2d, output)
    assert input_dim == 4
    assert output_dim == 2
    assert not equivalent 
开发者ID:rdevon,项目名称:cortex,代码行数:22,代码来源:test_convnets.py

示例5: test_apply_nonlinearity

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_apply_nonlinearity(simple_tensor):

    """

    Args:
        simple_tensor(@pytest.fixture): torch.Tensor

    Asserts: True if the right Pytorch function is called.

    """

    nonlinearity_args = {}
    nonlinear = 'tanh'

    expected_output = torch.nn.functional.tanh(simple_tensor)
    applied_nonlinearity = apply_nonlinearity(simple_tensor, nonlinear,
                                              **nonlinearity_args)

    assert torch.equal(expected_output, applied_nonlinearity) 
开发者ID:rdevon,项目名称:cortex,代码行数:21,代码来源:test_network_utils.py

示例6: test_std_share_network_output_values

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_std_share_network_output_values(input_dim, output_dim, hidden_sizes):
    module = GaussianMLPTwoHeadedModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:22,代码来源:test_gaussian_mlp_module.py

示例7: test_std_share_network_output_values_with_batch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_std_share_network_output_values_with_batch(input_dim, output_dim,
                                                    hidden_sizes):
    module = GaussianMLPTwoHeadedModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    batch_size = 5
    dist = module(torch.ones([batch_size, input_dim]))

    exp_mean = torch.full(
        (batch_size, output_dim),
        input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(
        torch.full((batch_size, output_dim), exp_variance))
    assert dist.rsample().shape == (batch_size, output_dim) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:26,代码来源:test_gaussian_mlp_module.py

示例8: test_std_network_output_values

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_std_network_output_values(input_dim, output_dim, hidden_sizes):
    init_std = 2.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=init_std,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = init_std**2

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:24,代码来源:test_gaussian_mlp_module.py

示例9: test_std_adaptive_network_output_values

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_std_adaptive_network_output_values(input_dim, output_dim,
                                            hidden_sizes, std_hidden_sizes):
    module = GaussianMLPIndependentStdModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        std_hidden_sizes=std_hidden_sizes,
        hidden_nonlinearity=None,
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_,
        std_hidden_nonlinearity=None,
        std_hidden_w_init=nn.init.ones_,
        std_output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_mean = torch.full(
        (output_dim, ), input_dim * (torch.Tensor(hidden_sizes).prod().item()))
    exp_variance = (
        input_dim * torch.Tensor(hidden_sizes).prod()).exp().pow(2).item()

    assert dist.mean.equal(exp_mean)
    assert dist.variance.equal(torch.full((output_dim, ), exp_variance))
    assert dist.rsample().shape == (output_dim, ) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:26,代码来源:test_gaussian_mlp_module.py

示例10: test_exp_min_std

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_exp_min_std(input_dim, output_dim, hidden_sizes):
    min_value = 10.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=1.,
        min_std=min_value,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = min_value**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance)) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:21,代码来源:test_gaussian_mlp_module.py

示例11: test_exp_max_std

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_exp_max_std(input_dim, output_dim, hidden_sizes):
    max_value = 1.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=10.,
        max_std=max_value,
        hidden_nonlinearity=None,
        std_parameterization='exp',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = max_value**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance)) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:21,代码来源:test_gaussian_mlp_module.py

示例12: test_softplus_min_std

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_softplus_min_std(input_dim, output_dim, hidden_sizes):
    min_value = 2.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=1.,
        min_std=min_value,
        hidden_nonlinearity=None,
        std_parameterization='softplus',
        hidden_w_init=nn.init.zeros_,
        output_w_init=nn.init.zeros_)

    dist = module(torch.ones(input_dim))

    exp_variance = torch.Tensor([min_value]).exp().add(1.).log()**2

    assert dist.variance.equal(torch.full((output_dim, ), exp_variance[0])) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:21,代码来源:test_gaussian_mlp_module.py

示例13: test_softplus_max_std

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def test_softplus_max_std(input_dim, output_dim, hidden_sizes):
    max_value = 1.

    module = GaussianMLPModule(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_sizes=hidden_sizes,
        init_std=10,
        max_std=max_value,
        hidden_nonlinearity=None,
        std_parameterization='softplus',
        hidden_w_init=nn.init.ones_,
        output_w_init=nn.init.ones_)

    dist = module(torch.ones(input_dim))

    exp_variance = torch.Tensor([max_value]).exp().add(1.).log()**2

    assert torch.equal(dist.variance,
                       torch.full((output_dim, ), exp_variance[0])) 
开发者ID:rlworkgroup,项目名称:garage,代码行数:22,代码来源:test_gaussian_mlp_module.py

示例14: testRandomForwards

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def testRandomForwards(self):
        """Test reference and patched net forward equivalence.

        Test if, given rand fast weights, patched net and reference forwards
        match up given random inputs.
        """
        with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _):
            for i in range(10):
                fast_named_weights = OrderedDict(
                    (name, torch.rand(p.shape, requires_grad=True))
                    for name, p in self.reference_net.named_parameters()
                )
                fast_weights = [p for _, p in fast_named_weights.items()]
                inputs = torch.rand(
                    self.batch_size, self.num_in_channels, self.in_h, self.in_w
                )
                self.assertTrue(
                    torch.equal(
                        self.reference_net(inputs, params=fast_named_weights),
                        fnet(inputs, params=fast_weights)
                    )
                ) 
开发者ID:facebookresearch,项目名称:higher,代码行数:24,代码来源:test_higher.py

示例15: testSubModuleDirectCall

# 需要导入模块: import torch [as 别名]
# 或者: from torch import equal [as 别名]
def testSubModuleDirectCall(self):
        """Check that patched submodules can be called directly."""
        class Module(nn.Module):
            def __init__(self):
                super().__init__()
                self.submodule = nn.Linear(3, 4)

            def forward(self, inputs):
                return self.submodule(inputs)

        module = _NestedEnc(nn.Linear(3, 4))
        fmodule = higher.monkeypatch(module)

        xs = torch.randn(2, 3)
        fsubmodule = fmodule.f

        self.assertTrue(torch.equal(fmodule(xs), fsubmodule(xs))) 
开发者ID:facebookresearch,项目名称:higher,代码行数:19,代码来源:test_patch.py


注:本文中的torch.equal方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。