Python torch.stft方法代码示例

本文整理汇总了Python中torch.stft方法的典型用法代码示例。


示例1: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None) -> None:
        super(Spectrogram, self).__init__()
        self.n_fft = n_fft
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.pad = pad
        self.power = power
        self.normalized = normalized 

示例2: test_istft_requires_nola

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok) 

示例3: _test_istft_of_sine

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
                                               window=torch.ones(L), center=False, normalized=False)
        # There is a larger error due to the scaling of amplitude
        _compare_estimate(sound, estimate, atol=1e-3) 

示例4: compute_torch_stft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def compute_torch_stft(audio, descriptor):

    name, *args = descriptor.split("_")

    n_fft, hop_size, *rest = args
    n_fft = int(n_fft)
    hop_size = int(hop_size)

    stft = torch.stft(
        window=torch.hann_window(n_fft, device=audio.device)

    stft = torch.sqrt((stft ** 2).sum(-1))

    return stft 

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
        if self.preemp is not None:
            x = x.unsqueeze(1)
            x = self.preemp(x)
            x = x.squeeze(1)
        stft = torch.stft(x,
        real = stft[:, :, :, 0]
        im = stft[:, :, :, 1]
        spec = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))

        # convert linear spec to mel
        mel = torch.matmul(spec, self.mel_basis)
        # convert to db
        mel = _amp_to_db(mel) - hparams.ref_level_db
        return _normalize(mel) 

示例6: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __call__(self, wav):
        with torch.no_grad():
            # STFT
            data = torch.stft(wav, n_fft=self.nfft, hop_length=self.window_shift,
                              win_length=self.window_size, window=self.window)
            data /= self.window.pow(2).sum().sqrt_()
            #mag = data.pow(2).sum(-1).log1p_()
            #ang = torch.atan2(data[:, :, 1], data[:, :, 0])
            ## {mag, phase} x n_freq_bin x n_frame
            #data = torch.cat([mag.unsqueeze_(0), ang.unsqueeze_(0)], dim=0)
            ## FxTx2 -> 2xFxT
            data = data.transpose(1, 2).transpose(0, 1)
            return data

# transformer: frame splitter 

示例7: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, audio):
        p = (self.n_fft - self.hop_length) // 2
        audio = F.pad(audio, (p, p), "reflect").squeeze(1)
        fft = torch.stft(
        real_part, imag_part = fft.unbind(-1)
        magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
        mel_output = torch.matmul(self.mel_basis, magnitude)
        log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
        return log_mel_spec 

示例8: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
        Input: (nb_samples, nb_channels, nb_timesteps)
        Output:(nb_samples, nb_channels, nb_bins, nb_frames, 2)

        nb_samples, nb_channels, nb_timesteps = x.size()

        # merge nb_samples and nb_channels for multichannel stft
        x = x.reshape(nb_samples*nb_channels, -1)

        # compute stft with parameters as close as possible scipy settings
        stft_f = torch.stft(
            n_fft=self.n_fft, hop_length=self.n_hop,
            window=self.window, center=self.center,
            normalized=False, onesided=True,

        # reshape back to channel dimension
        stft_f = stft_f.contiguous().view(
            nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2
        return stft_f 

示例9: test_istft

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft(self):
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4) 

示例10: test_batch_TimeStretch

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_batch_TimeStretch(self):
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) 

示例11: test_istft_of_ones

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        _compare_estimate(torch.ones(4), estimate) 

示例12: test_istft_of_zeros

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        _compare_estimate(torch.zeros(4), estimate) 

示例13: test_istft_requires_overlap_windows

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
                          hop_length=20, win_length=1, window=torch.ones(1)) 

示例14: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
        x = torch.stft(x, self.n_fft, **self.stft_kwargs).norm(dim=-1, p=2)
        x = self.f2m(x.permute(0, 2, 1))
        if self.use_cuda_kernel:
            x, ls = pcen_cuda_kernel(x, self.eps, self.s, self.alpha, self.delta, self.r, self.trainable, self.last_state, self.empty)
            x, ls = pcen(x, self.eps, self.s, self.alpha, self.delta, self.r, self.training and self.trainable, self.last_state, self.empty)
        self.last_state = ls.detach()
        self.empty = False
        return x 

示例15: __call__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __call__(self, pkg, cached_file=None):
        pkg = format_package(pkg)
        wav = pkg['chunk']
        max_frames = wav.size(0) // self.hop
        if cached_file is not None:
            # load pre-computed data
            X = torch.load(cached_file)
            beg_i = pkg['chunk_beg_i'] // self.hop
            end_i = pkg['chunk_end_i'] // self.hop
            X = X[:, beg_i:end_i]
            pkg['lps'] = X
            #print ('Chunks wav shape is {}'.format(wav.shape))
            wav = wav.to(self.device)
            X = torch.stft(wav, self.n_fft,
                           self.hop, self.win)
            X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
            X = 10 * torch.log10(X ** 2 + 10e-20).cpu()
            if self.der_order > 0 :
                for n in range(1,self.der_order+1):
            pkg[self.name] = X
        # Overwrite resolution to hop length
        pkg['dec_resolution'] = self.hop
        return pkg 
