本文整理汇总了Python中torch.fft方法的典型用法代码示例。如果您正苦于以下问题:Python torch.fft方法的具体用法?Python torch.fft怎么用?Python torch.fft使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.fft方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: fft2
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft2(data):
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
-3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
assumed to be batch dimensions.
Returns:
torch.Tensor: The FFT of the input.
"""
assert data.size(-1) == 2
data = ifftshift(data, dim=(-3, -2))
data = torch.fft(data, 2, normalized=True)
data = fftshift(data, dim=(-3, -2))
return data
示例2: pad_rfft3
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_rfft3(f, onesided=True):
"""
padded batch real fft
:param f: tensor of shape [..., res0, res1, res2]
"""
n0, n1, n2 = f.shape[-3:]
h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)
F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
F2[..., h2, :] = 0
F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
F1[..., h1,:] = 0
F1 = F1.transpose(-2,-3)
F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
F0[..., h0,:] = 0
F0 = F0.transpose(-2,-4)
return F0
示例3: pad_fft2
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_fft2(f):
"""
padded batch real fft
:param f: tensor of shape [..., res0, res1]
"""
n0, n1 = f.shape[-2:]
h0, h1 = int(n0/2), int(n1/2)
# turn f into complex signal
f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2]
F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2]
F1[..., h1,:] = 0 # [..., res0, res1, 2]
F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1)
F0[..., h0,:] = 0
F0 = F0.transpose(-2,-3)
return F0
示例4: rfftfreqs
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def rfftfreqs(res, dtype=torch.float32, exact=True):
"""
Helper function to return frequency tensors
:param res: n_dims int tuple of number of frequency modes
:return: frequency tensor of shape [dim, res, res, res/2+1]
"""
# print("res",res)
n_dims = len(res)
freqs = []
for dim in range(n_dims - 1):
r_ = res[dim]
freq = np.fft.fftfreq(r_, d=1/r_)
freqs.append(torch.tensor(freq, dtype=dtype))
r_ = res[-1]
if exact:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
else:
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
omega = torch.meshgrid(freqs)
omega = list(omega)
omega = torch.stack(omega, dim=0)
# print("omega.shape",omega.shape)
return omega
示例5: test_butterfly_fft
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def test_butterfly_fft():
# DFT matrix for n = 4
size = 4
DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
P = real_to_complex(torch.tensor([[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.]]))
M0 = Butterfly(size,
diagonal=2,
complex=True,
diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
M1 = Butterfly(size,
diagonal=1,
complex=True,
diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
br_perm = torch.tensor(bitreversal_permutation(size))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
D = complex_matmul(DFT, P.transpose(0, 1))
assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
示例6: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def forward(self, x):
bsn = 1
batchSize, dim, h, w = x.data.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim,
y = torch.ones(batchSize, self.output_dim, device=x.device)
for img in range(batchSize // bsn):
segLen = bsn * h * w
upper = batchSize * h * w
interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
batch_x = x_flat[interLarge, :]
sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)
sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)
Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])
tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]
y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)
y = self._signed_sqrt(y)
y = self._l2norm(y)
return y
示例7: fft
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft(t):
return torch.fft(t, 2)
示例8: get_uperleft_denominator
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def get_uperleft_denominator(img, kernel):
'''
img: HxWxC
kernel: hxw
denominator: HxWx1
upperleft: HxWxC
'''
V = psf2otf(kernel, img.shape[:2])
denominator = np.expand_dims(np.abs(V)**2, axis=2)
upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
return upperleft, denominator
示例9: otf2psf
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def otf2psf(otf, outsize=None):
insize = np.array(otf.shape)
psf = np.fft.ifftn(otf, axes=(0, 1))
for axis, axis_size in enumerate(insize):
psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
if type(outsize) != type(None):
insize = np.array(otf.shape)
outsize = np.array(outsize)
n = max(np.size(outsize), np.size(insize))
# outsize = postpad(outsize(:), n, 1);
# insize = postpad(insize(:) , n, 1);
colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
colvec_in = insize.flatten().reshape((np.size(insize), 1))
outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")
pad = (insize - outsize) / 2
if np.any(pad < 0):
print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
prepad = np.floor(pad)
postpad = np.ceil(pad)
dims_start = prepad.astype(int)
dims_end = (insize - postpad).astype(int)
for i in range(len(dims_start.shape)):
psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
n_ops = np.sum(otf.size * np.log2(otf.shape))
psf = np.real_if_close(psf, tol=n_ops)
return psf
# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py
示例10: fft
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft(t):
# Complex-to-complex Discrete Fourier Transform
return torch.fft(t, 2)
示例11: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def forward(self, head, rel, tail):
h_e, r_e, t_e = self.embed(head, rel, tail)
r_e = F.normalize(r_e, p=2, dim=-1)
h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
e, _ = torch.unbind(torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1)
return -F.sigmoid(torch.sum(r_e * e, 1))
示例12: pad_irfft3
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_irfft3(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
res = F.shape[-3]
f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
return f2
示例13: pad_ifft2
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_ifft2(F):
"""
padded batch inverse real fft
:param f: tensor of shape [..., res0, res1, res2/2+1, 2]
"""
f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
f1 = torch.ifft(f0, signal_ndim=1)
return f2
示例14: fftshift
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fftshift(x, dim=None):
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [dim // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = x.shape[dim] // 2
else:
shift = [x.shape[i] // 2 for i in dim]
return roll(x, shift, dim)
示例15: ifftshift
# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def ifftshift(x, dim=None):
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
"""
if dim is None:
dim = tuple(range(x.dim()))
shift = [(dim + 1) // 2 for dim in x.shape]
elif isinstance(dim, int):
shift = (x.shape[dim] + 1) // 2
else:
shift = [(x.shape[i] + 1) // 2 for i in dim]
return roll(x, shift, dim)