當前位置: 首頁>>代碼示例>>Python>>正文


Python autograd.Function方法代碼示例

本文整理匯總了Python中torch.autograd.Function方法的典型用法代碼示例。如果您正苦於以下問題:Python autograd.Function方法的具體用法?Python autograd.Function怎麽用?Python autograd.Function使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.autograd的用法示例。


在下文中一共展示了autograd.Function方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def forward(self, input, target, mask):
        self.loss = self.criterion(input, target*mask)
        return self.loss


# class DepthTo3D(Function):
#
#     def forward(self, input, pix_inv, R_inv, T):
#         self.save_for_backward(input, pix_inv, R_inv, T)
#         return torch.bmm(R_inv, input.resize(bs, 1, sx * sy).repeat(1, 3, 1) * pix_inv - T_var.repeat(1, 1,sx * sy)).resize(bs, 3, sx, sy)
#
#     def backward(self):
#
#         pix_inv, R_inv, T = self.saved_tensors
#
#         return grad_input, 
開發者ID:krematas,項目名稱:soccerontable,代碼行數:18,代碼來源:losses.py

示例2: test_symbolic_mismatch

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def test_symbolic_mismatch(self):
        class MyFun(Function):
            @staticmethod
            def symbolic(g, x):
                # The inside of this function should never be invoked, because
                # we will fail due to an argument mismatch first.
                assert False

            @staticmethod
            def forward(ctx, x, y):
                return x + y

        x = Variable(torch.randn(2, 2).fill_(1.0))
        y = Variable(torch.randn(2, 2).fill_(1.0))
        # NB: Don't use expect test here, the type error wobbles depending
        # on Python version
        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
            export_to_string(FuncModule(MyFun().apply), (x, y))

    # TODO: Do an nn style test for these 
開發者ID:onnxbot,項目名稱:onnx-fb-universe,代碼行數:22,代碼來源:test_operators.py

示例3: test_at_op

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def test_at_op(self):
        x = Variable(torch.randn(3, 4))

        class MyFun(Function):

            @staticmethod
            def symbolic(g, x):
                return g.at("add", x, x)

            @staticmethod
            def forward(ctx, x):
                return x + x

        class MyModule(Module):
            def forward(self, x):
                return MyFun.apply(x)

        self.assertONNX(MyModule(), x) 
開發者ID:onnxbot,項目名稱:onnx-fb-universe,代碼行數:20,代碼來源:test_operators.py

示例4: test_result_different

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def test_result_different(self):
        class BrokenAdd(Function):
            @staticmethod
            def symbolic(g, a, b):
                return g.op("Add", a, b)

            @staticmethod
            def forward(ctx, a, b):
                return a.sub(b) # yahaha! you found me!

        class MyModel(Module):
            def forward(self, x, y):
                return BrokenAdd().apply(x, y)

        x = Variable(torch.Tensor([1,2]))
        y = Variable(torch.Tensor([3,4]))
        self.assertVerifyExpectFail(MyModel(), (x, y), backend) 
開發者ID:onnxbot,項目名稱:onnx-fb-universe,代碼行數:19,代碼來源:test_verify.py

示例5: straight_backprop

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def straight_backprop(function):
    """ A function whose `derivative` is as linear

    >>> straight_backprop_relu = straight_backprop(F.relu)
    >>> straight_backprop_relu(tensor)

    :param function: original function
    :return: modified function
    """

    class _StraightBackprop(Function):
        @staticmethod
        def forward(ctx, inputs):
            return function(inputs)

        @staticmethod
        def backward(ctx, grad_outputs):
            return grad_outputs

    return _StraightBackprop.apply 
開發者ID:moskomule,項目名稱:homura,代碼行數:22,代碼來源:miscs.py

示例6: grid_pooling_auto

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import Function [as 別名]
def grid_pooling_auto(pts, feat):
    xv_value, yv_value, zv_value = np.meshgrid(x_grids[:-1], y_grids[:-1], z_grids[:-1], indexing='ij')
    xv_value = xv_value.flatten()
    yv_value = yv_value.flatten()
    zv_value = zv_value.flatten()

    
    feat_cell = Variable(torch.zeros((len(x_grids)-1) * (len(y_grids)-1) * (len(z_grids)-1), C).type(dtype))
    #for k in range(batchsize):
    for i_,(x_,y_,z_) in enumerate(zip(xv_value, yv_value, zv_value)): 
        pts_index = pts_in_cell(pts.unsqueeze(0),[x_,y_,z_,
            x_+len_cell, y_+len_cell, z_+len_cell])
        if len(pts_index)>0:
            pts_index = torch.LongTensor(pts_index).type(dtype_long)
            #pts_feat = feat.index_select(0, pts_index)
            pts_feat = feat[pts_index,:]
            # max pooling
            #pts_feat,_ = torch.max(pts_feat, 0)
            m = nn.MaxPool1d(len(pts_index))
            pts_feat = m(pts_feat.t().unsqueeze(0))
            feat_cell[i_, :] = pts_feat.squeeze()
    return feat_cell

#class GridPooling(Function):
#    def forward(self, points, feat_points):
#        feat_cells = torch.zeros(W*H*D, C).type(dtype)
#        indices = -1 * torch.ones(W*H*D, C).type(dtype_long)
#        shape = torch.LongTensor([W, H, D]).type(dtype_long)
#        forward_utils.grid_pooling_forward(points, feat_points, shape, feat_cells, indices) 
#        self.saved_indices = indices
#        return feat_cells 
#
#    def backward(self, grad_output):
#        grad_points = torch.zeros(N, C).type(torch.FloatTensor)
#        forward_utils.grid_pooling_backward( grad_output, self.saved_indices, grad_points) 
#        return None, grad_points 
開發者ID:autonomousvision,項目名稱:occupancy_networks,代碼行數:38,代碼來源:test_gridpooling.py


注:本文中的torch.autograd.Function方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。