Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_data(data_dict):
data_name = data_dict['name']
source_keys = data_dict.get("source_conditions")
target_keys = data_dict.get("target_conditions")
cell_type_key = data_dict.get("cell_type_key", None)
condition_key = data_dict.get('condition_key', 'condition')
spec_cell_type = data_dict.get("spec_cell_types", None)[0]
adata = sc.read(f"./data/{data_name}/{data_name}_normalized.h5ad")
adata = adata[adata.obs[condition_key].isin(source_keys + target_keys)]
if adata.shape[1] > 2000:
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
adata = adata[:, adata.var['highly_variable']]
train_adata, valid_adata = train_test_split(adata, 0.80)
net_train_adata = train_adata.copy()[~((train_adata.obs[cell_type_key] == spec_cell_type) &
(train_adata.obs[condition_key].isin(target_keys)))]
net_valid_adata = valid_adata.copy()[~((valid_adata.obs[cell_type_key] == spec_cell_type) &
(valid_adata.obs[condition_key].isin(target_keys)))]
return adata, net_train_adata, net_valid_adata
z_dim=100,
subsample=None,
alpha=0.001,
n_epochs=500,
batch_size=512,
dropout_rate=0.2,
learning_rate=0.001,
gpus=1,
verbose=2,
arch_style=1,
):
data_name = data_dict['name']
metadata_path = data_dict['metadata']
cell_type_key = data_dict['cell_type']
train_data = sc.read(f"../data/{data_name}/anna/processed_adata_Cusanovich_brain_May29_2019_5000.h5ad")
train_data.X += abs(train_data.X.min())
if subsample is not None:
train_data = train_data[:subsample]
spec_cell_type = data_dict.get("spec_cell_types", None)
if spec_cell_type is not []:
cell_types = spec_cell_type
train_size = int(train_data.shape[0] * 0.85)
indices = np.arange(train_data.shape[0])
np.random.shuffle(indices)
train_idx = indices[:train_size]
valid_idx = indices[train_size:]
net_train_data = train_data.copy()[train_idx, :]
net_valid_data = train_data.copy()[valid_idx, :]
stim_key = "Hpoly.Day10"
ctrl_key = "Control"
cell_type_key = "cell_label"
train = sc.read("../data/ch10_train_7000.h5ad")
elif data_name == "salmonella":
cell_type_to_monitor = None
stim_key = "Salmonella"
ctrl_key = "Control"
cell_type_key = "cell_label"
train = sc.read("../data/chsal_train_7000.h5ad")
elif data_name == "species":
cell_type_to_monitor = "rat"
stim_key = "LPS6"
ctrl_key = "unst"
cell_type_key = "species"
train = sc.read("../data/train_all_lps6.h5ad")
for cell_type in train.obs[cell_type_key].unique().tolist():
os.makedirs(f"./vae_results/{data_name}/{cell_type}/", exist_ok=True)
os.chdir(f"./vae_results/{data_name}/{cell_type}")
net_train_data = train[~((train.obs[cell_type_key] == cell_type) & (train.obs[condition_key] == stim_key))]
network = scgen.VAEArithKeras(x_dimension=net_train_data.X.shape[1],
z_dimension=z_dim,
alpha=alpha,
dropout_rate=dropout_rate,
learning_rate=learning_rate)
# network.restore_model()
network.train(net_train_data, train, n_epochs=n_epochs, batch_size=batch_size, verbose=2,
conditions={"ctrl": ctrl_key, "stim": stim_key},
condition_key=condition_key, cell_type_key=cell_type_key,
cell_type=cell_type, path_to_save="./figures/keras/")
def plot_reg_mean_with_genes(data_name="pbmc", gene_list=None):
if data_name == "pbmc":
stim_key = "stimulated"
ctrl_key = "control"
cell_type_key = "cell_type"
train = sc.read("../data/train.h5ad")
elif data_name == "hpoly":
stim_key = "Hpoly.Day10"
ctrl_key = "Control"
cell_type_key = "cell_label"
train = sc.read("../data/ch10_train_7000.h5ad")
elif data_name == "salmonella":
stim_key = "Salmonella"
ctrl_key = "Control"
cell_type_key = "cell_label"
train = sc.read("../data/chsal_train_7000.h5ad")
elif data_name == "species":
stim_key = "LPS6"
ctrl_key = "unst"
cell_type_key = "species"
train = sc.read("../data/train_all_lps6.h5ad")
recon_data = sc.read(f"./vae_results/{data_name}/reconstructed.h5ad")
def visualize_trained_network_results(data_dict, z_dim=100):
plt.close("all")
data_name = data_dict.get('name', None)
source_key = data_dict.get('source_key', None)
target_key = data_dict.get('target_key', None)
cell_type_key = data_dict.get("cell_type", None)
data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")
cell_types = data.obs[cell_type_key].unique().tolist()
spec_cell_type = data_dict.get("spec_cell_types", None)
if spec_cell_type:
cell_types = spec_cell_type
for cell_type in cell_types:
path_to_save = f"../results/RCVAE/{data_name}/{cell_type}/{z_dim}/{source_key} to {target_key}/Visualizations/"
os.makedirs(path_to_save, exist_ok=True)
sc.settings.figdir = os.path.abspath(path_to_save)
train_data = data.copy()[~((data.obs['condition'] == target_key) & (data.obs[cell_type_key] == cell_type))]
cell_type_adata = data[data.obs[cell_type_key] == cell_type]
network = trvae.trVAE(x_dimension=data.shape[1],
def visualize_batch_correction(data_dict, z_dim=100, mmd_dimension=128):
plt.close("all")
data_name = data_dict['name']
source_keys = data_dict.get("source_conditions")
target_keys = data_dict.get("target_conditions")
cell_type_key = data_dict.get("cell_type", None)
need_merge = data_dict.get('need_merge', False)
label_encoder = data_dict.get('label_encoder', None)
condition_key = data_dict.get('condition', 'condition')
if need_merge:
data, _ = merge_data(data_dict)
else:
data = sc.read(f"../data/{data_name}/train_{data_name}.h5ad")
cell_types = data.obs[cell_type_key].unique().tolist()
spec_cell_type = data_dict.get("spec_cell_types", None)
if spec_cell_type:
cell_types = spec_cell_type
for cell_type in cell_types:
path_to_save = f"../results/RCVAEMulti/{data_name}/{cell_type}/{z_dim}/Visualizations/"
os.makedirs(path_to_save, exist_ok=True)
sc.settings.figdir = os.path.abspath(path_to_save)
train_data = data.copy()[
~((data.obs[condition_key].isin(target_keys)) & (data.obs[cell_type_key] == cell_type))]
cell_type_adata = data[data.obs[cell_type_key] == cell_type]
def load_bipolar():
adata = sc.read('/Users/josh/src/molecular-cross-validation/data/bipolar/bipolar.h5ad')
adata.raw = adata
adata.obs['cell_type'] = adata.obs['CLUSTER']
return adata
"""
Template for preprocessing function. Use copy and paste.
Returns
-------
adata : AnnData
Stores data matrix and sample and variable annotations as well
as an arbitrary amount of unstructured annotation. For the latter
it behaves like a Python dictionary.
"""
# Generate an AnnData object, which is similar
# to R's ExpressionSet (Huber et al., Nat. Meth. 2015)
# AnnData allows annotation of samples/cells and variables/genes via
# the attributes "smp" and "var"
path_to_data = 'data/myexample/'
adata = sc.read(path_to_data + 'myexample.csv')
# other data reading examples
#adata = sc.read(path_to_data + 'myexample.txt')
#adata = sc.read(path_to_data + 'myexample.h5', sheet='mysheet')
#adata = sc.read(path_to_data + 'myexample.xlsx', sheet='mysheet')
#adata = sc.read(path_to_data + 'myexample.txt.gz')
#adata = sc.read(path_to_data + 'myexample.soft.gz')
# if the first column does not store strings, rownames are not detected
# automatically, hence
#adata = sc.read(path_to_data + 'myexample.csv', first_column_names=True)
# transpose if needed to match the convention that rows store samples/cells
# and columns variables/genes
# adata = adata.transpose() # rows = samples/cells & columns = variables/genes
# read some annotation from a file, now we want strings, and not a numerical
# data matrix, the following reads from the first column of the file