Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
'blackman_coeff', 'energy_floor', 'frame_length', 'frame_shift', 'high_freq', 'htk_compat',
'low_freq', 'num_mel_bins', 'preemphasis_coefficient', 'raw_energy', 'remove_dc_offset',
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
fn_split = fn.split('-')
assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}
# print flags for C++
s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
logging.info(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
logging.info()
# print args for python
inputs['dither'] = 0.0
logging.info(inputs)
sound, sample_rate = torchaudio.load_wav(sound_path)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
torch.set_printoptions(precision=10, sci_mode=False)
logging.info(res)
logging.info(kaldi_output_dict['my_id'])
def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
expected_num_args, get_output_fn, atol=1e-5, rtol=1e-8):
"""
Inputs:
sound_filepath (str): The location of the sound file
filepath_key (str): A key to `test_filepaths` which matches which files to use
expected_num_files (int): The expected number of kaldi files to read
expected_num_args (int): The expected number of arguments used in a kaldi configuration
get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
and a configuration and returns an output
atol (float): absolute tolerance
rtol (float): relative tolerance
"""
sound, sample_rate = torchaudio.load_wav(sound_filepath)
files = self.test_filepaths[filepath_key]
assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
for f in files:
print(f)
# Read kaldi's output from file
kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)}
assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
kaldi_output = kaldi_output_dict['my_id']
# Construct the same configuration used by kaldi
args = f.split('-')
def test_resample_waveform_identity_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
def test_resample_waveform_upsample_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
upsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
sound, sample_rate = torchaudio.load_wav(path)
output = kaldi.fbank(
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift
)
output_cmvn = data_utils.apply_mv_norm(output)
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
def __getitem__(self, index):
utt_id, path = self.file_list[index]
if self.from_kaldi:
feature = kio.load_mat(path)
else:
wavform, sample_frequency = ta.load_wav(path)
feature = compute_fbank(wavform, num_mel_bins=self.params['num_mel_bins'], sample_frequency=sample_frequency, dither=0.0)
if self.params['apply_cmvn']:
spk_id = self.utt2spk[utt_id]
stats = kio.load_mat(self.cmvns[spk_id])
feature = apply_cmvn(feature, stats)
if self.params['normalization']:
feature = normalization(feature)
if self.apply_spec_augment:
feature = spec_augment(feature)
feature_length = feature.shape[0]
targets = self.targets_dict[utt_id]
targets_length = len(targets)