Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_scriptmodule_MFCC(self):
tensor = torch.rand((1, 1000), device="cuda")
self._test_script_module(tensor, transforms.MFCC)
# check defaults
torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321)
self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through
melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641)
self.assertTrue(torch_mfcc2.shape[2] == 641)
# check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321)
norm_check = torch_mfcc.clone()
norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
# test s2db
db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
# test MFCC
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
# kwargs for melspectrogram and mfcc are the same. We just follow the
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
mel_kwargs['win_length'] = n_window_size
mel_kwargs['hop_length'] = n_window_stride
# Set window_fn. None defaults to torch.ones.
window_fn = self.torch_windows.get(window, None)
if window_fn is None:
raise ValueError(
f"Window argument for AudioProcessor is invalid: {window}."
f"For no window function, use 'ones' or None.")
mel_kwargs['window_fn'] = window_fn
# Use torchaudio's implementation of MFCCs as featurizer
self.featurizer = torchaudio.transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
dct_type=dct_type,
norm=norm,
log_mels=log,
melkwargs=mel_kwargs
)
self.featurizer.to(self._device)
mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
mel_kwargs['win_length'] = n_window_size
mel_kwargs['hop_length'] = n_window_stride
# Set window_fn. None defaults to torch.ones.
window_fn = self.torch_windows.get(window, None)
if window_fn is None:
raise ValueError(
f"Window argument for AudioProcessor is invalid: {window}."
f"For no window function, use 'ones' or None.")
mel_kwargs['window_fn'] = window_fn
# Use torchaudio's implementation of MFCCs as featurizer
self.featurizer = torchaudio.transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
dct_type=dct_type,
norm=norm,
log_mels=log,
melkwargs=mel_kwargs
)
self.featurizer.to(self._device)
class AmplitudeToDB:
forward = torchaudio.transforms.AmplitudeToDB().forward
class MelScale:
forward = torchaudio.transforms.MelScale().forward
class MelSpectrogram:
forward = torchaudio.transforms.MelSpectrogram().forward
class MFCC:
forward = torchaudio.transforms.MFCC().forward
class MuLawEncoding:
forward = torchaudio.transforms.MuLawEncoding().forward
class MuLawDecoding:
forward = torchaudio.transforms.MuLawDecoding().forward
class Resample:
# Resample isn't a script_method
forward = torchaudio.transforms.Resample.forward
def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None):
super(MFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type))
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.dct_type = dct_type
self.norm = norm
self.top_db = 80.0
self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
else:
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
if self.n_mfcc > self.MelSpectrogram.n_mels: