本文整理汇总了Python中torch.stft方法的典型用法代码示例。如果您正苦于以下问题:Python torch.stft方法的具体用法?Python torch.stft怎么用?Python torch.stft使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.stft方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __init__(self,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.register_buffer('window', window)
self.pad = pad
self.power = power
self.normalized = normalized
示例2: test_istft_requires_nola
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_requires_nola(self):
stft = torch.zeros((3, 5, 2))
kwargs_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.ones(4),
}
kwargs_not_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.zeros(4),
}
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
示例3: _test_istft_of_sine
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
sound = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(sound, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2))
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val
if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val
estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
_compare_estimate(sound, estimate, atol=1e-3)
示例4: compute_torch_stft
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def compute_torch_stft(audio, descriptor):
name, *args = descriptor.split("_")
n_fft, hop_size, *rest = args
n_fft = int(n_fft)
hop_size = int(hop_size)
stft = torch.stft(
audio,
n_fft=n_fft,
hop_length=hop_size,
window=torch.hann_window(n_fft, device=audio.device)
)
stft = torch.sqrt((stft ** 2).sum(-1))
return stft
示例5: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
if self.preemp is not None:
x = x.unsqueeze(1)
x = self.preemp(x)
x = x.squeeze(1)
stft = torch.stft(x,
self.win_length,
self.hop_length,
fft_size=self.n_fft,
window=self.win)
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
spec = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))
# convert linear spec to mel
mel = torch.matmul(spec, self.mel_basis)
# convert to db
mel = _amp_to_db(mel) - hparams.ref_level_db
return _normalize(mel)
示例6: __call__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __call__(self, wav):
with torch.no_grad():
# STFT
data = torch.stft(wav, n_fft=self.nfft, hop_length=self.window_shift,
win_length=self.window_size, window=self.window)
data /= self.window.pow(2).sum().sqrt_()
#mag = data.pow(2).sum(-1).log1p_()
#ang = torch.atan2(data[:, :, 1], data[:, :, 0])
## {mag, phase} x n_freq_bin x n_frame
#data = torch.cat([mag.unsqueeze_(0), ang.unsqueeze_(0)], dim=0)
## FxTx2 -> 2xFxT
data = data.transpose(1, 2).transpose(0, 1)
return data
# transformer: frame splitter
示例7: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, audio):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audio, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=False,
)
real_part, imag_part = fft.unbind(-1)
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
mel_output = torch.matmul(self.mel_basis, magnitude)
log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
return log_mel_spec
示例8: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
"""
Input: (nb_samples, nb_channels, nb_timesteps)
Output:(nb_samples, nb_channels, nb_bins, nb_frames, 2)
"""
nb_samples, nb_channels, nb_timesteps = x.size()
# merge nb_samples and nb_channels for multichannel stft
x = x.reshape(nb_samples*nb_channels, -1)
# compute stft with parameters as close as possible scipy settings
stft_f = torch.stft(
x,
n_fft=self.n_fft, hop_length=self.n_hop,
window=self.window, center=self.center,
normalized=False, onesided=True,
pad_mode='reflect'
)
# reshape back to channel dimension
stft_f = stft_f.contiguous().view(
nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2
)
return stft_f
示例9: test_istft
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft(self):
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4)
示例10: test_batch_TimeStretch
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_batch_TimeStretch(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath) # (2, 278756), 44100
kwargs = {
'n_fft': 2048,
'hop_length': 512,
'win_length': 2048,
'window': torch.hann_window(2048),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}
rate = 2
complex_specgrams = torch.stft(waveform, **kwargs)
# Single then transform then batch
expected = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams).repeat(3, 1, 1, 1, 1)
# Batch then transform
computed = torchaudio.transforms.TimeStretch(
fixed_rate=rate,
n_freq=1025,
hop_length=512,
)(complex_specgrams.repeat(3, 1, 1, 1, 1))
self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
示例11: test_istft_of_ones
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4)
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.ones(4), estimate)
示例12: test_istft_of_zeros
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2))
estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
_compare_estimate(torch.zeros(4), estimate)
示例13: test_istft_requires_overlap_windows
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2))
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))
示例14: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def forward(self, x):
x = torch.stft(x, self.n_fft, **self.stft_kwargs).norm(dim=-1, p=2)
x = self.f2m(x.permute(0, 2, 1))
if self.use_cuda_kernel:
x, ls = pcen_cuda_kernel(x, self.eps, self.s, self.alpha, self.delta, self.r, self.trainable, self.last_state, self.empty)
else:
x, ls = pcen(x, self.eps, self.s, self.alpha, self.delta, self.r, self.training and self.trainable, self.last_state, self.empty)
self.last_state = ls.detach()
self.empty = False
return x
示例15: __call__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import stft [as 别名]
def __call__(self, pkg, cached_file=None):
pkg = format_package(pkg)
wav = pkg['chunk']
max_frames = wav.size(0) // self.hop
if cached_file is not None:
# load pre-computed data
X = torch.load(cached_file)
beg_i = pkg['chunk_beg_i'] // self.hop
end_i = pkg['chunk_end_i'] // self.hop
X = X[:, beg_i:end_i]
pkg['lps'] = X
else:
#print ('Chunks wav shape is {}'.format(wav.shape))
wav = wav.to(self.device)
X = torch.stft(wav, self.n_fft,
self.hop, self.win)
X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
X = 10 * torch.log10(X ** 2 + 10e-20).cpu()
if self.der_order > 0 :
deltas=[X]
for n in range(1,self.der_order+1):
deltas.append(librosa.feature.delta(X.numpy(),order=n))
X=torch.from_numpy(np.concatenate(deltas))
pkg[self.name] = X
# Overwrite resolution to hop length
pkg['dec_resolution'] = self.hop
return pkg