Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
calibMMDNet.fit(train_real_ctrl, sourceLabels, nb_epoch=n_epochs, batch_size=batch_size, validation_split=0.1,
verbose=2, callbacks=[lrate, EarlyStopping(monitor='val_loss', patience=50, mode='auto')])
path_to_save = f"../results/MMDResNet/{data_name}/{spec_cell_type}"
sc.settings.figdir = os.path.abspath(path_to_save)
sc.settings.writedir = os.path.abspath(path_to_save)
CD4T = data.copy()[data.obs[cell_type_key] == spec_cell_type]
ctrl_CD4T = data.copy()[(data.obs[cell_type_key] == spec_cell_type) & (data.obs['condition'] == source_key)]
stim_CD4T = data.copy()[(data.obs[cell_type_key] == spec_cell_type) & (data.obs['condition'] == target_key)]
if sparse.issparse(ctrl_CD4T.X):
ctrl_CD4T.X = ctrl_CD4T.X.A
stim_CD4T.X = stim_CD4T.X.A
if data_name == "pbmc":
sc.tl.rank_genes_groups(CD4T, groupby="condition", n_genes=100, method="wilcoxon")
top_100_genes = CD4T.uns["rank_genes_groups"]["names"][target_key].tolist()
gene_list = top_100_genes[:10]
else:
sc.tl.rank_genes_groups(CD4T, groupby="condition", n_genes=100, method="wilcoxon")
top_50_down_genes = CD4T.uns["rank_genes_groups"]["names"][source_key].tolist()
top_50_up_genes = CD4T.uns["rank_genes_groups"]["names"][target_key].tolist()
top_100_genes = top_50_up_genes + top_50_down_genes
gene_list = top_50_down_genes[:5] + top_50_up_genes[:5]
pred_stim = calibMMDNet.predict(ctrl_CD4T.X)
all_Data = sc.AnnData(np.concatenate([ctrl_CD4T.X, stim_CD4T.X, pred_stim]))
all_Data.obs["condition"] = ["ctrl"] * len(ctrl_CD4T.X) + ["real_stim"] * len(stim_CD4T.X) + \
["pred_stim"] * len(pred_stim)
all_Data.var_names = CD4T.var_names
trvae.plotting.reg_var_plot(all_Data,
predicted_cells = network.predict(unperturbed_data, fake_labels)
adata = sc.AnnData(predicted_cells, obs={"condition": ["pred"]*len(fake_labels)})
adata.var_names = CD4T.var_names
all_adata = CD4T.concatenate(adata)
scgen.plotting.reg_mean_plot(all_adata, condition_key="condition",
axis_keys={"x": "pred", "y": "stimulated"},
gene_list= ["ISG15", "CD3D"],
path_to_save=f"figures/reg_mean_{z_dim}.pdf")
scgen.plotting.reg_var_plot(all_adata, condition_key="condition",
axis_keys={"x": "pred", "y": "stimulated"},
gene_list= ["ISG15", "CD3D"],
path_to_save=f"figures/reg_var_{z_dim}.pdf")
sc.pp.neighbors(all_adata)
sc.tl.umap(all_adata)
sc.pl.umap(all_adata, color ="condition", save="pred")
sc.pl.violin(all_adata, keys="ISG15", groupby="condition", save=f"violin_{z_dim}")
control_data = sc.read("data/zheng_cross_study.h5ad")
network = scgen.MMDCVAE(x_dimension=train.X.shape[1], z_dimension=z_dim, alpha=alpha, beta=beta,
batch_mmd=True, kernel=kernel, train_with_fake_labels=False,
model_path="./")
# network.train(train, n_epochs=n_epochs, batch_size=batch_size, verbose=2)
print(f"network has been trained!")
network.restore_model()
true_labels= np.zeros(shape=(control_data.shape[0], 1))
fake_labels = np.ones(shape=(control_data.shape[0], 1))
pred = network.predict(data=control_data, encoder_labels=true_labels, decoder_labels=fake_labels)
pred_adata = anndata.AnnData(pred, obs={condition_key: ["pred"] * len(pred), "cell_type":control_data.obs["cell_type"].tolist()},
var={"var_names": control_data.var_names})
all_adata = control_data.concatenate(pred_adata)
sc.pp.neighbors(all_adata)
sc.tl.umap(all_adata)
sc.pl.umap(all_adata, color=[f"{condition_key}", "cell_type", "ISG15"],
save="cross_pred")
all_adata.write("cross_study_mmd.h5ad")
latent_with_true_labels = network.to_latent(feed_data, train_labels)
latent_with_fake_labels = network.to_latent(feed_data, fake_labels)
mmd_latent_with_true_labels = network.to_mmd_layer(feed_data, train_labels)
mmd_latent_with_fake_labels = network.to_mmd_layer(feed_data, fake_labels)
cell_type_ctrl = cell_type_adata.copy()[cell_type_adata.obs['condition'] == source_key]
print(cell_type_ctrl.shape, cell_type_adata.shape)
pred_celltypes = network.predict(cell_type_ctrl, labels=np.ones((cell_type_ctrl.shape[0], 1)))
pred_adata = anndata.AnnData(X=pred_celltypes)
pred_adata.obs['condition'] = ['predicted'] * pred_adata.shape[0]
pred_adata.var = cell_type_adata.var
if data_name == "pbmc":
sc.tl.rank_genes_groups(cell_type_adata, groupby="condition", n_genes=100, method="wilcoxon")
top_100_genes = cell_type_adata.uns["rank_genes_groups"]["names"][target_key].tolist()
gene_list = top_100_genes[:10]
else:
sc.tl.rank_genes_groups(cell_type_adata, groupby="condition", n_genes=100, method="wilcoxon")
top_50_down_genes = cell_type_adata.uns["rank_genes_groups"]["names"][source_key].tolist()
top_50_up_genes = cell_type_adata.uns["rank_genes_groups"]["names"][target_key].tolist()
top_100_genes = top_50_up_genes + top_50_down_genes
gene_list = top_50_down_genes[:5] + top_50_up_genes[:5]
cell_type_adata = cell_type_adata.concatenate(pred_adata)
trvae.plotting.reg_mean_plot(cell_type_adata,
top_100_genes=top_100_genes,
gene_list=gene_list,
condition_key='condition',
axis_keys={"x": 'predicted', 'y': target_key},
def score(adata, n_deg=10, n_genes=1000, condition_key="condition", cell_type_key="cell_type",
conditions={"stim": "stimulated", "ctrl": "control"},
sortby="median_score"):
import scanpy as sc
import numpy as np
from scipy.stats import entropy
import pandas as pd
sc.tl.rank_genes_groups(adata, groupby=condition_key, method="wilcoxon", n_genes=n_genes)
gene_names = adata.uns["rank_genes_groups"]['names'][conditions['stim']]
gene_lfcs = adata.uns["rank_genes_groups"]['logfoldchanges'][conditions['stim']]
diff_genes_df = pd.DataFrame({"names": gene_names, "lfc": gene_lfcs})
diff_genes = diff_genes_df["names"].tolist()[:n_genes]
print(len(diff_genes))
adata_deg = adata[:, diff_genes].copy()
cell_types = adata_deg.obs[cell_type_key].cat.categories.tolist()
lfc_temp = np.zeros((len(cell_types), n_genes))
for j, ct in enumerate(cell_types):
if cell_type_key == "cell_type": # if data is pbmc
stim = adata_deg[(adata_deg.obs[cell_type_key] == ct) &
(adata_deg.obs[condition_key] == conditions["stim"])].X.mean(0).A1
ctrl = adata_deg[(adata_deg.obs[cell_type_key] == ct) &
(adata_deg.obs[condition_key] == conditions["ctrl"])].X.mean(0).A1
else:
mmd_with_true_labels = sc.AnnData(X=mmd_with_true_labels,
obs={condition_key: net_train_data.obs[condition_key].tolist(),
cell_type_key: pd.Categorical(net_train_data.obs[cell_type_key])})
sc.pp.neighbors(mmd_with_true_labels)
sc.tl.umap(mmd_with_true_labels)
sc.pl.umap(mmd_with_true_labels, color=[condition_key, cell_type_key],
save=f"_mmd_true_labels_{z_dim}",
show=False)
mmd_with_fake_labels = network.to_mmd_layer(network.cvae_model, net_train_data.X,
encoder_labels=true_labels, feed_fake=True)
mmd_with_fake_labels = sc.AnnData(X=mmd_with_fake_labels,
obs={condition_key: net_train_data.obs[condition_key].tolist(),
cell_type_key: pd.Categorical(net_train_data.obs[cell_type_key])})
sc.pp.neighbors(mmd_with_fake_labels)
sc.tl.umap(mmd_with_fake_labels)
sc.pl.umap(mmd_with_fake_labels, color=[condition_key, cell_type_key],
save=f"_mmd_fake_labels_{z_dim}",
show=False)
cell_type_data = train[train.obs[cell_type_key] == cell_type]
unperturbed_data = train[((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == ctrl_key))]
true_labels = np.zeros((len(unperturbed_data), 1))
fake_labels = np.ones((len(unperturbed_data), 1))
pred = network.predict(data=unperturbed_data, encoder_labels=true_labels, decoder_labels=fake_labels)
pred_adata = anndata.AnnData(pred, obs={condition_key: ["pred"] * len(pred)},
var={"var_names": cell_type_data.var_names})
all_adata = cell_type_data.copy().concatenate(pred_adata.copy())
scgen.plotting.reg_mean_plot(all_adata, condition_key=condition_key,
axis_keys={"x": ctrl_key, "y": "pred", "y1": stim_key},
network.train(net_train_adata,
net_valid_adata,
condition_encoder,
condition_key,
n_epochs=10000,
batch_size=batch_size_choices,
verbose=2,
early_stop_limit=250,
lr_reducer=200,
monitor='val_loss',
shuffle=True,
save=False)
cell_type_adata = train_adata.copy()[train_adata.obs[cell_type_key] == cell_type]
sc.tl.rank_genes_groups(cell_type_adata,
key_added='up_reg_genes',
groupby=condition_key,
groups=[target_condition],
reference=source_condition,
n_genes=10)
sc.tl.rank_genes_groups(cell_type_adata,
key_added='down_reg_genes',
groupby=condition_key,
groups=[source_condition],
reference=target_condition,
n_genes=10)
up_genes = cell_type_adata.uns['up_reg_genes']['names'][target_condition].tolist()
down_genes = cell_type_adata.uns['down_reg_genes']['names'][source_condition].tolist()
path_to_save=os.path.join(path_to_save, f"reg_mean_top_50_genes.pdf"))
scgen.plotting.reg_var_plot(all_adata_top_50_genes, condition_key=condition_key,
axis_keys={"x": "pred", "y": conditions["stim"]},
gene_list=diff_genes[:5],
path_to_save=os.path.join(path_to_save, f"reg_var_top_50_genes.pdf"))
if plot_umap:
sc.pp.neighbors(all_adata)
sc.tl.umap(all_adata)
sc.pl.umap(all_adata, color=condition_key,
save="pred_all_genes",
show=False)
sc.pp.neighbors(all_adata_top_100_genes)
sc.tl.umap(all_adata_top_100_genes)
sc.pl.umap(all_adata_top_100_genes, color=condition_key,
save="pred_top_100_genes",
show=False)
sc.pp.neighbors(all_adata_top_50_genes)
sc.tl.umap(all_adata_top_50_genes)
sc.pl.umap(all_adata_top_50_genes, color=condition_key,
save="pred_top_50_genes",
show=False)
sc.pl.violin(all_adata, keys=diff_genes.tolist()[0], groupby=condition_key,
save=f"_{diff_genes.tolist()[0]}",
show=False)
plt.close("all")
def sweep_pca_sqrt(base_adata, max_pcs=30, p=0.9, save_reconstruction=True):
adata = base_adata.copy()
adata1, adata2 = split_adata(adata, p)
sc.pp.sqrt(adata1)
sc.pp.sqrt(adata2)
sc.pp.sqrt(adata)
sc.tl.pca(adata, n_comps=max_pcs, zero_center=False, random_state=1)
sc.tl.pca(adata1, n_comps=max_pcs, zero_center=False, random_state=1)
k_range = pca_range(max_pcs)
denoised = []
for i, k in enumerate(tqdm(k_range)):
reconstruction = adata.obsm['X_pca'][:, :k].dot(adata.varm['PCs'].T[:k])
reconstruction = np.maximum(reconstruction, 0)
reconstruction1 = adata1.obsm['X_pca'][:, :k].dot(
adata1.varm['PCs'].T[:k])
mcv = mean_squared_error(convert_expectations(reconstruction1, p, 1 - p), adata2.X)
if not save_reconstruction:
reconstruction = dok_matrix(reconstruction.shape)
adata_denoised = sc.AnnData(X=reconstruction,