API

Preprocessing

VDJPreprocessingV1Human and VDJPreprocessingV1Mouse are the main preprocessing classes for human and mouse data, respectively. They are used to preprocess raw data and update the AnnData object with the preprocessed data from CellRanger.

class scatlasvae.pp.VDJPreprocessingV1Human(*, cellranger_gex_output_path: str, cellranger_vdj_output_path: str, output_path: str, check_existing_files: bool = False, check_valid_vdj: bool = True, vdj_all_contig: bool = False, cellranger_gex_output_path_opt: str | None = None, cellranger_vdj_output_path_opt: str | Iterable[str] | None = None, sample_name: str | None = None, study_name: str | None = None)[source]

Bases: object

process(r_path: str = '/opt/anaconda3/envs/r403/bin/Rscript', ref_data_path: str = PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/scatlasvae/checkouts/latest/scatlasvae/preprocessing/data/refdata/human/reft_name.Rdata'))[source]

Preprocess the output of cellranger vdj and cellranger gex.

Parameters:
  • r_path – Path to the Rscript executable.

  • ref_data_path – Path to the reference data. Available references are defined in scatlasvae.pp.HSAP_REF_DATA and scatlasvae.pp.MMUS_REF_DATA.

class scatlasvae.pp.VDJPreprocessingV1Mouse(*, cellranger_gex_output_path: str, cellranger_vdj_output_path: str, output_path: str, check_existing_files: bool = False, check_valid_vdj: bool = True, vdj_all_contig: bool = False, cellranger_gex_output_path_opt: str | None = None, cellranger_vdj_output_path_opt: str | Iterable[str] | None = None, sample_name: str | None = None, study_name: str | None = None)[source]

Bases: VDJPreprocessingV1Human

process(r_path: str = '/opt/anaconda3/envs/r403/bin/Rscript', ref_data_path: str = PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/scatlasvae/checkouts/latest/scatlasvae/preprocessing/data/refdata/mouse/reft_name_mouse_nothymus.rds'))[source]

Preprocess the output of cellranger vdj and cellranger gex.

Parameters:
  • r_path – Path to the Rscript executable.

  • ref_data_path – Path to the reference data. Available references are defined in scatlasvae.pp.HSAP_REF_DATA and scatlasvae.pp.MMUS_REF_DATA.

Model

class scatlasvae.model.scAtlasVAE(*, adata: ~anndata._core.anndata.AnnData | None = None, anndata_tensorstore_path: str | None = None, anndata_tensorstore_var_names: ~typing.Iterable[str] | None = None, use_layer: str | None = None, hidden_stacks: ~typing.List[int] = [128], n_latent: int = 10, n_batch: int = 0, n_label: int = 0, n_additional_batch: ~typing.Iterable[int] | None = None, n_additional_label: ~typing.Iterable[int] | None = None, batch_key: str | ~typing.Iterable[str] | None = None, additional_batch_keys: ~typing.Iterable[str] | None = None, label_key: str | ~typing.Iterable[str] | None = None, additional_label_keys: ~typing.Iterable[str] | None = None, encoder_type: ~scatlasvae.model._primitives.EncoderType = EncoderType.SAE, dispersion: ~typing.Literal['gene', 'gene-batch', 'gene-cell'] = 'gene-cell', rna_dropout: ~typing.Literal['gene', 'cell'] = 'gene', log_variational: bool = True, total_variational: bool = False, bias: bool = True, use_batch_norm: bool = True, use_layer_norm: bool = False, batch_hidden_dim: int = 8, batch_embedding: ~typing.Literal['embedding', 'onehot'] = 'embedding', reconstruction_method: ~typing.Literal['mse', 'zg', 'zinb', 'nb'] = 'zinb', constrain_n_label: bool = True, constrain_n_batch: bool = True, constrain_latent_method: ~typing.Literal['mse', 'normal'] = 'mse', constrain_latent_embedding: bool = False, constrain_latent_key: str = 'X_gex', encode_libsize: bool = False, decode_libsize: bool = True, dropout_rate: float = 0.1, activation_fn: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>, inject_batch: bool = True, inject_label: bool = False, inject_additional_batch: bool = True, mmd_key: ~typing.Literal['batch', 'additional_batch', 'both'] | None = None, unlabel_key: str = 'undefined', device: str | ~torch.device | None = None, pretrained_state_dict: str | ~typing.Mapping[str, ~torch.Tensor] | None = None, low_memory_initialization: bool = False)

Bases: ReparameterizeLayerBase, MMDLayerBase

VAE model for atlas-level integration and label transfer

Parameters:
  • adata – AnnData. If provided, initialize the model with the adata.

  • use_layer – Optional[str]. Use the layer in the adata. Default: None

  • hidden_stacks – List[int]. Number of hidden units in each layer. Default: [128] (one hidden layer with 128 units)

  • n_latent – int. Number of latent dimensions. Default: 10

  • n_batch – int. Number of batch. Default: 0

  • n_label – int. Number of label. Default: 0

  • n_additional_batch – Optional[Iterable[int]]. Number of categorical covariate. Default: None

  • batch_key – str. Batch key in adata.obs. Default: None

  • label_key – str. Label key in adata.obs. Default: None

  • dispersion – Literal[“gene”, “gene-batch”, “gene-cell”]. Dispersion modeling method. Default: “gene-cell”

  • rna_dropout – Literal[“gene”, “cell”]. RNA dropout modeling method. Default: “gene” models dropout at the gene level. Alternative: “cell” models dropout at the cell level.

  • log_variational – bool. If True, log the variational distribution. Default: True

  • total_variational – bool. If True, normalize the counts with library size. Default: False

  • bias – bool. If True, use bias in the linear layer. Default: True

  • use_batch_norm – bool. If True, use batch normalization. Default: True

  • use_layer_norm – bool. If True, use layer normalization. Default: False

  • batch_hidden_dim – int. Number of hidden units in the batch embedding layer. Default: 8

  • batch_embedding – Literal[“embedding”, “onehot”]. Batch embedding method. Default: “batch_embedding”

  • constrain_latent_method – Literal[‘mse’, ‘normal’]. Method to constrain the latent embedding. Default: ‘mse’

  • constrain_latent_embedding – bool. If True, constrain the latent embedding. Default: False

  • constrain_latent_key – str. Key to the data to constrain the latent embedding. Default: ‘X_gex’

  • encode_libsize – bool. If True, encode the library size. Default: False

  • decode_libsize – bool. If True, decode the library size. Default: True

  • dropout_rate – float. Dropout rate. Default: 0.1

  • activation_fn – nn.Module. Activation function. Default: nn.ReLU

  • inject_batch – bool. If True, inject batch information. Default: True

  • inject_label – bool. If True, inject label information. Default: False

  • inject_additional_batch – bool. If True, inject categorical covariate information. Default: True

  • unlabel_key – str. key for unlabeled cells. Default: “undefined”

  • mmd_key – Optional[Literal[‘batch’]]. If provided, use MMD loss. Default: None (do not use MMD loss)

  • pretrained_state_dict – torch.device or str. Build the model loading the pretrained state dict

  • device – Optional[Union[str, torch.device]]. Device to use. Default: determined by availablility of CUDA device

Anndata_tensorstore_path. Path to the AnndataTensorStore. Default:

None.

Example:
>>> import scatlasvae
>>> model = scatlasvae.model.scAtlasVAE(
>>>    adata,
>>>    batch_key = ['sample_name','study_name'],
>>>    label_key = ['cell_type', 'cell_subtype'],
>>> )

scAtlasVAE’s Methods table

partial_load_state_dict(state_dict)

Partially load the state dict

get_config()

Get the model config

save_to_disk(path_to_state_dict)

Save the model to disk

load_from_disk(path_to_state_dict)

Load the model from disk

setup_anndata(adata, path_to_state_dict[, ...])

Setup the model with adata

fit([max_epoch, n_per_batch, kl_weight, ...])

Fit the model.

scAtlasVAE’s Methods

scAtlasVAE.partial_load_state_dict(state_dict: Mapping[str, Tensor])[source]

Partially load the state dict

Parameters:

state_dict – Mapping[str, torch.Tensor]. State dict to load

scAtlasVAE.get_config()[source]

Get the model config

Returns:

dict. Model config dictionary

scAtlasVAE.save_to_disk(path_to_state_dict: str | Path)[source]

Save the model to disk

Parameters:

path_to_state_dict – str or Path. Path to save the model

scAtlasVAE.load_from_disk(path_to_state_dict: str | Path)[source]

Load the model from disk

Parameters:

path_to_state_dict – str or Path. Path to load the model

static scAtlasVAE.setup_anndata(adata: AnnData, path_to_state_dict: str | Path, unlabel_key: str = 'undefined')[source]

Setup the model with adata

Parameters:
  • adata – AnnData. AnnData to setup the model

  • path_to_state_dict – Optional[str, Path]. Path to the state dict to load

  • unlabeled_key – str. Default: Undefined

scAtlasVAE.fit(max_epoch: int | None = None, n_per_batch: int = 128, kl_weight: float = 1.0, pred_weight: float = 1.0, mmd_weight: float = 1.0, gate_weight: float = 1.0, constrain_weight: float = 1.0, optimizer_parameters: Iterable | None = None, validation_split: float = 0.2, lr: bool = 5e-05, lr_schedule: bool = False, lr_factor: float = 0.6, lr_patience: int = 30, lr_threshold: float = 0.0, lr_min: float = 1e-06, n_epochs_kl_warmup: int | None = 400, weight_decay: float = 1e-06, random_seed: int = 12, subset_indices: tensor | ndarray | None = None, pred_last_n_epoch: int = 10, pred_last_n_epoch_fconly: bool = False, compute_batch_after_n_epoch: int = 0, reconstruction_reduction: str = 'sum', n_concurrent_batch: int = 1)[source]

Fit the model.

Parameters:
  • max_epoch – int. Maximum number of epoch to train the model. If not provided, the model will be trained for 400 epochs or 20000 / n_record * 400 epochs as default.

  • n_per_batch – int. Number of cells per batch.

  • kl_weight – float. (Maximum) weight of the KL divergence loss.

  • pred_weight – float. weight of the prediction loss.

  • mmd_weight – float. weight of the mmd loss. ignored if mmd_key is None

  • constrain_weight – float. weight of the constrain loss. ignored if constrain_latent_embedding is False.

  • optimizer_parameters – Iterable. Parameters to be optimized. If not provided, all parameters will be optimized.

  • validation_split – float. Percentage of data to be used as validation set.

  • lr – float. Learning rate.

  • lr_schedule – bool. Whether to use learning rate scheduler.

  • lr_factor – float. Factor to reduce learning rate.

  • lr_patience – int. Number of epoch to wait before reducing learning rate.

  • lr_threshold – float. Threshold to trigger learning rate reduction.

  • lr_min – float. Minimum learning rate.

  • n_epochs_kl_warmup – int. Number of epoch to warmup the KL divergence loss (deterministic warm-up

of the KL-term). :param weight_decay: float. Weight decay (L2 penalty). :param random_seed: int. Random seed. :param subset_indices: Union[torch.tensor, np.ndarray]. Indices of cells to be used for training. If not provided, all cells will be used. :param pred_last_n_epoch: int. Number of epoch to train the prediction layer only. :param pred_last_n_epoch_fconly: bool. Whether to train the prediction layer only. :param reconstruction_reduction: str. Reduction method for reconstruction loss. Can be ‘sum’ or ‘mean’.

Pipeline

pipeline.run_transfer(adata_query: AnnData, path_to_state_dict: str, label_key: str = 'cell_type', device='cpu')

Transfer cell type labels from reference to query dataset

Parameters:
  • adata_reference – Reference dataset

  • adata_query – Query dataset

  • path_to_state_dict – Path to the state dict file

  • label_key – Key to store cell type labels

  • device – Device to use. Default is ‘cpu’

Tools

scatlasvae.tl.umap_alignment(reference_embedding, reference_umap, query_embedding, method: Literal['retrain-reference', 'retrain-both', 'knn'] = 'knn', subsample: int = 100000, return_subsampled_indices: bool = False, return_subsampled_reference_umap: bool = False, n_epochs: int = 10, use_cuml_umap: bool = False, return_reducer: bool = False, **kwargs)[source]

Transfer UMAP from reference_embedding to query_embedding

Parameters:
  • reference_embedding – Reference embedding

  • reference_umap – Reference UMAP

  • query_embedding – Query embedding

  • method – Method to use. Either ‘retrain’ or ‘knn’. If ‘retrain-reference’, retrain UMAP using reference_embedding and reference_umap as init. Slow, Not recommended yet. If ‘retrain-both’, first get initial position using knn methods and then retrain UMAP. Slow, Not recommended yet. If ‘knn’, use reference_umap to find nearest neighbors and average their UMAP coordinates

  • subsample – Number of cells to subsample from reference_embedding.

  • return_subsampled_indices – Return subsampled indices.

  • return_subsampled_reference_umap – Return subsampled reference UMAP

  • n_epochs – Number of epochs to use for retraining. Ignore if method is ‘knn’

  • return_reducer – Return reducer. Ignore if method is ‘knn’.

  • kwargs – Additional arguments to pass to UMAP. Ignore if method is ‘knn’.

Example:
>>> import scatlasvae
>>> import scanpy as sc
>>> import numpy as np
>>> adata_query.obsm['X_umap'] = scatlasvae.ut.umap_alignment(
>>>     adata_reference.obsm['X_gex'],
>>>     adata_reference.obsm['X_umap'],
>>>     adata_query.obsm['X_gex']
>>> )["embedding]
scatlasvae.tl.cell_type_alignment(adata: AnnData, obs_1: str, obs_2: str, palette_1: dict | None = None, palette_2: dict | None = None, perc_in_obs_1: float = 0.1, perc_in_obs_2: float = 0.1, ignore_label: str = 'undefined', return_fig: bool = True)[source]

Plot a Sankey diagram of cell type alignment between two obs columns.

Parameters:
  • adata – Annotated data matrix.

  • obs_1 – First obs column to compare.

  • obs_2 – Second obs column to compare.

  • palette_1 – Color palette for obs_1.

  • palette_2 – Color palette for obs_2.

  • perc_in_obs_1 – Minimum percentage of cells in obs_1 to be considered.

Returns:

Sankey diagram in plotly.graph_objects.Figure format.

Example:
>>> import scatlasvae
>>> fig = scatlasvae.utils.cell_type_alignment(adata, "cell_type", "cell_type_2")
>>> fig.show()
>>> fig.write_html("cell_type_alignment.html")

Utilities

scatlasvae.ut.get_default_device()[source]

Get the default device for training