本文整理汇总了Python中torch.autograd.Function方法的典型用法代码示例。如果您正苦于以下问题:Python autograd.Function方法的具体用法?Python autograd.Function怎么用?Python autograd.Function使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.autograd
的用法示例。
在下文中一共展示了autograd.Function方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def forward(self, input, target, mask):
self.loss = self.criterion(input, target*mask)
return self.loss
# class DepthTo3D(Function):
#
# def forward(self, input, pix_inv, R_inv, T):
# self.save_for_backward(input, pix_inv, R_inv, T)
# return torch.bmm(R_inv, input.resize(bs, 1, sx * sy).repeat(1, 3, 1) * pix_inv - T_var.repeat(1, 1,sx * sy)).resize(bs, 3, sx, sy)
#
# def backward(self):
#
# pix_inv, R_inv, T = self.saved_tensors
#
# return grad_input,
示例2: test_symbolic_mismatch
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_symbolic_mismatch(self):
class MyFun(Function):
@staticmethod
def symbolic(g, x):
# The inside of this function should never be invoked, because
# we will fail due to an argument mismatch first.
assert False
@staticmethod
def forward(ctx, x, y):
return x + y
x = Variable(torch.randn(2, 2).fill_(1.0))
y = Variable(torch.randn(2, 2).fill_(1.0))
# NB: Don't use expect test here, the type error wobbles depending
# on Python version
with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
export_to_string(FuncModule(MyFun().apply), (x, y))
# TODO: Do an nn style test for these
示例3: test_at_op
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_at_op(self):
x = Variable(torch.randn(3, 4))
class MyFun(Function):
@staticmethod
def symbolic(g, x):
return g.at("add", x, x)
@staticmethod
def forward(ctx, x):
return x + x
class MyModule(Module):
def forward(self, x):
return MyFun.apply(x)
self.assertONNX(MyModule(), x)
示例4: test_result_different
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def test_result_different(self):
class BrokenAdd(Function):
@staticmethod
def symbolic(g, a, b):
return g.op("Add", a, b)
@staticmethod
def forward(ctx, a, b):
return a.sub(b) # yahaha! you found me!
class MyModel(Module):
def forward(self, x, y):
return BrokenAdd().apply(x, y)
x = Variable(torch.Tensor([1,2]))
y = Variable(torch.Tensor([3,4]))
self.assertVerifyExpectFail(MyModel(), (x, y), backend)
示例5: straight_backprop
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def straight_backprop(function):
""" A function whose `derivative` is as linear
>>> straight_backprop_relu = straight_backprop(F.relu)
>>> straight_backprop_relu(tensor)
:param function: original function
:return: modified function
"""
class _StraightBackprop(Function):
@staticmethod
def forward(ctx, inputs):
return function(inputs)
@staticmethod
def backward(ctx, grad_outputs):
return grad_outputs
return _StraightBackprop.apply
示例6: grid_pooling_auto
# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import Function [as 别名]
def grid_pooling_auto(pts, feat):
xv_value, yv_value, zv_value = np.meshgrid(x_grids[:-1], y_grids[:-1], z_grids[:-1], indexing='ij')
xv_value = xv_value.flatten()
yv_value = yv_value.flatten()
zv_value = zv_value.flatten()
feat_cell = Variable(torch.zeros((len(x_grids)-1) * (len(y_grids)-1) * (len(z_grids)-1), C).type(dtype))
#for k in range(batchsize):
for i_,(x_,y_,z_) in enumerate(zip(xv_value, yv_value, zv_value)):
pts_index = pts_in_cell(pts.unsqueeze(0),[x_,y_,z_,
x_+len_cell, y_+len_cell, z_+len_cell])
if len(pts_index)>0:
pts_index = torch.LongTensor(pts_index).type(dtype_long)
#pts_feat = feat.index_select(0, pts_index)
pts_feat = feat[pts_index,:]
# max pooling
#pts_feat,_ = torch.max(pts_feat, 0)
m = nn.MaxPool1d(len(pts_index))
pts_feat = m(pts_feat.t().unsqueeze(0))
feat_cell[i_, :] = pts_feat.squeeze()
return feat_cell
#class GridPooling(Function):
# def forward(self, points, feat_points):
# feat_cells = torch.zeros(W*H*D, C).type(dtype)
# indices = -1 * torch.ones(W*H*D, C).type(dtype_long)
# shape = torch.LongTensor([W, H, D]).type(dtype_long)
# forward_utils.grid_pooling_forward(points, feat_points, shape, feat_cells, indices)
# self.saved_indices = indices
# return feat_cells
#
# def backward(self, grad_output):
# grad_points = torch.zeros(N, C).type(torch.FloatTensor)
# forward_utils.grid_pooling_backward( grad_output, self.saved_indices, grad_points)
# return None, grad_points