Transfering Multi-source Gene Expression (GEX)

This is a repository for the code for transfering multi-source gene expression (GEX) data using the VAE model from the TCR-DeepInsight package.

Transfer without training query datasets with reference datasets

1import scatlasvae
2
3# Load the data
4query_adata = scatlasvae.read_h5ad("query_adata.h5ad")

The adata is a anndata.AnnData object with raw GEX count matrix stored in adata.X. To transfer the GEX data to previously established reference, we first need to build a supervised VAE model trained on the reference data with cell type information.

 1reference_adata = scatlasvae.read_h5ad("reference_adata.h5ad")
 2
 3reference_model = scatlasvae.model.scAtlasVAE(
 4  adata=reference_adata,
 5  batch_key="sample_name",
 6  label_key="cell_type",
 7  batch_embedding='embedding',
 8  device='cuda:0',
 9  batch_hidden_dim=8
10)
11
12reference_model.fit()
13reference_model.save_to_disk("model.pt")

We need to make sure that the number of genes in the query data is the same as the reference data. If not, please see the Retraining Multi-source GEX Data tutorial for how to transfer GEX data with different number of genes.

 1scatlasvae.model.scAtlasVAE.setup_anndata(query_adata, "model.pt")
 2query_model = scatlasvae.model.scAtlasVAE(
 3  adata=query_adata,
 4  batch_key="sample_name",
 5  label_key="cell_type",
 6  batch_embedding='embedding',
 7  device='cuda:0',
 8  batch_hidden_dim=8,
 9  pretrained_state_dict="model.pt",
10)

Without further training, we can use the predict_labels method to transfer the cell type information from the reference to the query dataset.

 1predictions = query_model.predict_labels(
 2  return_pandas=True,
 3  show_progress=True
 4)
 5predictions.columns = list(map(lambda x: 'predicted_'+x, predictions.columns))
 6query_adata.obs = query_adata.obs.join(predictions)
 7
 8predictions_logits = query_model.predict_labels(return_pandas=False)
 9query_adata.uns['predictions_logits'] = predictions_logits
10
11count, fig = scatlasvae.ut.cell_type_alignment(
12  query_adata,
13  obs_1='original_celltype',
14  obs_2='predicted_cell_type',
15  return_fig=True
16)
17fig.show()

Getting the transfered latent embedding

1query_adata.obsm['X_gex'] = query_model.get_latent_embedding()

Mapping the UMAP representation to the reference

1query_adata.obsm['X_umap'] = scatlasvae.tl.umap_alignment(
2  reference_adata.obsm['X_gex'],
3  reference_adata.obsm['X_umap'],
4  query_adata.obsm['X_gex'],
5  method = 'knn'
6)['embedding']

Optionally, if the label_key or additional_label_keys is setted in the reference model, one can use query_model.predict_labels() to get the transfered cell types.

Transfer by training query datasets with reference datasets

The alternative way to project query data to reference data is by co-training the reference and query datasets. However, this approach is more computationally expensive since we need to train the model on both reference and query datasets, and the model is not guaranteed to be the same as the model trained on the reference dataset alone.

 1import scatlasvae
 2import scanpy as sc
 3
 4query_adata.obs['cell_type'] = 'undefined'
 5merged_adata = sc.concat([reference_adata, query_adata])
 6
 7model = scatlasvae.model.scAtlasVAE(
 8  adata=merged_adata,
 9  batch_key="sample_name",
10  batch_embedding='embedding',
11  label_key="cell_type",
12  device='cuda:0',
13  batch_hidden_dim=8
14)
15model.fit()
16
17predictions = model.predict_labels(
18  return_pandas=True,
19  show_progress=True
20)
21
22predictions.columns = list(map(lambda x: 'predicted_'+x, predictions.columns))
23merged_adata.obs = merged_adata.obs.join(predictions)
24
25predictions_logits = model.predict_labels(return_pandas=False)
26merged_adata.uns['predictions_logits'] = predictions_logits
27
28count, fig = scatlasvae.ut.cell_type_alignment(
29  merged_adata[query_adata.obs.index],
30  obs_1='original_celltype',
31  obs_2='predicted_cell_type,
32  return_fig=True
33)
34fig.show()