当前位置: 首页>>代码示例>>Python>>正文


Python Variable.send_方法代码示例

本文整理汇总了Python中torch.autograd.Variable.send_方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.send_方法的具体用法?Python Variable.send_怎么用?Python Variable.send_使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.autograd.Variable的用法示例。


在下文中一共展示了Variable.send_方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_var_gradient_keeps_id_during_send_

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import send_ [as 别名]
    def test_var_gradient_keeps_id_during_send_(self):
        # PyTorch has a tendency to delete var.grad python objects
        # and re-initialize them (resulting in new/random ids)
        # we have fixed this bug and recorded how it was fixed
        # as well as the creation of this unit test in the following
        # video (1:50:00 - 2:00:00) ish
        # https://www.twitch.tv/videos/275838386

        # this is our hook
        hook = TorchHook(verbose=False)
        local = hook.local_worker
        local.verbose = False

        remote = VirtualWorker(id=1, hook=hook, verbose=False)
        local.add_worker(remote)

        data = Var(torch.FloatTensor([[0, 0], [0, 1], [1, 0], [1, 1]]))
        target = Var(torch.FloatTensor([[0], [0], [1], [1]]))

        model = Var(torch.zeros(2, 1), requires_grad=True)

        # generates grad objects on model
        pred = data.mm(model)
        loss = ((pred - target)**2).sum()
        loss.backward()

        # the grad's true id
        original_data_id = model.data.id + 0
        original_grad_id = model.grad.data.id + 0

        model.send_(remote)

        assert model.data.id == original_data_id
        assert model.grad.data.id == original_grad_id
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:36,代码来源:torch_test.py

示例2: test_send_var_with_gradient

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import send_ [as 别名]
    def test_send_var_with_gradient(self):

        # previously, there was a bug involving sending variables with graidents
        # to remote tensors. This bug was documented in Issue 1350
        # https://github.com/OpenMined/PySyft/issues/1350

        # this is our hook
        hook = TorchHook(verbose=False)
        local = hook.local_worker
        local.verbose = False

        remote = VirtualWorker(id=1, hook=hook, verbose=False)
        local.add_worker(remote)

        data = Var(torch.FloatTensor([[0, 0], [0, 1], [1, 0], [1, 1]]))
        target = Var(torch.FloatTensor([[0], [0], [1], [1]]))

        model = Var(torch.zeros(2, 1), requires_grad=True)

        # generates grad objects on model
        pred = data.mm(model)
        loss = ((pred - target)**2).sum()
        loss.backward()

        # ensure that model and all (grand)children are owned by the local worker
        assert model.owners[0].id == local.id
        assert model.data.owners[0].id == local.id

        # if you get a failure here saying that model.grad.owners does not exist
        # check in hooks.py - _hook_new_grad(). self.grad_backup has probably either
        # been deleted or is being run at the wrong time (see comments there)
        assert model.grad.owners[0].id == local.id
        assert model.grad.data.owners[0].id == local.id

        # ensure that objects are not yet pointers (haven't sent it yet)
        assert not model.is_pointer
        assert not model.data.is_pointer
        assert not model.grad.is_pointer
        assert not model.grad.data.is_pointer

        model.send_(remote)

        # ensures that object ids do not change during the sending process
        assert model.owners[0].id == remote.id
        assert model.data.owners[0].id == remote.id
        assert model.grad.owners[0].id == remote.id
        assert model.grad.data.owners[0].id == remote.id

        # ensures that all local objects are now pointers
        assert model.is_pointer
        assert model.data.is_pointer
        assert model.grad.is_pointer
        assert model.grad.data.is_pointer

        # makes sure that tensors actually get sent to remote worker
        assert model.id in remote._objects
        assert model.data.id in remote._objects
        assert model.grad.id in remote._objects
        assert model.grad.data.id in remote._objects
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:61,代码来源:torch_test.py


注:本文中的torch.autograd.Variable.send_方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。