Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_resample_waveform_multi_channel(self):
num_channels = 3
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath) # (1, 8000)
multi_sound = sound.repeat(num_channels, 1) # (num_channels, 8000)
for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5
multi_sound_sampled = kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = sound * (i + 1) * 1.5
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
def test_resample_waveform_downsample_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate // 2)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
new_sample_rate = sample_rate
if up_scale_factor is not None:
new_sample_rate *= up_scale_factor
if down_scale_factor is not None:
new_sample_rate //= down_scale_factor
duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
# trim the first/last n samples as these points have boundary effects
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))
def test_resample_waveform_identity_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
def test_resample_waveform_upsample_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
upsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
return output
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): The input signal of dimension (channel, time)
Returns:
torch.Tensor: Output signal of dimension (channel, time)
"""
if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))