Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Things are more efficient when the frequencies are over the first axis
Y = np.zeros((n_freq, n_frames, n_src), dtype=X.dtype)
X = X.swapaxes(0, 1).copy()
# Compute the demixed output
def demix(Y, X, W):
Y[:, :, :] = X @ np.conj(W)
for epoch in range(n_iter):
demix(Y, X, W)
if callback is not None and epoch % 10 == 0:
Y_tmp = Y.swapaxes(0, 1)
if proj_back:
z = projection_back(Y_tmp, X[:, :, 0].swapaxes(0, 1))
callback(Y_tmp * np.conj(z[None, :, :]))
else:
callback(Y_tmp)
# simple loop as a start
# shape: (n_frames, n_src)
if model == 'laplace':
r_inv[:, :] = 1. / (2. * np.linalg.norm(Y, axis=0))
elif model == 'gauss':
r_inv[:, :] = n_freq / (np.linalg.norm(Y, axis=0) ** 2)
# Update now the demixing matrix
for s in range(n_src):
# Compute Auxiliary Variable
# shape: (n_freq, n_chan, n_chan)
V[:, :, :] = (X.swapaxes(1, 2) * r_inv[None, None, :, s]) @ np.conj(X) / n_frames
# activations of background
r[:, -1] = np.mean(np.abs(Z * np.conj(Z)), axis=(1, 2))
if epoch % 3 == 0:
the_cost.append(cost_func(r, A))
plt.figure()
plt.plot(np.arange(len(the_cost)) * 3, the_cost)
plt.title("The cost function")
plt.xlabel("Number of iterations")
plt.ylabel("Neg. log-likelihood")
if proj_back:
print("proj back!")
z = projection_back(Y, X[:, :, 0])
Y *= np.conj(z[None, :, :])
if return_filters:
return Y, W, A
else:
return Y
WV = np.conj(W).swapaxes(1, 2) @ V
rhs = I[None, :, s][[0] * WV.shape[0], :]
W[:, :, s] = np.linalg.solve(WV, rhs)
# normalize
denom = np.conj(W[:, None, :, s]) @ V[:, :, :] @ W[:, :, None, s]
W[:, :, s] /= np.sqrt(denom[:, :, 0])
demix(Y, X, W)
Y = Y.swapaxes(0, 1).copy()
X = X.swapaxes(0, 1)
if proj_back:
z = projection_back(Y, X[:, :, 0])
Y *= np.conj(z[None, :, :])
if return_filters:
return Y, W
else:
return Y
X_ref = X # keep a reference to input signal
X = X.swapaxes(0, 1).copy() # more efficient order for processing
for epoch in range(n_iter):
# compute the switching criterion
if update == "switching" and epoch % 10 == 0:
switching_criterion()
# Extract the target signal
demix(Y, X, w)
# Now run any necessary callback
if callback is not None and epoch % 100 == 0:
Y_tmp = Y.swapaxes(0, 1)
if proj_back:
z = projection_back(Y_tmp, X_ref[:, :, 0])
callback(Y_tmp * np.conj(z[None, :, :]))
else:
callback(Y_tmp)
# simple loop as a start
# shape: (n_frames, n_src)
if model == "laplace":
r[:, :] = np.linalg.norm(Y, axis=0) / np.sqrt(n_freq)
elif model == "gauss":
r[:, :] = (np.linalg.norm(Y, axis=0) ** 2) / n_freq
eps = 1e-15
r[r < eps] = eps
r_inv[:, :] = 1.0 / r
update_a_from_w(I_do_w)
update_w_from_a(I_do_a)
max_delta = np.max(np.linalg.norm(delta, axis=(1, 2)))
if max_delta < tol:
break
# Extract target
demix(Y, X, w)
Y = Y.swapaxes(0, 1).copy()
X = X.swapaxes(0, 1)
if proj_back:
z = projection_back(Y, X_ref[:, :, 0])
Y *= np.conj(z[None, :, :])
if return_filters:
return Y, w
else:
return Y
# Things are more efficient when the frequencies are over the first axis
Y = np.zeros((n_freq, n_frames, n_src), dtype=X.dtype)
X = X.swapaxes(0, 1).copy()
# Compute the demixed output
def demix(Y, X, W):
Y[:, :, :] = X @ np.conj(W)
for epoch in range(n_iter):
demix(Y, X, W)
if callback is not None and epoch % 10 == 0:
Y_tmp = Y.swapaxes(0, 1)
if proj_back:
z = projection_back(Y_tmp, X[:, :, 0].swapaxes(0, 1))
callback(Y_tmp * np.conj(z[None, :, :]))
else:
callback(Y_tmp)
# simple loop as a start
# shape: (n_frames, n_src)
if model == 'laplace':
r[:, :] = (2. * np.linalg.norm(Y, axis=0))
elif model == 'gauss':
r[:, :] = (np.linalg.norm(Y, axis=0) ** 2) / n_freq
# set the scale of r
gamma = r.mean(axis=0)
r /= gamma[None, :]
if model == 'laplace':
X_matlab, step_size, aini, n_iter, "sign", nargout=4
)
elif update == "demix":
# Run the MATLAB versio of OGIVE_w, updates of demix vector
w, a, shat, numit = eng.ogive_w(
X_matlab, step_size, aini, n_iter, "sign", nargout=4
)
else:
raise ValueError(f"Unknown update type {update}")
# Now convert back the output (shat, shape=(n_freq, n_frames)
Y = np.array(shat)
Y = Y[:, :, None].transpose([1, 0, 2]).copy()
if proj_back:
z = projection_back(Y, X[:, :, 0])
Y *= np.conj(z[None, :, :])
if callback is not None:
callback(Y)
return Y
# shape: (n_frames, n_src)
r[:, :] = np.mean(np.abs(Y * np.conj(Y)), axis=1)
if epoch % 3 == 0:
the_cost.append(cost_func(r))
plt.figure()
plt.plot(np.arange(len(the_cost)) * 3, the_cost)
plt.title("The cost function")
plt.xlabel("Number of iterations")
plt.ylabel("Neg. log-likelihood")
if proj_back:
print("proj back!")
z = projection_back(Y, X[:, :, 0])
Y *= np.conj(z[None, :, :])
if return_filters:
return Y, W
else:
return Y