How to use the scvelo.preprocessing.moments.get_connectivities function in scvelo

To help you get started, we’ve selected a few scvelo examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github theislab / scvelo / scvelo / tools / dynamical_model_utils_deprecated.py View on Github external
-------
    Returns or updates `adata`
    """
    adata = data.copy() if copy else data
    logg.info('recovering dynamics', r=True)

    if isinstance(var_names, str) and var_names not in adata.var_names:
        var_names = adata.var_names[adata.var[var_names] == True] if 'genes' in var_names and var_names in adata.var.keys() \
            else adata.var_names if 'all' in var_names or 'genes' in var_names else var_names
    var_names = make_unique_list(var_names, allow_array=True)
    var_names = [name for name in var_names if name in adata.var_names]
    if len(var_names) == 0:
        raise ValueError('Variable name not found in var keys.')

    if fit_connected_states:
        fit_connected_states = get_connectivities(adata)

    alpha, beta, gamma, t_, scaling = read_pars(adata)
    idx = []
    L, P, T = [], [], adata.layers['fit_t'] if 'fit_t' in adata.layers.keys() else np.zeros(adata.shape) * np.nan

    progress = logg.ProgressReporter(len(var_names))
    for i, gene in enumerate(var_names):
        dm = DynamicsRecovery(adata, gene, use_raw=use_raw, load_pars=load_pars, fit_time=fit_time, fit_alpha=fit_alpha,
                              fit_switching=fit_switching, fit_scaling=fit_scaling, fit_steady_states=fit_steady_states,
                              fit_connected_states=fit_connected_states)
        if max_iter > 1:
            dm.fit(max_iter, learning_rate, assignment_mode=assignment_mode, min_loss=min_loss, method=method, **kwargs)

        ix = np.where(adata.var_names == gene)[0][0]
        idx.append(ix)
github theislab / scvelo / scvelo / tools / terminal_states.py View on Github external
strings_to_categoricals(adata)
    if groupby is not None:
        logg.warn(
            "Only set groupby, when you have evident distinct clusters/lineages,"
            " each with an own root and end point."
        )

    kwargs.update({"self_transitions": self_transitions})
    categories = [None]
    if groupby is not None and groups is None:
        categories = adata.obs[groupby].cat.categories
    for cat in categories:
        groups = cat if cat is not None else groups
        cell_subset = groups_to_bool(adata, groups=groups, groupby=groupby)
        _adata = adata if groups is None else adata[cell_subset]
        connectivities = get_connectivities(_adata, "distances")

        T = transition_matrix(_adata, vkey=vkey, backward=True, **kwargs)
        eigvecs_roots = eigs(T, eps=eps, perc=[2, 98], random_state=random_state)[1]
        roots = csr_matrix.dot(connectivities, eigvecs_roots).sum(1)
        roots = scale(np.clip(roots, 0, np.percentile(roots, 98)))
        roots = verify_roots(_adata, roots)
        write_to_obs(adata, "root_cells", roots, cell_subset)

        T = transition_matrix(_adata, vkey=vkey, backward=False, **kwargs)
        eigvecs_ends = eigs(T, eps=eps, perc=[2, 98], random_state=random_state)[1]
        ends = csr_matrix.dot(connectivities, eigvecs_ends).sum(1)
        ends = scale(np.clip(ends, 0, np.percentile(ends, 98)))
        write_to_obs(adata, "end_points", ends, cell_subset)

        n_roots, n_ends = eigvecs_roots.shape[1], eigvecs_ends.shape[1]
        groups_str = f" ({groups})" if isinstance(groups, str) else ""
github theislab / scvelo / scvelo / tools / velocity_pseudotime.py View on Github external
def set_iroot(self, root=None):
        if (
            isinstance(root, str)
            and root in self._adata.obs.keys()
            and self._adata.obs[root].max() != 0
        ):
            self.iroot = get_connectivities(self._adata).dot(self._adata.obs[root])
            self.iroot = scale(self.iroot).argmax()
        elif isinstance(root, str) and root in self._adata.obs_names:
            self.iroot = np.where(self._adata.obs_names == root)[0][0]
        elif isinstance(root, (int, np.integer)) and root < self._adata.n_obs:
            self.iroot = root
        else:
            self.iroot = None
github theislab / scvelo / scvelo / tools / dynamical_model.py View on Github external
pars = read_pars(adata)
    alpha, beta, gamma, t_, scaling, std_u, std_s, likelihood = pars[:8]
    u0, s0, pval, steady_u, steady_s, varx = pars[8:]
    # likelihood[np.isnan(likelihood)] = 0
    idx, L, P = [], [], []
    T = np.zeros(adata.shape) * np.nan
    Tau = np.zeros(adata.shape) * np.nan
    Tau_ = np.zeros(adata.shape) * np.nan
    if "fit_t" in adata.layers.keys():
        T = adata.layers["fit_t"]
    if "fit_tau" in adata.layers.keys():
        Tau = adata.layers["fit_tau"]
    if "fit_tau_" in adata.layers.keys():
        Tau_ = adata.layers["fit_tau_"]

    conn = get_connectivities(adata) if fit_connected_states else None
    progress = logg.ProgressReporter(len(var_names))
    for i, gene in enumerate(var_names):
        dm = DynamicsRecovery(
            adata,
            gene,
            use_raw=use_raw,
            load_pars=load_pars,
            max_iter=max_iter,
            fit_time=fit_time,
            fit_steady_states=fit_steady_states,
            fit_connected_states=conn,
            fit_scaling=fit_scaling,
            fit_basal_transcription=fit_basal_transcription,
            steady_state_prior=steady_state_prior,
            **kwargs,
        )
github theislab / scvelo / scvelo / tools / dynamical_model_utils.py View on Github external
try:
            self.initialize_weights()
        except:
            self.recoverable = False
            logg.warn(f'Model for {self.gene} could not be instantiated.')

        self.refit_time = fit_time

        self.assignment_mode = None
        self.steady_state_ratio = None
        self.steady_state_prior = steady_state_prior

        self.fit_scaling = fit_scaling
        self.fit_steady_states = fit_steady_states
        self.fit_connected_states = fit_connected_states
        self.connectivities = get_connectivities(adata) if self.fit_connected_states is True else self.fit_connected_states
        self.high_pars_resolution = high_pars_resolution
        self.init_vals = init_vals

        # for differential kinetic test
        self.clusters, self.cats, self.varx, self.orth_beta = None, None, None, None
        self.diff_kinetics, self.pval_kinetics, self.pvals_kinetics = None, None, None
github theislab / scvelo / scvelo / tools / dynamical_model_utils_deprecated.py View on Github external
u = make_dense(_layers['unspliced']) if use_raw else make_dense(_layers['Mu'])
            s = make_dense(_layers['spliced']) if use_raw else make_dense(_layers['Ms'])
        self.s, self.u = s, u

        # set weights for fitting (exclude dropouts and extreme outliers)
        nonzero = np.ravel(s > 0) & np.ravel(u > 0)
        s_filter = np.ravel(s < np.percentile(s[nonzero], 98))
        u_filter = np.ravel(u < np.percentile(u[nonzero], 98))

        self.weights = s_filter & u_filter & nonzero
        self.fit_scaling = fit_scaling
        self.fit_time = fit_time
        self.fit_alpha = fit_alpha
        self.fit_switching = fit_switching
        self.fit_steady_states = fit_steady_states
        self.connectivities = get_connectivities(adata) if fit_connected_states is True else fit_connected_states

        if load_pars and 'fit_alpha' in adata.var.keys():
            self.load_pars(adata, gene)
        else:
            self.initialize()
github theislab / scvelo / scvelo / tools / dynamical_model.py View on Github external
keys = [key for key in terminal_keys if key in adata.obs.keys()]
        if len(keys) > 0:
            root_key = keys[0]
    if root_key not in adata.uns.keys() and root_key not in adata.obs.keys():
        root_key = "root_cells"
    if root_key not in adata.obs.keys():
        terminal_states(adata, vkey=vkey)

    t = np.array(adata.layers["fit_t"])
    idx_valid = ~np.isnan(t.sum(0))
    if min_likelihood is not None:
        likelihood = adata.var["fit_likelihood"].values
        idx_valid &= np.array(likelihood >= min_likelihood, dtype=bool)
    t = t[:, idx_valid]
    t_sum = np.sum(t, 1)
    conn = get_connectivities(adata)

    if root_key not in adata.uns.keys():
        roots = np.argsort(t_sum)
        idx_roots = np.array(adata.obs[root_key][roots])
        idx_roots[pd.isnull(idx_roots)] = 0
        if np.any([isinstance(ix, str) for ix in idx_roots]):
            idx_roots = np.array([isinstance(ix, str) for ix in idx_roots], dtype=int)
        idx_roots = idx_roots.astype(np.float) > 1 - 1e-3
        if np.sum(idx_roots) > 0:
            roots = roots[idx_roots]
        else:
            logg.warn(
                "No root cells detected. Consider specifying "
                "root cells to improve latent time prediction."
            )
    else:
github theislab / scvelo / scvelo / tools / dynamical_model_utils.py View on Github external
'constraint_time_increments': False, 'fit_steady_states': True, 'fit_basal_transcription': None,
               'std_u': std_u, 'std_s': std_s, 'pval_steady': pval_steady, 'steady_u': steady_u, 'steady_s': steady_s}
    kwargs_.update(adata.uns['recover_dynamics'])
    kwargs_.update(**kwargs)

    reg_time = None
    if use_latent_time is True: use_latent_time = 'latent_time'
    if isinstance(use_latent_time, str) and use_latent_time in adata.obs.keys():
        reg_time = adata.obs[use_latent_time].values
    u, s = get_reads(vdata, use_raw=kwargs_['use_raw'])
    if kwargs_['fit_basal_transcription']: u, s = u - u0, s - s0
    tau = np.array(vdata.layers['fit_tau']) if 'fit_tau' in vdata.layers.keys() else None
    tau_ = np.array(vdata.layers['fit_tau_']) if 'fit_tau_' in vdata.layers.keys() else None

    res = compute_divergence(u, s, alpha, beta, gamma, scaling, t_, tau=tau, tau_=tau_, reg_time=reg_time, mode=mode,
                             connectivities=get_connectivities(adata) if use_connectivities else None, **kwargs_)
    return res