当前位置: 首页>>代码示例>>Python>>正文


Python autograd.Function方法代码示例

本文整理汇总了Python中torch.autograd.Function方法的典型用法代码示例。如果您正苦于以下问题:Python autograd.Function方法的具体用法?Python autograd.Function怎么用?Python autograd.Function使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.autograd的用法示例。


在下文中一共展示了autograd.Function方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def forward(self, input, target, mask):
        self.loss = self.criterion(input, target*mask)
        return self.loss


# class DepthTo3D(Function):
#
#     def forward(self, input, pix_inv, R_inv, T):
#         self.save_for_backward(input, pix_inv, R_inv, T)
#         return torch.bmm(R_inv, input.resize(bs, 1, sx * sy).repeat(1, 3, 1) * pix_inv - T_var.repeat(1, 1,sx * sy)).resize(bs, 3, sx, sy)
#
#     def backward(self):
#
#         pix_inv, R_inv, T = self.saved_tensors
#
#         return grad_input, 
开发者ID:krematas,项目名称:soccerontable,代码行数:18,代码来源:losses.py

示例2: test_symbolic_mismatch

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_symbolic_mismatch(self):
        class MyFun(Function):
            @staticmethod
            def symbolic(g, x):
                # The inside of this function should never be invoked, because
                # we will fail due to an argument mismatch first.
                assert False

            @staticmethod
            def forward(ctx, x, y):
                return x + y

        x = Variable(torch.randn(2, 2).fill_(1.0))
        y = Variable(torch.randn(2, 2).fill_(1.0))
        # NB: Don't use expect test here, the type error wobbles depending
        # on Python version
        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
            export_to_string(FuncModule(MyFun().apply), (x, y))

    # TODO: Do an nn style test for these 
开发者ID:onnxbot,项目名称:onnx-fb-universe,代码行数:22,代码来源:test_operators.py

示例3: test_at_op

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_at_op(self):
        x = Variable(torch.randn(3, 4))

        class MyFun(Function):

            @staticmethod
            def symbolic(g, x):
                return g.at("add", x, x)

            @staticmethod
            def forward(ctx, x):
                return x + x

        class MyModule(Module):
            def forward(self, x):
                return MyFun.apply(x)

        self.assertONNX(MyModule(), x) 
开发者ID:onnxbot,项目名称:onnx-fb-universe,代码行数:20,代码来源:test_operators.py

示例4: test_result_different

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_result_different(self):
        class BrokenAdd(Function):
            @staticmethod
            def symbolic(g, a, b):
                return g.op("Add", a, b)

            @staticmethod
            def forward(ctx, a, b):
                return a.sub(b) # yahaha! you found me!

        class MyModel(Module):
            def forward(self, x, y):
                return BrokenAdd().apply(x, y)

        x = Variable(torch.Tensor([1,2]))
        y = Variable(torch.Tensor([3,4]))
        self.assertVerifyExpectFail(MyModel(), (x, y), backend) 
开发者ID:onnxbot,项目名称:onnx-fb-universe,代码行数:19,代码来源:test_verify.py

示例5: straight_backprop

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def straight_backprop(function):
    """ A function whose `derivative` is as linear

    >>> straight_backprop_relu = straight_backprop(F.relu)
    >>> straight_backprop_relu(tensor)

    :param function: original function
    :return: modified function
    """

    class _StraightBackprop(Function):
        @staticmethod
        def forward(ctx, inputs):
            return function(inputs)

        @staticmethod
        def backward(ctx, grad_outputs):
            return grad_outputs

    return _StraightBackprop.apply 
开发者ID:moskomule,项目名称:homura,代码行数:22,代码来源:miscs.py

示例6: grid_pooling_auto

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def grid_pooling_auto(pts, feat):
    xv_value, yv_value, zv_value = np.meshgrid(x_grids[:-1], y_grids[:-1], z_grids[:-1], indexing='ij')
    xv_value = xv_value.flatten()
    yv_value = yv_value.flatten()
    zv_value = zv_value.flatten()

    
    feat_cell = Variable(torch.zeros((len(x_grids)-1) * (len(y_grids)-1) * (len(z_grids)-1), C).type(dtype))
    #for k in range(batchsize):
    for i_,(x_,y_,z_) in enumerate(zip(xv_value, yv_value, zv_value)): 
        pts_index = pts_in_cell(pts.unsqueeze(0),[x_,y_,z_,
            x_+len_cell, y_+len_cell, z_+len_cell])
        if len(pts_index)>0:
            pts_index = torch.LongTensor(pts_index).type(dtype_long)
            #pts_feat = feat.index_select(0, pts_index)
            pts_feat = feat[pts_index,:]
            # max pooling
            #pts_feat,_ = torch.max(pts_feat, 0)
            m = nn.MaxPool1d(len(pts_index))
            pts_feat = m(pts_feat.t().unsqueeze(0))
            feat_cell[i_, :] = pts_feat.squeeze()
    return feat_cell

#class GridPooling(Function):
#    def forward(self, points, feat_points):
#        feat_cells = torch.zeros(W*H*D, C).type(dtype)
#        indices = -1 * torch.ones(W*H*D, C).type(dtype_long)
#        shape = torch.LongTensor([W, H, D]).type(dtype_long)
#        forward_utils.grid_pooling_forward(points, feat_points, shape, feat_cells, indices) 
#        self.saved_indices = indices
#        return feat_cells 
#
#    def backward(self, grad_output):
#        grad_points = torch.zeros(N, C).type(torch.FloatTensor)
#        forward_utils.grid_pooling_backward( grad_output, self.saved_indices, grad_points) 
#        return None, grad_points 
开发者ID:autonomousvision,项目名称:occupancy_networks,代码行数:38,代码来源:test_gridpooling.py


注:本文中的torch.autograd.Function方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。