How to use the torchaudio.compliance.kaldi.resample_waveform function in torchaudio

To help you get started, we’ve selected a few torchaudio examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / audio / test / test_compliance_kaldi.py View on Github external
def test_resample_waveform_multi_channel(self):
        num_channels = 3

        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)  # (1, 8000)
        multi_sound = sound.repeat(num_channels, 1)  # (num_channels, 8000)

        for i in range(num_channels):
            multi_sound[i, :] *= (i + 1) * 1.5

        multi_sound_sampled = kaldi.resample_waveform(multi_sound, sample_rate, sample_rate // 2)

        # check that sampling is same whether using separately or in a tensor of size (c, n)
        for i in range(num_channels):
            single_channel = sound * (i + 1) * 1.5
            single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
            self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
github pytorch / audio / test / test_compliance_kaldi.py View on Github external
def test_resample_waveform_downsample_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate // 2)
        self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)
github pytorch / audio / test / test_compliance_kaldi.py View on Github external
# resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
            new_sample_rate *= up_scale_factor

        if down_scale_factor is not None:
            new_sample_rate //= down_scale_factor

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
        estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()

        new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

        self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))
github pytorch / audio / test / test_compliance_kaldi.py View on Github external
def test_resample_waveform_identity_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate)
        self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
github pytorch / audio / test / test_compliance_kaldi.py View on Github external
def test_resample_waveform_upsample_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        upsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
        self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
github pytorch / audio / test / test_compliance_kaldi.py View on Github external
def get_output_fn(sound, args):
            output = kaldi.resample_waveform(sound, args[1], args[2])
            return output
github pytorch / audio / torchaudio / transforms.py View on Github external
def forward(self, waveform):
        r"""
        Args:
            waveform (torch.Tensor): The input signal of dimension (channel, time)

        Returns:
            torch.Tensor: Output signal of dimension (channel, time)
        """
        if self.resampling_method == 'sinc_interpolation':
            return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))