Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_anndata_raw(self, sparse):
data, sample_description = self.simulate()
gene_names = ["gene" + str(i) for i in range(data.shape[1])]
if sparse:
data = scipy.sparse.csr_matrix(data)
data = anndata.AnnData(data)
data.var_names = gene_names
data.raw = data
self._test_wald(data=data.raw, sample_description=sample_description)
self._test_lrt(data=data.raw, sample_description=sample_description)
self._test_t_test(data=data, sample_description=sample_description)
self._test_rank(data=data, sample_description=sample_description)
def test_reg_mean_plot():
train = sc.read("./tests/data/train.h5ad", backup_url="https://goo.gl/33HtVh")
network = scgen.VAEArith(x_dimension=train.shape[1], model_path="../models/test")
network.train(train_data=train, n_epochs=0)
unperturbed_data = train[((train.obs["cell_type"] == "CD4T") & (train.obs["condition"] == "control"))]
condition = {"ctrl": "control", "stim": "stimulated"}
pred, delta = network.predict(adata=train, adata_to_predict=unperturbed_data, conditions=condition,
condition_key="condition",cell_type_key="cell_type")
pred_adata = anndata.AnnData(pred, obs={"condition": ["pred"] * len(pred)}, var={"var_names": train.var_names})
CD4T = train[train.obs["cell_type"] == "CD4T"]
all_adata = CD4T.concatenate(pred_adata)
scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred"},
path_to_save="tests/reg_mean1.pdf")
scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred"},
path_to_save="tests/reg_mean2.pdf", gene_list=["ISG15", "CD3D"])
scgen.plotting.reg_mean_plot(all_adata,condition_key="condition", axis_keys={"x": "control", "y": "pred", "y1": "stimulated"},
path_to_save="tests/reg_mean3.pdf")
scgen.plotting.reg_mean_plot(all_adata, condition_key="condition", axis_keys={"x": "control", "y": "pred", "y1": "stimulated"},
gene_list=["ISG15", "CD3D"], path_to_save="tests/reg_mean.pdf",)
network.sess.close()
if adata_source.shape[0] == 0:
adata_source = pred_adatas.copy()[pred_adatas.obs[condition_key] == source_condition]
if adata_target.shape[0] == 0:
adata_target = pred_adatas.copy()[pred_adatas.obs[condition_key] == target_condition]
source_labels = np.zeros(adata_source.shape[0]) + source_label
target_labels = np.zeros(adata_source.shape[0]) + target_label
pred_target = network.predict(adata_source,
encoder_labels=source_labels,
decoder_labels=target_labels,
size_factor=adata_source.obs['size_factors'].values
)
pred_adata = anndata.AnnData(X=pred_target)
pred_adata.obs[condition_key] = [name] * pred_target.shape[0]
pred_adata.var_names = adata.var_names
if sparse.issparse(adata_source.X):
adata_source.X = adata_source.X.A
if sparse.issparse(adata_target.X):
adata_target.X = adata_target.X.A
if sparse.issparse(pred_adata.X):
pred_adata.X = pred_adata.X.A
# adata_to_plot = pred_adata.concatenate(adata_target)
# trvae.plotting.reg_mean_plot(adata_to_plot,
# top_100_genes=top_100_genes,
show=False)
decoded_latent_with_true_labels = network.predict(data=latent_with_true_labels, encoder_labels=true_labels,
decoder_labels=true_labels, data_space='latent')
cell_type_data = train[train.obs[cell_type_key] == cell_type]
unperturbed_data = train[((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == ctrl_key))]
true_labels = np.zeros((len(unperturbed_data), 1))
fake_labels = np.ones((len(unperturbed_data), 1))
sc.tl.rank_genes_groups(cell_type_data, groupby=condition_key, n_genes=100)
diff_genes = cell_type_data.uns["rank_genes_groups"]["names"][stim_key]
# cell_type_data = cell_type_data.copy()[:, diff_genes.tolist()]
pred = network.predict(data=unperturbed_data, encoder_labels=true_labels, decoder_labels=fake_labels)
pred_adata = anndata.AnnData(pred, obs={condition_key: ["pred"] * len(pred)},
var={"var_names": cell_type_data.var_names})
all_adata = cell_type_data.concatenate(pred_adata)
scgen.plotting.reg_mean_plot(all_adata, condition_key=condition_key,
axis_keys={"x": ctrl_key, "y": stim_key, "y1": "pred"},
gene_list=diff_genes,
path_to_save=f"./figures/reg_mean_{z_dim}.pdf")
scgen.plotting.reg_var_plot(all_adata, condition_key=condition_key,
axis_keys={"x": ctrl_key, "y": stim_key, 'y1': "pred"},
gene_list=diff_genes,
path_to_save=f"./figures/reg_var_{z_dim}.pdf")
sc.pp.neighbors(all_adata)
sc.tl.umap(all_adata)
sc.pl.umap(all_adata, color=condition_key,
save="pred")
self._check_data(data)
n_pca, rank_threshold = self._parse_n_pca_threshold(data, n_pca, rank_threshold)
try:
if isinstance(data, pd.SparseDataFrame):
data = data.to_coo()
elif isinstance(data, pd.DataFrame):
try:
data = data.sparse.to_coo()
except AttributeError:
data = np.array(data)
except NameError:
# pandas not installed
pass
try:
if isinstance(data, anndata.AnnData):
data = data.X
except NameError:
# anndata not installed
pass
self.data = data
self.n_pca = n_pca
self.rank_threshold = rank_threshold
self.random_state = random_state
self.data_nu = self._reduce_data()
super().__init__(**kwargs)
If `return_info` is true, all estimated distribution parameters are stored in AnnData such as:
- `.obsm["X_dca_dropout"]` which is the mixture coefficient (pi) of the zero component
in ZINB, i.e. dropout probability. (Only if ae_type is zinb or zinb-conddisp)
- `.obsm["X_dca_dispersion"]` which is the dispersion parameter of NB.
- `.uns["dca_loss_history"]` which stores the loss history of the training.
Finally, the raw counts are stored as `.raw`.
If `return_model` is given, trained model is returned. When both `copy` and `return_model`
are true, a tuple of anndata and model is returned in that order.
"""
assert isinstance(adata, anndata.AnnData), 'adata must be an AnnData instance'
assert mode in ('denoise', 'latent'), '%s is not a valid mode.' % mode
# set seed for reproducibility
random.seed(random_state)
np.random.seed(random_state)
tf.set_random_seed(random_state)
os.environ['PYTHONHASHSEED'] = '0'
# this creates adata.raw with raw counts and copies adata if copy==True
adata = read_dataset(adata,
transpose=False,
test_split=False,
copy=copy)
# check for zero genes
nonzero_genes, _ = sc.pp.filter_genes(adata.X, min_counts=1)
def plot_diff_map(X, pseudotime, brns):
data = ad.AnnData(X)
diffmap(adata=data)
diff_map = data.obsm["X_diffmap"]
cols = np.array(list(pylab.cm.Set1.colors))
branch_names, indices = np.unique(brns, return_inverse=True)
fig, ax = plt.subplots(ncols=2)
fig.set_size_inches(w=9, h=4)
ax[0].scatter(diff_map[:, 0], diff_map[:, 1], c=cols[indices])
ax[0].set_title("branches")
ax[1].scatter(diff_map[:, 0], diff_map[:, 1], c=pseudotime, cmap="viridis")
ax[1].set_title("pseudotime")
plt.show()
def louvain(X, N, resolution=1, seed=None, replace=False):
from anndata import AnnData
import scanpy.api as sc
adata = AnnData(X=X)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.louvain(adata, resolution=resolution, key_added='louvain')
cluster_labels_full = adata.obs['louvain'].tolist()
louv = {}
for i, cluster in enumerate(cluster_labels_full):
if cluster not in louv:
louv[cluster] = []
louv[cluster].append(i)
lv_idx = []
for n in range(N):
louv_cells = list(louv.keys())
louv_cell = louv_cells[np.random.choice(len(louv_cells))]
samples = list(louv[louv_cell])
sample = samples[np.random.choice(len(samples))]
corrected = anndata.AnnData(network.reconstruct(all_shared_ann.X, use_data=True))
corrected.obs = all_shared_ann.obs.copy(deep=True)
corrected.var_names = adata.var_names.tolist()
corrected = corrected[adata.obs_names]
if adata.raw is not None:
adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
adata_raw.obs_names = adata.obs_names
corrected.raw = adata_raw
corrected.obsm["latent"] = all_shared_ann.X
return corrected
else:
all_not_shared_ann = anndata.AnnData.concatenate(*not_shared_ct, batch_key="concat_batch", index_unique=None)
all_corrected_data = anndata.AnnData.concatenate(all_shared_ann, all_not_shared_ann, batch_key="concat_batch", index_unique=None)
if "concat_batch" in all_shared_ann.obs.columns:
del all_corrected_data.obs["concat_batch"]
corrected = anndata.AnnData(network.reconstruct(all_corrected_data.X, use_data=True))
corrected.obs = pd.concat([all_shared_ann.obs, all_not_shared_ann.obs])
corrected.var_names = adata.var_names.tolist()
corrected = corrected[adata.obs_names]
if adata.raw is not None:
adata_raw = anndata.AnnData(X=adata.raw.X, var=adata.raw.var)
adata_raw.obs_names = adata.obs_names
corrected.raw = adata_raw
corrected.obsm["latent"] = all_corrected_data.X
return corrected
)
encoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
decoder_labels, _ = trvae.utils.label_encoder(train_adata, condition_key="condition")
pred_adata = network.predict(train_adata, encoder_labels, decoder_labels)
```
"""
adata = remove_sparsity(adata)
encoder_labels = to_categorical(encoder_labels, num_classes=self.n_conditions)
decoder_labels = to_categorical(decoder_labels, num_classes=self.n_conditions)
reconstructed = self.trvae_model.predict([adata.X, encoder_labels, decoder_labels])[0]
reconstructed = np.nan_to_num(reconstructed)
if return_adata:
output = anndata.AnnData(X=reconstructed)
output.obs = adata.obs.copy(deep=True)
output.var_names = adata.var_names
else:
output = reconstructed
return output