当前位置: 首页>>代码示例>>Python>>正文


Python torch.rfft方法代码示例

本文整理汇总了Python中torch.rfft方法的典型用法代码示例。如果您正苦于以下问题:Python torch.rfft方法的具体用法?Python torch.rfft怎么用?Python torch.rfft使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.rfft方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: p2o

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def p2o(psf, shape):
    '''
    # psf: NxCxhxw
    # shape: [H,W]
    # otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf



# otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math. 
开发者ID:cszn,项目名称:KAIR,代码行数:20,代码来源:utils_deblur.py

示例2: p2o

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def p2o(psf, shape):
    '''
    Args:
        psf: NxCxhxw
        shape: [H,W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
开发者ID:cszn,项目名称:KAIR,代码行数:19,代码来源:utils_sisr.py

示例3: p2o

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def p2o(psf, shape):
    '''
    Convert point-spread function to optical transfer function.
    otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
    point-spread function (PSF) array and creates the optical transfer
    function (OTF) array that is not influenced by the PSF off-centering.

    Args:
        psf: NxCxhxw
        shape: [H, W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
开发者ID:cszn,项目名称:KAIR,代码行数:24,代码来源:network_usrnet.py

示例4: pad_rfft3

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
开发者ID:maxjiang93,项目名称:space_time_pde,代码行数:21,代码来源:torch_spec_operator.py

示例5: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def __init__(self, shape, sd=0.01, decay_power=1, init_img=None):
        super(SpectralImage, self).__init__()
        self.shape = shape
        ch, h, w = shape
        freqs = _rfft2d_freqs(h, w)
        fh, fw = freqs.shape
        self.decay_power = decay_power

        init_val = sd * torch.randn(ch, fh, fw, 2)
        spectrum_var = torch.nn.Parameter(init_val)
        self.spectrum_var = spectrum_var

        spertum_scale = 1.0 / np.maximum(freqs,
                                         1.0 / max(h, w))**self.decay_power
        spertum_scale *= np.sqrt(w * h)
        spertum_scale = torch.FloatTensor(spertum_scale).unsqueeze(-1)
        self.register_buffer('spertum_scale', spertum_scale)

        if init_img is not None:
            if init_img.shape[2] % 2 == 1:
                init_img = nn.functional.pad(init_img, (1, 0, 0, 0))
            fft = torch.rfft(init_img * 4, 2, onesided=True, normalized=False)
            self.spectrum_var.data.copy_(fft / spertum_scale) 
开发者ID:Vermeille,项目名称:Torchelie,代码行数:25,代码来源:data_learning.py

示例6: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(self, z):
        z, cond = z
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral transform 
        S_sig = torch.rfft(z, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_sig[:,:,:,0] - H[:,:,:,1] * S_sig[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_sig[:,:,:,1] + H[:,:,:,1] * S_sig[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(z.shape[0], -1)
        return filtered_noise 
开发者ID:acids-ircam,项目名称:ddsp_pytorch,代码行数:23,代码来源:filters.py

示例7: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(self, z):
        z, conditions = z
        # Pad the input sequence
        y = nn.functional.pad(z, (0, self.size), "constant", 0)
        # Compute STFT
        Y_S = torch.rfft(y, 1)
        # Compute the current impulse response
        idx = torch.sigmoid(self.wetdry) * self.identity
        imp = torch.sigmoid(1 - self.wetdry) * self.impulse
        dcy = torch.exp(-(torch.exp(self.decay) + 2) * torch.linspace(0,1, self.size).to(z.device))
        final_impulse = idx + imp * dcy
        # Pad the impulse response
        impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0)
        if y.shape[-1] > self.size:
            impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0)
        IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S)
        # Apply the reverb
        Y_S_CONV = torch.zeros_like(IR_S)
        Y_S_CONV[:,:,0] = Y_S[:,:,0] * IR_S[:,:,0] - Y_S[:,:,1] * IR_S[:,:,1]
        Y_S_CONV[:,:,1] = Y_S[:,:,0] * IR_S[:,:,1] + Y_S[:,:,1] * IR_S[:,:,0]
        # Invert the reverberated signal
        y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1],))
        return y 
开发者ID:acids-ircam,项目名称:ddsp_pytorch,代码行数:25,代码来源:effects.py

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(self, z):
        sig, conditions = z
        # Create noise source
        noise = torch.randn([sig.shape[0], sig.shape[1], self.block_size]).detach().to(sig.device).reshape(-1, self.block_size) * self.noise_att
        S_noise = torch.rfft(noise, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_noise[:,:,:,0] - H[:,:,:,1] * S_noise[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_noise[:,:,:,1] + H[:,:,:,1] * S_noise[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(sig.shape[0], -1)
        return filtered_noise 
开发者ID:acids-ircam,项目名称:ddsp_pytorch,代码行数:24,代码来源:generators.py

示例9: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(self, z, x):
        z = self.feature(z)
        x = self.feature(x)

        zf = torch.rfft(z, signal_ndim=2)
        xf = torch.rfft(x, signal_ndim=2)

        kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True)

        kzyf = tensor_complex_mulconj(zf, self.yf.to(device=z.device))

        solution =  tensor_complex_division(kzyf, kzzf + self.config.lambda0)

        response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2)

        return response 
开发者ID:huanglianghua,项目名称:open-vot,代码行数:18,代码来源:dcfnet.py

示例10: normalize_var

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def normalize_var(orig):
    batch_size = orig.size(0)

    # Spectral variance
    mean = torch.mean(orig.view(batch_size, -1), 1).view(batch_size, 1, 1, 1)
    spec_var = torch.rfft(torch.pow(orig -  mean, 2), 2)

    # Normalization
    imC = torch.sqrt(torch.irfft(spec_var, 2, signal_sizes=orig.size()[2:]).abs())
    imC /= torch.max(imC.view(batch_size, -1), 1)[0].view(batch_size, 1, 1, 1)
    minC = 0.001
    imK =  (minC + 1) / (minC + imC)

    mean, imK = mean.detach(), imK.detach()
    img = mean + (orig -  mean) * imK
    return normalize(img) 
开发者ID:ddkang,项目名称:advex-uar,代码行数:18,代码来源:gabor.py

示例11: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,))

        return re 
开发者ID:Cadene,项目名称:block.bootstrap.pytorch,代码行数:33,代码来源:compactbilinearpooling.py

示例12: get_uperleft_denominator_pytorch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def get_uperleft_denominator_pytorch(img, kernel):
    '''
    img: NxCxHxW
    kernel: Nx1xhxw
    denominator: Nx1xHxW
    upperleft: NxCxHxWx2
    '''
    V = p2o(kernel, img.shape[-2:])  # Nx1xHxWx2
    denominator = V[..., 0]**2+V[..., 1]**2  # Nx1xHxW
    upperleft = cmul(cconj(V), rfft(img))  # Nx1xHxWx2 * NxCxHxWx2
    return upperleft, denominator 
开发者ID:cszn,项目名称:KAIR,代码行数:13,代码来源:utils_deblur.py

示例13: rfft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def rfft(t):
    return torch.rfft(t, 2, onesided=False) 
开发者ID:cszn,项目名称:KAIR,代码行数:4,代码来源:utils_deblur.py

示例14: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf):
        FR = FBFy + torch.rfft(alpha*x, 2, onesided=False)
        x1 = cmul(FB, FR)
        FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
        invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
        invWBR = cdiv(FBR, csum(invW, alpha))
        FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
        FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1)
        Xest = torch.irfft(FX, 2, onesided=False)

        return Xest 
开发者ID:cszn,项目名称:KAIR,代码行数:13,代码来源:network_usrnet.py

示例15: _get_fft_basis

# 需要导入模块: import torch [as 别名]
# 或者: from torch import rfft [as 别名]
def _get_fft_basis(self):
        fourier_basis = torch.rfft(
            torch.eye(self.filter_length), 
            1, onesided=True
        )
        cutoff = 1 + self.filter_length // 2
        fourier_basis = torch.cat([
            fourier_basis[:, :cutoff, 0],
            fourier_basis[:, :cutoff, 1]
        ], dim=1)
        return fourier_basis.float() 
开发者ID:nussl,项目名称:nussl,代码行数:13,代码来源:filter_bank.py


注:本文中的torch.rfft方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。