Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
m = np.minimum(y.shape[0] - framesize // 2, ref.shape[1])
sdr, sir, sar, perm = bss_eval_sources(
ref[:n_sources_target, :m, 0],
y[framesize // 2 : m + framesize // 2, :n_sources_target].T,
)
SDR.append(sdr)
SIR.append(sir)
if args.no_cb:
convergence_callback = None
# START BSS
###########
# shape: (n_frames, n_freq, n_mics)
X_all = pra.transform.analysis(
mics_signals.T, framesize, framesize // 2, win=win_a
).astype(np.complex128)
X_mics = X_all[:, :, :n_mics]
tic = time.perf_counter()
# Run BSS
if args.algo == "auxiva":
# Run AuxIVA
Y = overiva(
X_mics,
n_iter=n_iter,
proj_back=True,
model=args.dist,
callback=convergence_callback,
)
# Mix down the recorded signals
mix = np.sum(premix[:n_targets], axis=0) + background
# shape (n_targets+1, n_samples, n_mics)
ref = np.zeros((n_targets+1, premix.shape[2], premix.shape[1]), dtype=premix.dtype)
ref[:n_targets, :, :] = premix[:n_targets, :, :].swapaxes(1, 2)
ref[n_targets, :, :] = background.T
synth = np.zeros_like(ref)
synth[n_targets, :, 0] = np.random.randn(synth.shape[1]) # fill this to compare to background
# START BSS
###########
# shape: (n_frames, n_freq, n_mics)
X_all = pra.transform.analysis(mix.T, framesize, framesize // 2, win=win_a)
X_mics = X_all[:, :, :n_mics]
# convergence monitoring callback
def convergence_callback(Y, n_targets, SDR, SIR, ref, framesize, win_s, algo_name):
from mir_eval.separation import bss_eval_sources
if Y.shape[2] == 1:
y = pra.transform.synthesis(
Y[:, :, 0], framesize, framesize // 2, win=win_s
)[:, None]
else:
y = pra.transform.synthesis(Y, framesize, framesize // 2, win=win_s)
if algo_name not in parameters["overdet_algos"]:
new_ord = np.argsort(np.std(y, axis=0))[::-1]
y = y[:, new_ord]
L = args.block
# Let's hard code sampling frequency to avoid some problems
fs = 16000
## RECORD
if args.device is not None:
sd.default.device[0] = args.device
## MIXING
print('* Recording started... ', end='')
mics_signals = sd.rec(int(args.duration * fs), samplerate=fs, channels=2, blocking=True)
print('done')
## STFT ANALYSIS
# shape == (n_chan, n_frames, n_freq)
X = pra.transform.analysis(mics_signals.T, L, L, zp_back=L//2, zp_front=L//2)
## Monitor convergence
it = 10
def cb_print(*args):
global it
print(' AuxIVA Iter', it)
it += 10
## Run live BSS
print('* Starting BSS')
bss_type = args.algo
if bss_type == 'auxiva':
# Run AuxIVA
Y = pra.bss.auxiva(X, n_iter=args.n_iter, proj_back=True, callback=cb_print)
elif bss_type == 'ilrma':
# Run ILRMA
raise NameError('No signal to beamform.')
if FD is True:
# STFT processing
if self.weights is None and self.filters is not None:
self.weights_from_filters()
elif self.weights is None and self.filters is None:
raise NameError('Beamforming weights or filters need to be '
'computed first.')
# create window functions
analysis_win = windows.hann(self.L)
# perform STFT
sig_stft = transform.analysis(self.signals.T,
L=self.L,
hop=self.hop,
win=analysis_win,
zp_back=self.zpb,
zp_front=self.zpf)
# beamform
sig_stft_bf = np.sum(sig_stft * self.weights.conj().T, axis=2)
# back to time domain
output = transform.synthesis(sig_stft_bf,
L=self.L,
hop=self.hop,
zp_back=self.zpb,
zp_front=self.zpf)
## Monitor Convergence
ref = np.moveaxis(separate_recordings, 1, 2)
SDR, SIR = [], []
def convergence_callback(Y):
global SDR, SIR
from mir_eval.separation import bss_eval_sources
ref = np.moveaxis(separate_recordings, 1, 2)
y = pra.transform.synthesis(Y, L, hop, win=win_s)
y = y[L-hop: , :].T
m = np.minimum(y.shape[1], ref.shape[1])
sdr, sir, sar, perm = bss_eval_sources(ref[:, :m, 0], y[:, :m])
SDR.append(sdr)
SIR.append(sir)
## STFT ANALYSIS
X = pra.transform.analysis(mics_signals.T, L, hop, win=win_a)
t_begin = time.perf_counter()
## START BSS
bss_type = args.algo
if bss_type == 'auxiva':
# Run AuxIVA
Y = pra.bss.auxiva(X, n_iter=30, proj_back=True,
callback=convergence_callback)
elif bss_type == 'ilrma':
# Run ILRMA
Y = pra.bss.ilrma(X, n_iter=30, n_components=2, proj_back=True,
callback=convergence_callback)
elif bss_type == 'fastmnmf':
# Run FastMNMF
Y = pra.bss.fastmnmf(X, n_iter=100, n_components=8, n_src=2,