当前位置: 首页>>代码示例>>Python>>正文


Python torch.fft方法代码示例

本文整理汇总了Python中torch.fft方法的典型用法代码示例。如果您正苦于以下问题:Python torch.fft方法的具体用法?Python torch.fft怎么用?Python torch.fft使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.fft方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: fft2

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft2(data):
    """
    Apply centered 2 dimensional Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The FFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.fft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:19,代码来源:transforms.py

示例2: pad_rfft3

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:21,代码来源:torch_spec_operator.py

示例3: pad_fft2

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_fft2(f):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1]
    """
    n0, n1 = f.shape[-2:]
    h0, h1 = int(n0/2), int(n1/2)
    # turn f into complex signal
    f = torch.stack((f, torch.zeros_like(f)), dim=-1) # [..., res0, res1, 2]

    F1 = torch.fft(f, signal_ndim=1) # [..., res0, res1, 2]
    F1[..., h1,:] = 0 # [..., res0, res1, 2]

    F0 = torch.fft(F1.transpose(-3,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-3)
    return F0 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:19,代码来源:torch_spec_operator.py

示例4: rfftfreqs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def rfftfreqs(res, dtype=torch.float32, exact=True):
    """
    Helper function to return frequency tensors
    :param res: n_dims int tuple of number of frequency modes
    :return: frequency tensor of shape [dim, res, res, res/2+1]
    """
    # print("res",res)
    n_dims = len(res)
    freqs = []
    for dim in range(n_dims - 1):
        r_ = res[dim]
        freq = np.fft.fftfreq(r_, d=1/r_)
        freqs.append(torch.tensor(freq, dtype=dtype))
    r_ = res[-1]
    if exact:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype))
    else:
        freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype))
    omega = torch.meshgrid(freqs)
    omega = list(omega)
    omega = torch.stack(omega, dim=0)

    # print("omega.shape",omega.shape)
    return omega 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:26,代码来源:torch_spec_operator.py

示例5: test_butterfly_fft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def test_butterfly_fft():
    # DFT matrix for n = 4
    size = 4
    DFT = torch.fft(real_to_complex(torch.eye(size)), 1)
    P = real_to_complex(torch.tensor([[1., 0., 0., 0.],
                                      [0., 0., 1., 0.],
                                      [0., 1., 0., 0.],
                                      [0., 0., 0., 1.]]))
    M0 = Butterfly(size,
                   diagonal=2,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True))
    M1 = Butterfly(size,
                   diagonal=1,
                   complex=True,
                   diag=torch.tensor([[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True),
                   subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True),
                   superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True))
    assert torch.allclose(complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT)
    br_perm = torch.tensor(bitreversal_permutation(size))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT)
    D = complex_matmul(DFT, P.transpose(0, 1))
    assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D) 
开发者ID:HazyResearch,项目名称:learning-circuits,代码行数:27,代码来源:butterfly_old.py

示例6: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
开发者ID:jiangtaoxie,项目名称:fast-MPN-COV,代码行数:31,代码来源:CBP.py

示例7: fft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft(t):
    return torch.fft(t, 2) 
开发者ID:cszn,项目名称:KAIR,代码行数:4,代码来源:utils_deblur.py

示例8: get_uperleft_denominator

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def get_uperleft_denominator(img, kernel):
    '''
    img: HxWxC
    kernel: hxw
    denominator: HxWx1
    upperleft: HxWxC
    '''
    V = psf2otf(kernel, img.shape[:2])
    denominator = np.expand_dims(np.abs(V)**2, axis=2)
    upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1])
    return upperleft, denominator 
开发者ID:cszn,项目名称:KAIR,代码行数:13,代码来源:utils_deblur.py

示例9: otf2psf

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def otf2psf(otf, outsize=None):
    insize = np.array(otf.shape)
    psf = np.fft.ifftn(otf, axes=(0, 1))
    for axis, axis_size in enumerate(insize):
        psf = np.roll(psf, np.floor(axis_size / 2).astype(int), axis=axis)
    if type(outsize) != type(None):
        insize = np.array(otf.shape)
        outsize = np.array(outsize)
        n = max(np.size(outsize), np.size(insize))
        # outsize = postpad(outsize(:), n, 1);
        # insize = postpad(insize(:) , n, 1);
        colvec_out = outsize.flatten().reshape((np.size(outsize), 1))
        colvec_in = insize.flatten().reshape((np.size(insize), 1))
        outsize = np.pad(colvec_out, ((0, max(0, n - np.size(colvec_out))), (0, 0)), mode="constant")
        insize = np.pad(colvec_in, ((0, max(0, n - np.size(colvec_in))), (0, 0)), mode="constant")

        pad = (insize - outsize) / 2
        if np.any(pad < 0):
            print("otf2psf error: OUTSIZE must be smaller than or equal than OTF size")
        prepad = np.floor(pad)
        postpad = np.ceil(pad)
        dims_start = prepad.astype(int)
        dims_end = (insize - postpad).astype(int)
        for i in range(len(dims_start.shape)):
            psf = np.take(psf, range(dims_start[i][0], dims_end[i][0]), axis=i)
    n_ops = np.sum(otf.size * np.log2(otf.shape))
    psf = np.real_if_close(psf, tol=n_ops)
    return psf


# psf2otf copied/modified from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py 
开发者ID:cszn,项目名称:KAIR,代码行数:33,代码来源:utils_deblur.py

示例10: fft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fft(t):
    # Complex-to-complex Discrete Fourier Transform
    return torch.fft(t, 2) 
开发者ID:cszn,项目名称:KAIR,代码行数:5,代码来源:network_usrnet.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def forward(self, head, rel, tail):
        h_e, r_e, t_e = self.embed(head, rel, tail)
        r_e = F.normalize(r_e, p=2, dim=-1)
        h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
        t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
        e, _ = torch.unbind(torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1), -1)
        return -F.sigmoid(torch.sum(r_e * e, 1)) 
开发者ID:Sujit-O,项目名称:pykg2vec,代码行数:9,代码来源:pairwise.py

示例12: pad_irfft3

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_irfft3(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    res = F.shape[-3]
    f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
    f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
    return f2 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:12,代码来源:torch_spec_operator.py

示例13: pad_ifft2

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def pad_ifft2(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f1 = torch.ifft(f0, signal_ndim=1)
    return f2 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:10,代码来源:torch_spec_operator.py

示例14: fftshift

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def fftshift(x, dim=None):
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [dim // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = x.shape[dim] // 2
    else:
        shift = [x.shape[i] // 2 for i in dim]
    return roll(x, shift, dim) 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:14,代码来源:transforms.py

示例15: ifftshift

# 需要导入模块: import torch [as 别名]
# 或者: from torch import fft [as 别名]
def ifftshift(x, dim=None):
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    """
    if dim is None:
        dim = tuple(range(x.dim()))
        shift = [(dim + 1) // 2 for dim in x.shape]
    elif isinstance(dim, int):
        shift = (x.shape[dim] + 1) // 2
    else:
        shift = [(x.shape[i] + 1) // 2 for i in dim]
    return roll(x, shift, dim) 
开发者ID:facebookresearch,项目名称:fastMRI,代码行数:14,代码来源:transforms.py


注:本文中的torch.fft方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。