Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
-------
Returns or updates `adata`
"""
adata = data.copy() if copy else data
logg.info('recovering dynamics', r=True)
if isinstance(var_names, str) and var_names not in adata.var_names:
var_names = adata.var_names[adata.var[var_names] == True] if 'genes' in var_names and var_names in adata.var.keys() \
else adata.var_names if 'all' in var_names or 'genes' in var_names else var_names
var_names = make_unique_list(var_names, allow_array=True)
var_names = [name for name in var_names if name in adata.var_names]
if len(var_names) == 0:
raise ValueError('Variable name not found in var keys.')
if fit_connected_states:
fit_connected_states = get_connectivities(adata)
alpha, beta, gamma, t_, scaling = read_pars(adata)
idx = []
L, P, T = [], [], adata.layers['fit_t'] if 'fit_t' in adata.layers.keys() else np.zeros(adata.shape) * np.nan
progress = logg.ProgressReporter(len(var_names))
for i, gene in enumerate(var_names):
dm = DynamicsRecovery(adata, gene, use_raw=use_raw, load_pars=load_pars, fit_time=fit_time, fit_alpha=fit_alpha,
fit_switching=fit_switching, fit_scaling=fit_scaling, fit_steady_states=fit_steady_states,
fit_connected_states=fit_connected_states)
if max_iter > 1:
dm.fit(max_iter, learning_rate, assignment_mode=assignment_mode, min_loss=min_loss, method=method, **kwargs)
ix = np.where(adata.var_names == gene)[0][0]
idx.append(ix)
strings_to_categoricals(adata)
if groupby is not None:
logg.warn(
"Only set groupby, when you have evident distinct clusters/lineages,"
" each with an own root and end point."
)
kwargs.update({"self_transitions": self_transitions})
categories = [None]
if groupby is not None and groups is None:
categories = adata.obs[groupby].cat.categories
for cat in categories:
groups = cat if cat is not None else groups
cell_subset = groups_to_bool(adata, groups=groups, groupby=groupby)
_adata = adata if groups is None else adata[cell_subset]
connectivities = get_connectivities(_adata, "distances")
T = transition_matrix(_adata, vkey=vkey, backward=True, **kwargs)
eigvecs_roots = eigs(T, eps=eps, perc=[2, 98], random_state=random_state)[1]
roots = csr_matrix.dot(connectivities, eigvecs_roots).sum(1)
roots = scale(np.clip(roots, 0, np.percentile(roots, 98)))
roots = verify_roots(_adata, roots)
write_to_obs(adata, "root_cells", roots, cell_subset)
T = transition_matrix(_adata, vkey=vkey, backward=False, **kwargs)
eigvecs_ends = eigs(T, eps=eps, perc=[2, 98], random_state=random_state)[1]
ends = csr_matrix.dot(connectivities, eigvecs_ends).sum(1)
ends = scale(np.clip(ends, 0, np.percentile(ends, 98)))
write_to_obs(adata, "end_points", ends, cell_subset)
n_roots, n_ends = eigvecs_roots.shape[1], eigvecs_ends.shape[1]
groups_str = f" ({groups})" if isinstance(groups, str) else ""
def set_iroot(self, root=None):
if (
isinstance(root, str)
and root in self._adata.obs.keys()
and self._adata.obs[root].max() != 0
):
self.iroot = get_connectivities(self._adata).dot(self._adata.obs[root])
self.iroot = scale(self.iroot).argmax()
elif isinstance(root, str) and root in self._adata.obs_names:
self.iroot = np.where(self._adata.obs_names == root)[0][0]
elif isinstance(root, (int, np.integer)) and root < self._adata.n_obs:
self.iroot = root
else:
self.iroot = None
pars = read_pars(adata)
alpha, beta, gamma, t_, scaling, std_u, std_s, likelihood = pars[:8]
u0, s0, pval, steady_u, steady_s, varx = pars[8:]
# likelihood[np.isnan(likelihood)] = 0
idx, L, P = [], [], []
T = np.zeros(adata.shape) * np.nan
Tau = np.zeros(adata.shape) * np.nan
Tau_ = np.zeros(adata.shape) * np.nan
if "fit_t" in adata.layers.keys():
T = adata.layers["fit_t"]
if "fit_tau" in adata.layers.keys():
Tau = adata.layers["fit_tau"]
if "fit_tau_" in adata.layers.keys():
Tau_ = adata.layers["fit_tau_"]
conn = get_connectivities(adata) if fit_connected_states else None
progress = logg.ProgressReporter(len(var_names))
for i, gene in enumerate(var_names):
dm = DynamicsRecovery(
adata,
gene,
use_raw=use_raw,
load_pars=load_pars,
max_iter=max_iter,
fit_time=fit_time,
fit_steady_states=fit_steady_states,
fit_connected_states=conn,
fit_scaling=fit_scaling,
fit_basal_transcription=fit_basal_transcription,
steady_state_prior=steady_state_prior,
**kwargs,
)
try:
self.initialize_weights()
except:
self.recoverable = False
logg.warn(f'Model for {self.gene} could not be instantiated.')
self.refit_time = fit_time
self.assignment_mode = None
self.steady_state_ratio = None
self.steady_state_prior = steady_state_prior
self.fit_scaling = fit_scaling
self.fit_steady_states = fit_steady_states
self.fit_connected_states = fit_connected_states
self.connectivities = get_connectivities(adata) if self.fit_connected_states is True else self.fit_connected_states
self.high_pars_resolution = high_pars_resolution
self.init_vals = init_vals
# for differential kinetic test
self.clusters, self.cats, self.varx, self.orth_beta = None, None, None, None
self.diff_kinetics, self.pval_kinetics, self.pvals_kinetics = None, None, None
u = make_dense(_layers['unspliced']) if use_raw else make_dense(_layers['Mu'])
s = make_dense(_layers['spliced']) if use_raw else make_dense(_layers['Ms'])
self.s, self.u = s, u
# set weights for fitting (exclude dropouts and extreme outliers)
nonzero = np.ravel(s > 0) & np.ravel(u > 0)
s_filter = np.ravel(s < np.percentile(s[nonzero], 98))
u_filter = np.ravel(u < np.percentile(u[nonzero], 98))
self.weights = s_filter & u_filter & nonzero
self.fit_scaling = fit_scaling
self.fit_time = fit_time
self.fit_alpha = fit_alpha
self.fit_switching = fit_switching
self.fit_steady_states = fit_steady_states
self.connectivities = get_connectivities(adata) if fit_connected_states is True else fit_connected_states
if load_pars and 'fit_alpha' in adata.var.keys():
self.load_pars(adata, gene)
else:
self.initialize()
keys = [key for key in terminal_keys if key in adata.obs.keys()]
if len(keys) > 0:
root_key = keys[0]
if root_key not in adata.uns.keys() and root_key not in adata.obs.keys():
root_key = "root_cells"
if root_key not in adata.obs.keys():
terminal_states(adata, vkey=vkey)
t = np.array(adata.layers["fit_t"])
idx_valid = ~np.isnan(t.sum(0))
if min_likelihood is not None:
likelihood = adata.var["fit_likelihood"].values
idx_valid &= np.array(likelihood >= min_likelihood, dtype=bool)
t = t[:, idx_valid]
t_sum = np.sum(t, 1)
conn = get_connectivities(adata)
if root_key not in adata.uns.keys():
roots = np.argsort(t_sum)
idx_roots = np.array(adata.obs[root_key][roots])
idx_roots[pd.isnull(idx_roots)] = 0
if np.any([isinstance(ix, str) for ix in idx_roots]):
idx_roots = np.array([isinstance(ix, str) for ix in idx_roots], dtype=int)
idx_roots = idx_roots.astype(np.float) > 1 - 1e-3
if np.sum(idx_roots) > 0:
roots = roots[idx_roots]
else:
logg.warn(
"No root cells detected. Consider specifying "
"root cells to improve latent time prediction."
)
else:
'constraint_time_increments': False, 'fit_steady_states': True, 'fit_basal_transcription': None,
'std_u': std_u, 'std_s': std_s, 'pval_steady': pval_steady, 'steady_u': steady_u, 'steady_s': steady_s}
kwargs_.update(adata.uns['recover_dynamics'])
kwargs_.update(**kwargs)
reg_time = None
if use_latent_time is True: use_latent_time = 'latent_time'
if isinstance(use_latent_time, str) and use_latent_time in adata.obs.keys():
reg_time = adata.obs[use_latent_time].values
u, s = get_reads(vdata, use_raw=kwargs_['use_raw'])
if kwargs_['fit_basal_transcription']: u, s = u - u0, s - s0
tau = np.array(vdata.layers['fit_tau']) if 'fit_tau' in vdata.layers.keys() else None
tau_ = np.array(vdata.layers['fit_tau_']) if 'fit_tau_' in vdata.layers.keys() else None
res = compute_divergence(u, s, alpha, beta, gamma, scaling, t_, tau=tau, tau_=tau_, reg_time=reg_time, mode=mode,
connectivities=get_connectivities(adata) if use_connectivities else None, **kwargs_)
return res