Re-Training Gene Expression (GEX)

This is a repository for the code for retraining multi-source gene expression (GEX) data using the VAE model from the TCR-DeepInsight package.

1import scatlasvae
2
3# Load the data
4reference_adata = scatlasvae.read_h5ad("reference_adata.h5ad")
5query_adata = scatlasvae.read_h5ad("query_adata.h5ad")
6assert(reference_adata.shape[1] != query_adata.shape[1])
7assert("X_gex" in reference_adata.obsm.keys())

The reference_adata and query_adata are anndata.AnnData objects with raw GEX count matrix stored in adata.X, with different number of genes To enable transfer between the two datasets with different number of genes, we need to first re-train a VAE model on the reference dataset using the shared genes between the two datasets. The X_gex is the VAE embedding of the GEX data obtained from previous training. The constrain_latent_embedding and constrain_latent_key arguments constrain the VAE embedding to be close to the X_gex embedding. This is useful when the VAE model is trained on a different subset of genes (e.g. highly variable genes) and we want to use the VAE embedding of the full set of genes.

 1shared_genes = set(reference_adata.var_names).intersection(set(query_adata.var_names))
 2
 3reference_adata = reference_adata[:, list(shared_genes)]
 4query_adata = query_adata[:, list(shared_genes)]
 5
 6# Retrain the VAE model
 7vae_model = scatlasvae.model.scAtlasVAE(
 8  adata=reference_adata,
 9  batch_key="sample_name",
10  batch_embedding='embedding',
11  device='cuda:0',
12  batch_hidden_dim=10,
13  constrain_latent_embedding=True,
14  constrain_latent_key='X_gex'
15)
16vae_model.fit(max_epoch=8)
17vae_model.save_to_disk("retrained_vae_model.pt")
18
19# Get the VAE embedding of the query dataset
20vae_model = scatlasvae.model.VAEModel(
21  adata=query,
22  batch_key="sample_name",
23  batch_embedding='embedding',
24  device='cuda:0',
25  batch_hidden_dim=10,
26  pretrained_state_dict="retrained_vae_model.pt"
27)
28query_adata.obsm['X_gex'] = vae_model.get_latent_embedding()