# Pytorch
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.distributions import kl_divergence as kld
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Third Party
import numpy as np
from pathlib import Path
import pandas as pd
import scanpy as sc
from scipy.sparse import issparse
# import anndata_tensorstore as ats
# Built-in
import time
from collections import Counter
from itertools import chain
from copy import deepcopy
import json
from typing import Mapping, Union, Iterable, Tuple, Optional, Mapping, Dict
from concurrent.futures import ThreadPoolExecutor
import os
import warnings
# Package
from ._primitives import *
from ..utils._tensor_utils import one_hot, get_k_elements, get_last_k_elements, get_elements
from ..utils._decorators import typed
from ..utils._loss import LossFunction
from ..utils._logger import mt, mw, Colors, get_tqdm, is_notebook
from ..utils._utilities import random_subset_by_key_fast, next_unicode_char
from ..utils._compat import Literal
from ..utils._utilities import get_default_device
from ..tools._umap import umap_alignment
from ..preprocessing._preprocess import subset_adata_by_genes_fill_zeros
from ..externals.tabnet.tab_network import TabNetEncoder
MODULE_PATH = Path(__file__).parent
warnings.filterwarnings("ignore")
class scAtlasVAE(ReparameterizeLayerBase, MMDLayerBase):
"""
VAE model for atlas-level integration and label transfer
:param adata: AnnData. If provided, initialize the model with the adata.
:anndata_tensorstore_path. Path to the AnndataTensorStore. Default: None.
:param use_layer: Optional[str]. Use the layer in the adata. Default: None
:param hidden_stacks: List[int]. Number of hidden units in each layer. Default: [128] (one hidden layer with 128 units)
:param n_latent: int. Number of latent dimensions. Default: 10
:param n_batch: int. Number of batch. Default: 0
:param n_label: int. Number of label. Default: 0
:param n_additional_batch: Optional[Iterable[int]]. Number of categorical covariate. Default: None
:param batch_key: str. Batch key in adata.obs. Default: None
:param label_key: str. Label key in adata.obs. Default: None
:param dispersion: Literal["gene", "gene-batch", "gene-cell"]. Dispersion modeling method. Default: "gene-cell"
:param rna_dropout: Literal["gene", "cell"]. RNA dropout modeling method. Default: "gene" models dropout at the gene level. Alternative: "cell" models dropout at the cell level.
:param log_variational: bool. If True, log the variational distribution. Default: True
:param total_variational: bool. If True, normalize the counts with library size. Default: False
:param bias: bool. If True, use bias in the linear layer. Default: True
:param use_batch_norm: bool. If True, use batch normalization. Default: True
:param use_layer_norm: bool. If True, use layer normalization. Default: False
:param batch_hidden_dim: int. Number of hidden units in the batch embedding layer. Default: 8
:param batch_embedding: Literal["embedding", "onehot"]. Batch embedding method. Default: "batch_embedding"
:param constrain_latent_method: Literal['mse', 'normal']. Method to constrain the latent embedding. Default: 'mse'
:param constrain_latent_embedding: bool. If True, constrain the latent embedding. Default: False
:param constrain_latent_key: str. Key to the data to constrain the latent embedding. Default: 'X_gex'
:param encode_libsize: bool. If True, encode the library size. Default: False
:param decode_libsize: bool. If True, decode the library size. Default: True
:param dropout_rate: float. Dropout rate. Default: 0.1
:param activation_fn: nn.Module. Activation function. Default: nn.ReLU
:param inject_batch: bool. If True, inject batch information. Default: True
:param inject_label: bool. If True, inject label information. Default: False
:param inject_additional_batch: bool. If True, inject categorical covariate information. Default: True
:param unlabel_key: str. key for unlabeled cells. Default: "undefined"
:param mmd_key: Optional[Literal['batch']]. If provided, use MMD loss. Default: None (do not use MMD loss)
:param pretrained_state_dict: torch.device or str. Build the model loading the pretrained state dict
:param device: Optional[Union[str, torch.device]]. Device to use. Default: determined by availablility of CUDA device
:example:
>>> import scatlasvae
>>> model = scatlasvae.model.scAtlasVAE(
>>> adata,
>>> batch_key = ['sample_name','study_name'],
>>> label_key = ['cell_type', 'cell_subtype'],
>>> )
"""
def __init__(self, *,
adata: Optional[sc.AnnData] = None,
anndata_tensorstore_path: Optional[str] = None,
anndata_tensorstore_var_names: Optional[Iterable[str]] = None,
use_layer: Optional[str] = None,
hidden_stacks: List[int] = [128],
n_latent: int = 10,
n_batch: int = 0,
n_label: int = 0,
n_additional_batch: Optional[Iterable[int]] = None,
n_additional_label: Optional[Iterable[int]] = None,
batch_key: Union[str, Iterable[str]] = None,
additional_batch_keys: Iterable[str] = None, #TODO: deprecate in the future
label_key: Union[str, Iterable[str]] = None,
additional_label_keys: Iterable[str] = None, #TODO: deprecate in the future
encoder_type: EncoderType = EncoderType.SAE,
dispersion: Literal["gene", "gene-batch", "gene-cell"] = "gene-cell",
rna_dropout: Literal["gene", "cell"] = "gene",
log_variational: bool = True,
total_variational: bool = False,
bias: bool = True,
use_batch_norm: bool = True,
use_layer_norm: bool = False,
batch_hidden_dim: int = 8,
batch_embedding: Literal["embedding", "onehot"] = "embedding",
reconstruction_method: Literal['mse', 'zg', 'zinb', 'nb'] = 'zinb',
constrain_n_label: bool = True,
constrain_n_batch: bool = True,
constrain_latent_method: Literal['mse', 'normal'] = 'mse',
constrain_latent_embedding: bool = False,
constrain_latent_key: str = 'X_gex',
encode_libsize: bool = False,
decode_libsize: bool = True,
dropout_rate: float = 0.1,
activation_fn: nn.Module = nn.ReLU,
inject_batch: bool = True,
inject_label: bool = False,
inject_additional_batch: bool = True,
mmd_key: Optional[Literal['batch','additional_batch','both']] = None,
unlabel_key: str = 'undefined',
device: Optional[Union[str, torch.device]] = None,
pretrained_state_dict: Union[str, Optional[Mapping[str, torch.Tensor]]] = None,
low_memory_initialization: bool = False,
) -> None:
if device is None:
device = get_default_device()
if anndata_tensorstore_path is None and adata is None:
raise ValueError("Please provide either anndata or anndata_tensorstore_path")
elif anndata_tensorstore_path is not None and adata is not None:
raise ValueError("Please provide either anndata or anndata_tensorstore_path, not both")
elif anndata_tensorstore_path is not None:
low_memory_initialization = True
var = pd.read_parquet(os.path.join(anndata_tensorstore_path, ats.ATS_FILE_NAME.var))
obs = pd.read_parquet(os.path.join(anndata_tensorstore_path, ats.ATS_FILE_NAME.obs))
if constrain_latent_key is not None:
obsm = ats._ext.load_np_array_from_tensorstore(
os.path.join(anndata_tensorstore_path, ats.ATS_FILE_NAME.obsm, constrain_latent_key)
)
self.adata = sc.AnnData(
obs=obs,
obsm={
constrain_latent_key: obsm
} if constrain_latent_key is not None else None,
)
else:
self.adata = adata
if adata.is_view:
mw("adata is a view of another AnnData object. \n" + \
" "*40 + "This may cause slower training. \n" + \
" "*40 + "Use adata=adata.copy() to create a new AnnData object."
)
if use_layer is None:
if adata.X.dtype != np.int32 and reconstruction_method in ['zinb', 'nb']:
mw("adata.X is not of type np.int32. \n" + \
" "*40 + "\tCheck whether you are using raw count matrix.")
# adata.X = adata.X.astype(np.int32)
else:
if adata.layers[use_layer].dtype != np.int32 and reconstruction_method in ['zinb', 'nb']:
mw(f"adata.layers[{use_layer}] is not of type np.int32. \n" + \
" "*40 + "\tCheck whether you are using raw count matrix.")
# adata.layers[use_layer] = adata.layers[use_layer].astype(np.int32)
super(scAtlasVAE, self).__init__()
self.anndata_tensorstore_path = anndata_tensorstore_path
self.anndata_tensorstore_var_names = anndata_tensorstore_var_names
self.anndata_tensorstore_var_indices = None
if anndata_tensorstore_var_names is not None:
self.anndata_tensorstore_var_indices = np.argwhere(np.isin(var.index, anndata_tensorstore_var_names)).flatten()
var = var.loc[anndata_tensorstore_var_names]
self.use_layer = use_layer
self.in_dim = adata.shape[1] if adata else var.shape[0]
self.n_hidden = hidden_stacks[-1]
self.n_latent = n_latent
self.n_additional_batch = n_additional_batch
self.n_additional_label = n_additional_label
self._hidden_stacks = hidden_stacks
self.encoder_type = encoder_type
if n_batch > 0 and not batch_key:
raise ValueError("Please provide a batch key if n_batch is greater than 0")
if n_label > 0 and not label_key:
raise ValueError("Please provide a label key if n_batch is greater than 0")
self.label_key = label_key if isinstance(label_key, str) else label_key[0] if label_key is not None and isinstance(label_key, Iterable) else None
self.label_category = None
self.label_category_summary = None
self.batch_key = batch_key if isinstance(batch_key, str) else batch_key[0] if batch_key is not None and isinstance(batch_key, Iterable) else None
self.batch_category = None
self.batch_category_summary = None
if additional_batch_keys is None:
self.additional_batch_keys = None if isinstance(batch_key, str) or (isinstance(batch_key, Iterable) and len(batch_key) == 1) else batch_key[1:] if batch_key is not None else None
else:
#TODO: deprecate in the future
mw("additional_batch_keys is going to be deprecated. Use batch_key as a List instead.")
self.additional_batch_keys = additional_batch_keys
self.additional_batch_category = None
self.additional_batch_category_summary = None
if additional_label_keys is None:
self.additional_label_keys = None if isinstance(label_key, str) or (isinstance(label_key, Iterable) and len(label_key) == 1) else label_key[1:] if label_key is not None else None
else:
#TODO: deprecate in the future
mw("additional_label_keys is going to be deprecated. Use label_key as a List instead.")
self.additional_label_keys = additional_label_keys
self.additional_label_category = None
self.additional_label_category_summary = None
self.n_batch = n_batch
self.n_label = n_label
self.unlabel_key = unlabel_key
# Patch fix for the unlabel_key, since we are using the first unicode character and
# assure that the character is the last unicode character all labels
all_label_keys = []
if self.label_key is not None and unlabel_key in set(self.adata.obs[self.label_key]):
all_label_keys = list(set(self.adata.obs[self.label_key]))
all_label_keys.remove(unlabel_key)
if self.additional_label_keys is not None:
for k in self.additional_label_keys:
if unlabel_key in set(self.adata.obs[k]):
all_label_keys += list(set(self.adata.obs[k]))
all_label_keys.remove(unlabel_key)
if len(all_label_keys) > 0:
last_unicode = sorted(all_label_keys)[-1][0]
if ord(last_unicode) > ord(unlabel_key[0]):
mw(f"unlabel_key is set to {unlabel_key}")
self.unlabel_key = next_unicode_char(last_unicode) + '-' + self.unlabel_key
self.adata.obs[label_key] = self.adata.obs[label_key].replace(unlabel_key, self.unlabel_key)
if additional_label_keys is not None:
for k in additional_label_keys:
self.adata.obs[k] = self.adata.obs[k].replace(unlabel_key, self.unlabel_key)
self.new_adata_code = None
self.log_variational = log_variational
self.total_variational = total_variational
self.mmd_key = mmd_key
self.reconstruction_method = reconstruction_method
self.constrain_latent_embedding = constrain_latent_embedding
self.constrain_latent_method = constrain_latent_method
self.constrain_latent_key = constrain_latent_key
self.constrain_n_label = constrain_n_label
self.constrain_n_batch = constrain_n_batch
self.low_memory_initialization = low_memory_initialization
if self.low_memory_initialization and self.anndata_tensorstore_path is None:
mw(
"low_memory_initialization is set to True. \n" + \
" "*40 + "This will reduce the memory usage during initialization,\n" + \
" "*40 + "but may significantly slow down the training and \n" + \
" "*40 + "not fully tested for all functionalities."
)
self.device=device
self.initialize_dataset()
self.batch_embedding = batch_embedding
if batch_embedding == "onehot":
batch_hidden_dim = self.n_batch
self.batch_hidden_dim = batch_hidden_dim
self.inject_batch = inject_batch
self.inject_label = inject_label
self.inject_additional_batch = inject_additional_batch
self.encode_libsize = encode_libsize
self.decode_libsize = decode_libsize
self.dispersion = dispersion
self.rna_dropout = rna_dropout
self.fcargs = dict(
bias = bias,
dropout_rate = dropout_rate,
use_batch_norm = use_batch_norm,
use_layer_norm = use_layer_norm,
activation_fn = activation_fn,
device = device
)
#############################
# Model Trainable Variables #
#############################
if self.dispersion == "gene":
self.px_rate = torch.nn.Parameter(torch.randn(self.in_dim))
elif self.dispersion == "gene-batch":
self.px_rate = torch.nn.Parameter(torch.randn(self.in_dim, self.n_batch))
elif self.dispersion == "gene-cell":
pass
else:
raise ValueError(
"dispersion must be one of ['gene', 'gene-batch',"
" 'gene-cell'], but input was "
"{}.format(self.dispersion)"
)
############
# ENCODERS #
############
if self.encoder_type == EncoderType.SAE:
self.encoder = SAE(
self.in_dim if not self.encode_libsize else self.in_dim + 1,
stacks = hidden_stacks,
# n_cat_list = [self.n_batch] if self.n_batch > 0 else None,
cat_dim = batch_hidden_dim,
cat_embedding = batch_embedding,
encode_only = True,
**self.fcargs
)
elif self.encoder_type == EncoderType.TABNET:
self.encoder = TabNetEncoder(
input_dim = self.in_dim,
output_dim = self.n_hidden,
n_d = self.n_hidden,
)
# The latent cell representation z ~ Logisticnormal(0, I)
self.z_mean_fc = nn.Linear(self.n_hidden, self.n_latent)
self.z_var_fc = nn.Linear(self.n_hidden, self.n_latent)
self.z_transformation = nn.Softmax(dim=-1)
############
# DECODERS #
############
if self.n_additional_batch_ is not None and self.inject_additional_batch:
if self.n_batch > 0 and self.n_label > 0 and inject_batch and inject_label:
decoder_n_cat_list = [self.n_batch, self.n_label, *self.n_additional_batch]
elif self.n_batch > 0 and inject_batch:
decoder_n_cat_list = [self.n_batch, *self.n_additional_batch]
elif self.n_label > 0 and inject_label:
decoder_n_cat_list = [self.n_label, *self.n_additional_batch]
else:
decoder_n_cat_list = None
else:
if self.n_batch > 0 and self.n_label > 0 and inject_batch and inject_label:
decoder_n_cat_list = [self.n_batch, self.n_label]
elif self.n_batch > 0 and inject_batch:
decoder_n_cat_list = [self.n_batch]
elif self.n_label > 0 and inject_label:
decoder_n_cat_list = [self.n_label]
else:
decoder_n_cat_list = None
self.decoder_n_cat_list = decoder_n_cat_list
self.decoder = FCLayer(
in_dim = self.n_latent,
out_dim = self.n_hidden,
n_cat_list = decoder_n_cat_list,
cat_dim = batch_hidden_dim,
cat_embedding = batch_embedding,
use_layer_norm=False,
use_batch_norm=True,
dropout_rate=0,
device=device
)
self.px_rna_rate_decoder = nn.Linear(self.n_hidden, self.in_dim)
self.px_rna_scale_decoder = nn.Sequential(
nn.Linear(self.n_hidden, self.in_dim),
nn.Softmax(dim=-1)
)
if self.rna_dropout == "gene":
self.px_rna_dropout_decoder = Linear(self.n_hidden, self.in_dim, init='final')
elif self.rna_dropout == "cell":
self.px_rna_dropout_decoder = Linear(self.n_hidden, 1, init='final')
if self.n_label > 0:
self.fc = nn.Sequential(
nn.Linear(self.n_latent, self.n_label)
)
if self.n_additional_label is not None:
self.additional_fc = nn.ModuleList([
nn.Linear(self.n_latent, x) for x in self.n_additional_label
])
self._trained = False
self.to(device)
if pretrained_state_dict is not None:
if isinstance(pretrained_state_dict, str):
pretrained_state_dict = torch.load(pretrained_state_dict)['model_state_dict']
self.partial_load_state_dict(pretrained_state_dict)
def __repr__(self):
return f'{Colors.ORANGE}VAEModel{Colors.NC} object containing:\n' + \
f' {Colors.GREEN}adata{Colors.NC}: {self.adata}\n' + \
f' {Colors.GREEN}in_dim{Colors.NC}: {Colors.CYAN}{self.in_dim}{Colors.NC}\n' + \
f' {Colors.GREEN}n_hidden{Colors.NC}: {Colors.CYAN}{self.n_hidden}{Colors.NC}\n' + \
f' {Colors.GREEN}labels{Colors.NC}: {self.label_key} of {Colors.CYAN}{self.n_label}{Colors.NC}\n' if self.batch_key else '' + \
f' {Colors.GREEN}batches{Colors.NC}: {self.batch_key} of {Colors.CYAN}{self.n_batch}{Colors.NC}\n' if self.label_key else '' + \
f' {Colors.GREEN}additional_batches{Colors.NC}: {self.additional_batch_keys} of {Colors.CYAN}{self.n_additional_batch}{Colors.NC}\n' if self.additional_batch_keys else ''
[docs]
def partial_load_state_dict(self, state_dict: Mapping[str, torch.Tensor]):
"""
Partially load the state dict
:param state_dict: Mapping[str, torch.Tensor]. State dict to load
"""
original_state_dict = self.state_dict()
warned = False
ignored_keys = {}
for k,v in state_dict.items():
if k not in original_state_dict.keys():
mt(f"Warning: {k} not found in the model. Ignoring {k} in the provided state dict.")
ignored_keys[k] = v
elif v.shape != original_state_dict[k].shape:
mw(f"Warning: shape of {k} does not match. \n" + \
' '*40 + "\tOriginal:" + f" {original_state_dict[k].shape},\n" + \
' '*40 + f"\tNew: {v.shape}")
state_dict[k] = original_state_dict[k]
for k,v in original_state_dict.items():
if k not in state_dict.keys():
mw(f"Warning: {k} not found in the provided state dict. " + \
f"Using {k} in the original state dict.")
state_dict[k] = v
for i in ignored_keys:
state_dict.pop(i)
self.load_state_dict(state_dict)
for k,v in ignored_keys.items():
state_dict[k] = v
[docs]
def get_config(self):
"""
Get the model config
:return: dict. Model config dictionary
"""
return {
'hidden_stacks': self._hidden_stacks,
'n_latent': self.n_latent,
'n_batch': self.n_batch,
'n_label': self.n_label,
'n_additional_batch': self.n_additional_batch,
'n_additional_label': self.n_additional_label,
'batch_key': self.batch_key if self.additional_batch_keys is None else [self.batch_key] + self.additional_batch_keys,
'label_key': self.label_key if self.additional_label_keys is None else [self.label_key] + self.additional_label_keys,
'dispersion': self.dispersion,
'log_variational': self.log_variational,
'bias': self.fcargs['bias'],
'use_batch_norm': self.fcargs['use_batch_norm'],
'use_layer_norm': self.fcargs['use_layer_norm'],
'batch_hidden_dim': self.batch_hidden_dim,
'batch_embedding': self.batch_embedding,
'reconstruction_method': self.reconstruction_method,
'encode_libsize': self.encode_libsize,
'decode_libsize': self.decode_libsize,
'dropout_rate': self.fcargs['dropout_rate'],
'activation_fn': self.fcargs['activation_fn'],
'inject_batch': self.inject_batch,
'inject_label': self.inject_label,
'inject_additional_batch': self.inject_additional_batch,
'mmd_key': self.mmd_key,
'unlabel_key': self.unlabel_key,
}
[docs]
def save_to_disk(self, path_to_state_dict: Union[str, Path]):
"""
Save the model to disk
:param path_to_state_dict: str or Path. Path to save the model
"""
model_state_dict = self.state_dict()
model_var_index = self.adata.var.index
state_dict = {
"model_state_dict": model_state_dict,
"model_var_index": model_var_index,
"model_config": self.get_config(),
"batch_category": self.batch_category,
"batch_category_summary": self.batch_category_summary,
"label_category": self.label_category,
"label_category_summary": self.label_category_summary,
"additional_label_category": self.additional_label_category,
"additional_label_category_summary": self.additional_label_category_summary,
"additional_batch_category": self.additional_batch_category,
"additional_batch_category_summary": self.additional_batch_category_summary,
}
torch.save(state_dict, path_to_state_dict)
[docs]
def load_from_disk(self, path_to_state_dict: Union[str, Path]):
"""
Load the model from disk
:param path_to_state_dict: str or Path. Path to load the model
"""
state_dict = torch.load(path_to_state_dict)
self.partial_load_state_dict(state_dict["model_state_dict"])
[docs]
@staticmethod
def setup_anndata(
adata: sc.AnnData,
path_to_state_dict: Union[str, Path],
unlabel_key: str = 'undefined'
):
"""
Setup the model with adata
:param adata: AnnData. AnnData to setup the model
:param path_to_state_dict: Optional[str, Path]. Path to the state dict to load
:param unlabeled_key: str. Default: Undefined
"""
state_dict = torch.load(path_to_state_dict)
if 'model_var_index' in state_dict.keys():
model_var_index = state_dict['model_var_index']
if any(list(map(lambda x: x not in model_var_index, adata.var.index))) or any(list(map(lambda x: x not in adata.var.index, model_var_index))):
mw("the provided adata contains variables not in the state dict.")
mt(" The model will be initialized with the variables in the state dict.")
adata = subset_adata_by_genes_fill_zeros(adata, list(model_var_index))
if state_dict["batch_category"] is not None:
batch_key = state_dict['model_config']['batch_key'] \
if type(state_dict['model_config']['batch_key']) == str \
else state_dict['model_config']['batch_key'][0]
if batch_key not in adata.obs.keys():
adata.obs[batch_key] = unlabel_key
adata.obs[batch_key] = pd.Categorical(
list(adata.obs[batch_key] ),
categories=pd.Categorical(
state_dict["batch_category"].categories
).add_categories(unlabel_key).categories
)
else:
adata.obs[batch_key] = pd.Categorical(
list(pd.Series(list(adata.obs[batch_key])).fillna(unlabel_key)),
categories=pd.Categorical(
state_dict["batch_category"].categories
).add_categories(
list(np.unique(pd.Series(list(adata.obs[batch_key])).fillna(unlabel_key)))
)
)
if state_dict["label_category"] is not None:
label_key = state_dict['model_config']['label_key'] \
if type(state_dict['model_config']['label_key']) == str \
else state_dict['model_config']['label_key'][0]
if label_key not in adata.obs.keys():
adata.obs[label_key] = pd.Categorical(
[unlabel_key] * adata.shape[0],
categories = pd.Categorical(
state_dict["label_category"].categories
).add_categories(unlabel_key).categories
)
else:
adata.obs[label_key] = list(adata.obs[label_key])
adata.obs[label_key] = pd.Categorical(
list( adata.obs[label_key].fillna(unlabel_key) ),
categories = pd.Categorical(
state_dict["label_category"].categories
).add_categories(
list(np.unique(pd.Series(list(adata.obs[label_key])).fillna(unlabel_key)))
)
)
if state_dict["additional_batch_category"] is not None:
if isinstance(state_dict['model_config']['batch_key'], list):
additional_batch_keys = state_dict['model_config']['batch_key'][1:]
else:
additional_batch_keys = state_dict['model_config']['additional_batch_keys']
for i,k in enumerate(additional_batch_keys):
if k not in adata.obs.keys():
adata.obs[k] = unlabel_key
adata.obs[k] = pd.Categorical(
list( adata.obs[k] ),
categories = pd.Categorical(
state_dict["additional_batch_category"][i].categories
).add_categories(unlabel_key).categories
)
else:
adata.obs[k] = list(adata.obs[k])
adata.obs[k] = pd.Categorical(
list(pd.Series(list(adata.obs[k])).fillna(unlabel_key)),
categories = pd.Categorical(
state_dict["additional_batch_category"][i].categories
).add_categories(
list(np.unique(pd.Series(list(adata.obs[k])).fillna(unlabel_key)))
)
)
if state_dict["additional_label_category"] is not None:
if isinstance(state_dict['model_config']['label_key'], list):
additional_label_keys = state_dict['model_config']['label_key'][1:]
else:
additional_label_keys = state_dict['model_config']['additional_label_keys']
for i,k in enumerate(additional_label_keys):
if k not in adata.obs.keys():
adata.obs[k] = unlabel_key
adata.obs[k] = pd.Categorical(
list(adata.obs[k] ),
categories = pd.Categorical(
state_dict["additional_label_category"][i].categories
).add_categories(unlabel_key).categories
)
else:
adata.obs[k] = list(adata.obs[k])
adata.obs[k] = pd.Categorical(
list(adata.obs[k].fillna(unlabel_key) ),
categories= pd.Categorical(
state_dict["additional_label_category"][i].categories
).add_categories(
list(np.unique(pd.Series(list(adata.obs[k])).fillna(unlabel_key)))
)
)
return adata
def initialize_dataset(self):
mt("Initializing dataset into memory")
if self.batch_key is not None:
n_batch_ = len(np.unique(self.adata.obs[self.batch_key]))
if self.n_batch != n_batch_:
mt(f"warning: the provided n_batch={self.n_batch} does not match the number of batch in the adata.")
if self.constrain_n_batch:
mt(f" setting n_batch to {n_batch_}")
self.n_batch = n_batch_
if not (isinstance(self.adata.obs[self.batch_key], pd.Categorical) or hasattr(self.adata.obs[self.batch_key], 'cat')):
self.adata.obs[self.batch_key] = pd.Categorical(self.adata.obs[self.batch_key])
self.batch_category = pd.Categorical(self.adata.obs[self.batch_key])
self.batch_category_summary = dict(Counter(self.batch_category))
for k in self.batch_category.categories:
if k not in self.batch_category_summary.keys():
self.batch_category_summary[k] = 0
if self.label_key is not None:
n_label_ = len(np.unique(list(filter(lambda x: x != self.unlabel_key, pd.Categorical(self.adata.obs[self.label_key]).categories))))
if self.n_label != n_label_:
mt(f"warning: the provided n_label={self.n_label} does not match the number of label in the adata.")
if self.constrain_n_label:
mt(f" setting n_label to {n_label_}")
self.n_label = n_label_
if not (isinstance(self.adata.obs[self.label_key], pd.Categorical) or hasattr(self.adata.obs[self.label_key], 'cat')):
self.adata.obs[self.label_key] = list(self.adata.obs[self.label_key])
self.adata.obs[self.label_key] = pd.Categorical(self.adata.obs[self.label_key].fillna(self.unlabel_key))
if isinstance(self.adata.obs[self.label_key], pd.Categorical):
self.label_category = self.adata.obs[self.label_key]
elif isinstance(self.adata.obs[self.label_key].cat, pd.core.arrays.categorical.CategoricalAccessor):
cat = self.adata.obs[self.label_key].cat
self.label_category = pd.Categorical([cat.categories[x] for x in cat.codes], categories=cat.categories)
else:
self.label_category = pd.Categorical(list(self.adata.obs[self.label_key]))
self.label_category_summary = dict(Counter(list(filter(lambda x: x != self.unlabel_key, self.label_category))))
for k in self.label_category.categories:
if k not in self.label_category_summary.keys() and k != self.unlabel_key:
self.label_category_summary[k] = 0
self.label_category_weight = len(self.label_category) / torch.tensor([
self.label_category_summary[x] for x in list(filter(lambda x:
x != self.unlabel_key,
self.label_category.categories
))], dtype=torch.float64).to(self.device)
if self.unlabel_key in self.label_category.categories:
self.new_adata_code = list(self.label_category.categories).index(self.unlabel_key)
self.n_additional_label_ = None
self.additional_label_category = None
self.additional_label_category_summary = None
if self.additional_label_keys is not None:
self.n_cell_additional_label = [len(list(filter(lambda x: x != self.unlabel_key,self.adata.obs[x]))) for x in [self.label_key] + self.additional_label_keys]
self.n_additional_label_ = [len(np.unique(list(filter(lambda x: x != self.unlabel_key,pd.Categorical(self.adata.obs[x]).categories)))) for x in self.additional_label_keys]
# self.additional_label_weight = sum(self.n_cell_additional_label) / torch.tensor(self.n_cell_additional_label)
self.additional_label_weight = torch.tensor([1] * len(self.n_cell_additional_label), dtype=torch.float64).to(self.device)
if self.n_additional_label == None or len(self.n_additional_label_) != len(self.n_additional_label):
mt(f"warning: the provided n_additional_label={self.n_additional_label} does not match the number of additional label in the adata.")
if self.constrain_n_label:
mt(f" setting n_additional_label to {self.n_additional_label_}")
self.n_additional_label = self.n_additional_label_
else:
for e,(i,j) in enumerate(zip(self.n_additional_label_, self.n_additional_label)):
if i != j:
mt(f"n_additional_label {self.additional_label_keys[e]} does not match the number in the adata.")
if self.constrain_n_label:
mt(f" setting n_additional_label {e} to {i}")
self.n_additional_label[e] = i
def get_category(x):
if isinstance(x, pd.Categorical):
return x
elif isinstance(x.cat, pd.core.arrays.categorical.CategoricalAccessor):
cat = x.cat
return pd.Categorical([cat.categories[x] for x in cat.codes], categories=cat.categories)
else:
return pd.Categorical(list(x))
self.additional_label_category = [
get_category(self.adata.obs[x])
for x in self.additional_label_keys
]
self.additional_label_category_summary = [dict(Counter(x)) for x in self.additional_label_category]
for i in range(len(self.additional_label_category_summary)):
for k in self.additional_label_category[i].categories:
if k not in self.additional_label_category_summary[i].keys() and k != self.unlabel_key:
self.additional_label_category_summary[i][k] = 0
self.additional_label_category_weight = [len(label_category) / torch.tensor([
self.additional_label_category_summary[e][x] for x in list(filter(lambda x:
x != self.unlabel_key,
label_category.categories
))], dtype=torch.float64).to(self.device) for e,label_category in enumerate(self.additional_label_category)]
self.additional_new_adata_code = [list(x.categories).index(self.unlabel_key) if self.unlabel_key in x.categories else -1 for x in self.additional_label_category]
self.n_additional_batch_ = None
if self.additional_batch_keys is not None:
self.n_additional_batch_ = [len(np.unique(self.adata.obs[x])) for x in self.additional_batch_keys]
if self.n_additional_batch == None or len(self.n_additional_batch_) != len(self.n_additional_batch):
mt(f"warning: the provided n_additional_batch={self.n_additional_batch} does not match the number of categorical covariate in the adata.")
if self.constrain_n_batch:
mt(f" setting n_additional_batch to {self.n_additional_batch_}")
self.n_additional_batch = self.n_additional_batch_
else:
for e,(i,j) in enumerate(zip(self.n_additional_batch_, self.n_additional_batch)):
if i != j:
mt(f"n_additional_batch {self.additional_batch_keys[e]} does not match the number in the adata.")
if self.constrain_n_batch:
mt(f" setting n_additional_batch {e} to {i}")
self.n_additional_batch[e] = i
self.additional_batch_category = [pd.Categorical(self.adata.obs[x]) for x in self.additional_batch_keys]
self.additional_batch_category_summary = [dict(Counter(x)) for x in self.additional_batch_category]
for i in range(len(self.additional_batch_category_summary)):
for k in self.additional_batch_category[i].categories:
if k not in self.additional_batch_category_summary[i].keys():
self.additional_batch_category_summary[i][k] = 0
self._n_record = self.adata.shape[0]
self._indices = np.array(list(range(self._n_record)))
batch_categories, label_categories = None, None
additional_label_categories = None
additional_batch_categories = None
if self.batch_key is not None:
if self.batch_key not in self.adata.obs.columns:
raise ValueError(f"batch_key {self.batch_key} is not found in AnnData obs")
batch_categories = np.array(self.batch_category.codes)
if self.label_key is not None:
if self.label_key not in self.adata.obs.columns:
raise ValueError(f"label_key {self.label_key} is not found in AnnData obs")
label_categories = np.array(self.label_category.codes)
if self.additional_label_keys is not None:
for e,i in enumerate(self.additional_label_keys):
if i not in self.adata.obs.columns:
raise ValueError(f"additional_label_keys {i} is not found in AnnData obs")
additional_label_categories = [np.array(x.codes) for x in self.additional_label_category]
if self.additional_batch_keys is not None:
for e,i in enumerate(self.additional_batch_keys):
if i not in self.adata.obs.columns:
raise ValueError(f"additional_batch_keys {i} is not found in AnnData obs")
additional_batch_categories = [np.array(x.codes) for x in self.additional_batch_category]
if self.low_memory_initialization:
if self.constrain_latent_embedding and self.constrain_latent_key in self.adata.obsm.keys():
P = self.adata.obsm[self.constrain_latent_key]
if additional_batch_categories is not None:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(P, batch_categories, label_categories, *additional_label_categories, *additional_batch_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(P, batch_categories, label_categories, *additional_batch_categories))
elif batch_categories is not None:
_dataset = list(zip(P, batch_categories, *additional_batch_categories))
elif label_categories is not None:
_dataset = list(zip(P, label_categories, *additional_batch_categories))
else:
_dataset = list(zip(P, *additional_batch_categories))
else:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(P, batch_categories, label_categories, *additional_label_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(P, batch_categories, label_categories))
elif batch_categories is not None:
_dataset = list(zip(P, batch_categories))
elif label_categories is not None:
_dataset = list(zip(P, label_categories))
else:
_dataset = list(zip(P))
else:
if additional_batch_categories is not None:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(batch_categories, label_categories, *additional_label_categories, *additional_batch_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(batch_categories, label_categories, *additional_batch_categories))
elif batch_categories is not None:
_dataset = list(zip(batch_categories, *additional_batch_categories))
elif label_categories is not None:
_dataset = list(zip(label_categories, *additional_batch_categories))
else:
_dataset = list(zip(*additional_batch_categories))
else:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(batch_categories, label_categories, *additional_label_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(batch_categories, label_categories))
elif batch_categories is not None:
_dataset = list(zip(batch_categories))
elif label_categories is not None:
_dataset = list(zip(label_categories))
else:
_dataset = list(np.arange(self._n_record))
else:
if self.use_layer is None:
X = self.adata.X
else:
X = self.adata.layers[self.use_layer]
if self.constrain_latent_embedding and self.constrain_latent_key in self.adata.obsm.keys():
P = self.adata.obsm[self.constrain_latent_key]
if additional_batch_categories is not None:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(X, P, batch_categories, label_categories, *additional_label_categories, *additional_batch_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(X, P, batch_categories, label_categories, *additional_batch_categories))
elif batch_categories is not None:
_dataset = list(zip(X, P, batch_categories, *additional_batch_categories))
elif label_categories is not None:
_dataset = list(zip(X, P, label_categories, *additional_batch_categories))
else:
_dataset = list(zip(X, P, *additional_batch_categories))
else:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(X, P, batch_categories, label_categories, *additional_label_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(X, P, batch_categories, label_categories))
elif batch_categories is not None:
_dataset = list(zip(X, P, batch_categories))
elif label_categories is not None:
_dataset = list(zip(X, P, label_categories))
else:
_dataset = list(zip(X, P))
else:
if additional_batch_categories is not None:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(X, batch_categories, label_categories, *additional_label_categories, *additional_batch_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(X, batch_categories, label_categories, *additional_batch_categories))
elif batch_categories is not None:
_dataset = list(zip(X, batch_categories, *additional_batch_categories))
elif label_categories is not None:
_dataset = list(zip(X, label_categories, *additional_batch_categories))
else:
_dataset = list(zip(X, *additional_batch_categories))
else:
if batch_categories is not None and label_categories is not None and additional_label_categories is not None:
_dataset = list(zip(X, batch_categories, label_categories, *additional_label_categories))
elif batch_categories is not None and label_categories is not None:
_dataset = list(zip(X, batch_categories, label_categories))
elif batch_categories is not None:
_dataset = list(zip(X, batch_categories))
elif label_categories is not None:
_dataset = list(zip(X, label_categories))
else:
_dataset = list(X)
_shuffle_indices = list(range(len(_dataset)))
np.random.shuffle(_shuffle_indices)
self._dataset = np.array([_dataset[i] for i in _shuffle_indices], dtype=object)
self._shuffle_indices = np.array(
[x for x, _ in sorted(zip(range(len(_dataset)), _shuffle_indices), key=lambda x: x[1])]
)
self._shuffled_indices_inverse = np.array(_shuffle_indices)
mt("Finished initializing dataset into memory")
def as_dataloader(
self,
subset_indices: Union[torch.tensor, np.ndarray] = None,
n_per_batch: int = 128,
train_test_split: bool = False,
random_seed: bool = 42,
validation_split: bool = .2,
shuffle: bool = True,
):
indices = subset_indices if subset_indices is not None else self._indices
np.random.seed(random_seed)
if shuffle:
np.random.shuffle(indices)
if train_test_split:
split = int(np.floor(validation_split * self._n_record))
if split % n_per_batch == 1:
n_per_batch -= 1
elif (self._n_record - split) % n_per_batch == 1:
n_per_batch += 1
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
return DataLoader(indices, n_per_batch, sampler = train_sampler), DataLoader(indices, n_per_batch, sampler = valid_sampler)
if len(indices) % n_per_batch == 1:
n_per_batch -= 1
return DataLoader(indices, n_per_batch, shuffle = shuffle)
def _normalize_data(self, X, after=None, copy=True):
X = X.clone() if copy else X
X = X.to(torch.float32) # Check if torch.float64 should be used
counts = X.sum(axis=1)
counts_greater_than_zero = counts[counts > 0]
after = torch.median(counts_greater_than_zero, dim=0).values if after is None else after
counts += counts == 0
counts = counts / after
X /= counts.unsqueeze(1)
return X
def encode(self, X: torch.Tensor, batch_index: torch.Tensor = None, eps: float = 1e-4):
# Encode for hidden space
# if batch_index is not None and self.inject_batch:
# X = torch.hstack([X, batch_index])
libsize = torch.log(X.sum(1))
if self.reconstruction_method == 'zinb' or self.reconstruction_method == 'nb':
if self.total_variational:
X = self._normalize_data(X, after=1e4, copy=True)
if self.log_variational:
X = torch.log(1+X)
if self.encoder_type == EncoderType.SAE:
q = self.encoder.encode(torch.hstack([X,libsize.unsqueeze(1)])) if self.encode_libsize else self.encoder.encode(X)
elif self.encoder_type == EncoderType.TABNET:
steps_output, M_loss = self.encoder(torch.hstack([X,libsize.unsqueeze(1)])) if self.encode_libsize else self.encoder(X)
q = torch.sum(torch.stack(steps_output, dim=0), dim=0)
q_mu = self.z_mean_fc(q)
q_var = torch.exp(self.z_var_fc(q)) + eps
z = Normal(q_mu, q_var.sqrt()).rsample()
H = dict(
q = q,
q_mu = q_mu,
q_var = q_var,
z = z
)
return H
def decode(self,
H: Mapping[str, torch.tensor],
lib_size:torch.tensor,
batch_index: torch.Tensor = None,
label_index: torch.Tensor = None,
additional_batch_index: torch.Tensor = None,
eps: float = 1e-4
):
z = H["z"] # cell latent representation
if additional_batch_index is not None and self.inject_additional_batch:
if batch_index is not None and label_index is not None and self.inject_batch and self.inject_label:
z = torch.hstack([z, batch_index, label_index, *additional_batch_index])
elif batch_index is not None and self.inject_batch:
z = torch.hstack([z, batch_index, *additional_batch_index])
elif label_index is not None and self.inject_label:
z = torch.hstack([z, label_index, *additional_batch_index])
else:
if batch_index is not None and label_index is not None and self.inject_batch and self.inject_label:
z = torch.hstack([z, batch_index, label_index])
elif batch_index is not None and self.inject_batch:
z = torch.hstack([z, batch_index])
elif label_index is not None and self.inject_label:
z = torch.hstack([z, label_index])
# eps to prevent numerical overflow and NaN gradient
px = self.decoder(z)
h = None
px_rna_scale = self.px_rna_scale_decoder(px)
if self.decode_libsize and not self.reconstruction_method == 'mse':
px_rna_scale_final = px_rna_scale * lib_size.unsqueeze(1)
elif self.reconstruction_method == 'mse':
px_rna_scale_final = torch.log(px_rna_scale * 1e4 + 1)
else:
px_rna_scale_final = px_rna_scale
if self.dispersion == "gene-cell":
px_rna_rate = self.px_rna_rate_decoder(px) ## In logits
elif self.dispersion == "gene-batch":
px_rna_rate = F.linear(one_hot(batch_index, self.n_batch), self.px_rate)
elif self.dispersion == "gene":
px_rna_rate = self.px_rate
px_rna_dropout = self.px_rna_dropout_decoder(px) ## In logits
R = dict(
h = h,
px = px,
px_rna_scale_orig = px_rna_scale,
px_rna_scale = px_rna_scale_final,
px_rna_rate = px_rna_rate,
px_rna_dropout = px_rna_dropout
)
return R
def forward(
self,
X: torch.Tensor,
lib_size: torch.Tensor,
batch_index: torch.Tensor = None,
label_index: torch.Tensor = None,
additional_label_index: torch.Tensor = None,
additional_batch_index: torch.Tensor = None,
P: torch.Tensor = None,
reduction: str = "sum",
compute_mmd: bool = False
):
H = self.encode(X, batch_index)
q_mu = H["q_mu"]
q_var = H["q_var"]
mean = torch.zeros_like(q_mu)
scale = torch.ones_like(q_var)
kldiv_loss = kld(Normal(q_mu, q_var.sqrt()), Normal(mean, scale)).sum(dim = 1)
prediction_loss = torch.tensor(0., device=self.device)
additional_prediction_loss = torch.tensor(0., device=self.device)
R = self.decode(H, lib_size, batch_index, label_index, additional_batch_index)
if self.reconstruction_method == 'zinb':
reconstruction_loss = LossFunction.zinb_reconstruction_loss(
X,
mu = R['px_rna_scale'],
theta = R['px_rna_rate'].exp(),
gate_logits = R['px_rna_dropout'],
reduction = reduction
)
elif self.reconstruction_method == 'nb':
reconstruction_loss = LossFunction.nb_reconstruction_loss(
X,
mu = R['px_rna_scale'],
theta = R['px_rna_rate'].exp(),
reduction = reduction
)
elif self.reconstruction_method == 'zg':
reconstruction_loss = LossFunction.zi_gaussian_reconstruction_loss(
X,
mean=R['px_rna_scale'],
variance=R['px_rna_rate'].exp(),
gate_logits=R['px_rna_dropout'],
reduction=reduction
)
elif self.reconstruction_method == 'mse':
X_norm = self._normalize_data(X, after=1e4)
X_norm = torch.log(X_norm + 1)
reconstruction_loss = nn.functional.mse_loss(
X_norm,
R['px_rna_scale'],
reduction=reduction
)
else:
raise ValueError(f"reconstruction_method {self.reconstruction_method} is not supported")
if self.n_label > 0:
criterion = nn.CrossEntropyLoss(weight=self.label_category_weight)
prediction = self.fc(H['z'])
if self.new_adata_code and self.new_adata_code in label_index:
prediction_index = (label_index != self.new_adata_code).squeeze()
prediction_loss = criterion(prediction[prediction_index], one_hot(label_index[prediction_index], self.n_label))
else:
prediction_loss = criterion(prediction, one_hot(label_index, self.n_label))
if self.n_additional_label is not None:
prediction_loss = prediction_loss * self.additional_label_weight[0]
for e,i in enumerate(self.n_additional_label):
criterion = nn.CrossEntropyLoss(weight=self.additional_label_category_weight[e])
additional_prediction = self.additional_fc[e](H['z'])
if self.additional_new_adata_code[e] and self.additional_new_adata_code[e] in additional_label_index[e]:
additional_prediction_index = (additional_label_index[e] != self.additional_new_adata_code[e]).squeeze()
additional_prediction_loss += criterion(
additional_prediction[additional_prediction_index],
one_hot(additional_label_index[e][additional_prediction_index], i) * self.additional_label_weight[e+1]
)
else:
additional_prediction_loss += criterion(additional_prediction, one_hot(additional_label_index[e], i)) * self.additional_label_weight[e+1]
latent_constrain = torch.tensor(0.)
if self.constrain_latent_embedding and P is not None:
# Constrains on cells with no PCA information will be ignored
latent_constrain_mask = P.mean(1) != 0
if self.constrain_latent_method == 'mse':
latent_constrain = (
nn.MSELoss(reduction='none')(P, q_mu).sum(1) * latent_constrain_mask
).sum() / len(list(filter(lambda x: x != 0, P.detach().cpu().numpy().mean(1))))
elif self.constrain_latent_method == 'normal':
latent_constrain = (
kld(Normal(q_mu, q_var.sqrt()), Normal(P, torch.ones_like(P))).sum(1) * latent_constrain_mask
).sum() / len(list(filter(lambda x: x != 0, P.detach().cpu().numpy().mean(1))))
mmd_loss = torch.tensor(0.)
if self.mmd_key is not None and compute_mmd:
if self.mmd_key == 'batch':
mmd_loss = self.mmd_loss(
H['q_mu'],
batch_index.detach().cpu().numpy(),
dim=1
)
elif self.mmd_key == 'additional_batch':
for i in range(len(self.additional_batch_keys)):
mmd_loss += self.mmd_loss(
H['q_mu'],
additional_batch_index[i].detach().cpu().numpy(),
dim=1
)
elif self.mmd_key == 'both':
mmd_loss = self.mmd_loss(
H['q_mu'],
batch_index.detach().cpu().numpy(),
dim=1
)
for i in range(len(self.additional_batch_keys)):
mmd_loss += self.hierarchical_mmd_loss_2(
H['q_mu'],
batch_index.detach().cpu().numpy(),
additional_batch_index[i].detach().cpu().numpy(),
dim=1
)
loss_record = {
"reconstruction_loss": reconstruction_loss,
"prediction_loss": prediction_loss,
"additional_prediction_loss": additional_prediction_loss,
"kldiv_loss": kldiv_loss,
"mmd_loss": mmd_loss,
"latent_constrain_loss": latent_constrain
}
return H, R, loss_record
def calculate_metric(self, X_test, kl_weight, pred_weight, mmd_weight, reconstruction_reduction):
epoch_total_loss = 0
epoch_reconstruction_loss = 0
epoch_kldiv_loss = 0
epoch_prediction_loss = 0
epoch_mmd_loss = 0
b = 0
X_test = list(X_test)
with torch.no_grad():
with ThreadPoolExecutor(max_workers=1) as executor:
future = None
for b, batch_indices in enumerate(X_test):
if future is None:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = self._prepare_batch(batch_indices)
else:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = future.result()
if b+1 < len(X_test):
future = executor.submit(self._prepare_batch, X_test[b+1])
H, R, L = self.forward(
X,
lib_size,
batch_index,
label_index,
additional_label_index,
additional_batch_index,
P,
reduction=reconstruction_reduction,
compute_mmd = mmd_weight > 0
)
reconstruction_loss = L['reconstruction_loss']
prediction_loss = pred_weight * L['prediction_loss']
additional_prediction_loss = pred_weight * L['additional_prediction_loss']
kldiv_loss = kl_weight * L['kldiv_loss']
mmd_loss = mmd_weight * L['mmd_loss']
avg_reconstruction_loss = reconstruction_loss.sum()
avg_kldiv_loss = kldiv_loss.sum()
avg_mmd_loss = mmd_loss
epoch_reconstruction_loss += avg_reconstruction_loss.item()
epoch_kldiv_loss += avg_kldiv_loss.item()
if self.n_label > 0:
epoch_prediction_loss += prediction_loss.sum().item()
if self.n_additional_label is not None:
epoch_prediction_loss += additional_prediction_loss.sum().item()
epoch_mmd_loss += avg_mmd_loss
epoch_total_loss += (avg_reconstruction_loss + avg_kldiv_loss + avg_mmd_loss).item()
return {
"epoch_reconstruction_loss": epoch_reconstruction_loss / (b+1),
"epoch_kldiv_loss": epoch_kldiv_loss / (b+1),
"epoch_mmd_loss": epoch_mmd_loss / (b+1),
"epoch_prediction_loss": epoch_prediction_loss / (b+1),
"epoch_total_loss": epoch_total_loss / (b+1),
}
[docs]
def fit(self,
max_epoch: Optional[int] = None,
n_per_batch:int = 128,
kl_weight: float = 1.,
pred_weight: float = 1.,
mmd_weight: float = 1.,
gate_weight: float = 1.,
constrain_weight: float = 1.,
optimizer_parameters: Iterable = None,
validation_split: float = .2,
lr: bool = 5e-5,
lr_schedule: bool = False,
lr_factor: float = 0.6,
lr_patience: int = 30,
lr_threshold: float = 0.0,
lr_min: float = 1e-6,
n_epochs_kl_warmup: Union[int, None] = 400,
weight_decay: float = 1e-6,
random_seed: int = 12,
subset_indices: Union[torch.tensor, np.ndarray] = None,
pred_last_n_epoch: int = 10,
pred_last_n_epoch_fconly: bool = False,
compute_batch_after_n_epoch: int = 0,
reconstruction_reduction: str = 'sum',
n_concurrent_batch: int = 1,
):
"""
Fit the model.
:param max_epoch: int. Maximum number of epoch to train the model. If not provided, the model will be trained for 400 epochs or 20000 / n_record * 400 epochs as default.
:param n_per_batch: int. Number of cells per batch.
:param kl_weight: float. (Maximum) weight of the KL divergence loss.
:param pred_weight: float. weight of the prediction loss.
:param mmd_weight: float. weight of the mmd loss. ignored if mmd_key is None
:param constrain_weight: float. weight of the constrain loss. ignored if constrain_latent_embedding is False.
:param optimizer_parameters: Iterable. Parameters to be optimized. If not provided, all parameters will be optimized.
:param validation_split: float. Percentage of data to be used as validation set.
:param lr: float. Learning rate.
:param lr_schedule: bool. Whether to use learning rate scheduler.
:param lr_factor: float. Factor to reduce learning rate.
:param lr_patience: int. Number of epoch to wait before reducing learning rate.
:param lr_threshold: float. Threshold to trigger learning rate reduction.
:param lr_min: float. Minimum learning rate.
:param n_epochs_kl_warmup: int. Number of epoch to warmup the KL divergence loss (deterministic warm-up
of the KL-term).
:param weight_decay: float. Weight decay (L2 penalty).
:param random_seed: int. Random seed.
:param subset_indices: Union[torch.tensor, np.ndarray]. Indices of cells to be used for training. If not provided, all cells will be used.
:param pred_last_n_epoch: int. Number of epoch to train the prediction layer only.
:param pred_last_n_epoch_fconly: bool. Whether to train the prediction layer only.
:param reconstruction_reduction: str. Reduction method for reconstruction loss. Can be 'sum' or 'mean'.
"""
self.train()
if max_epoch is None:
max_epoch = np.min([round((20000 / self._n_record ) * 400), 400])
mt(f"max_epoch is not provided, setting max_epoch to {max_epoch}")
if n_epochs_kl_warmup:
n_epochs_kl_warmup = min(max_epoch, n_epochs_kl_warmup)
kl_warmup_gradient = kl_weight / n_epochs_kl_warmup
kl_weight_max = kl_weight
kl_weight = 0.
if optimizer_parameters is None:
optimizer = optim.AdamW(self.parameters(), lr, weight_decay=weight_decay)
else:
optimizer = optim.AdamW(optimizer_parameters, lr, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(
optimizer,
patience=lr_patience,
factor=lr_factor,
threshold=lr_threshold,
min_lr=lr_min,
threshold_mode="abs",
verbose=True,
) if lr_schedule else None
labels=None
best_state_dict = None
best_score = 0
current_score = 0
pbar = get_tqdm()(
range(max_epoch),
desc="Epoch",
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}' if not is_notebook() else '',
position=0,
leave=True
)
loss_record = {
"epoch_reconstruction_loss": 0,
"epoch_kldiv_loss": 0,
"epoch_prediction_loss": 0,
"epoch_mmd_loss": 0,
"epoch_total_loss": 0
}
epoch_total_loss_list = []
epoch_reconstruction_loss_list = []
epoch_kldiv_loss_list = []
epoch_prediction_loss_list = []
epoch_mmd_loss_list = []
epoch_gate_loss_list = []
epoch_constraint_loss_list = []
for epoch in range(1, max_epoch+1):
self._trained = True
pbar.desc = "Epoch {}".format(epoch)
epoch_total_loss = 0
epoch_reconstruction_loss = 0
epoch_kldiv_loss = 0
epoch_prediction_loss = 0
epoch_mmd_loss = 0
epoch_gate_loss = 0
epoch_constrain_loss = 0
X_train, X_test = self.as_dataloader(
n_per_batch=n_per_batch,
train_test_split = True,
validation_split = validation_split,
random_seed=random_seed,
subset_indices=subset_indices
)
if self.n_label > 0 and epoch == max_epoch - pred_last_n_epoch:
mt("saving transcriptome only state dict")
self.gene_only_state_dict = deepcopy(self.state_dict())
if pred_last_n_epoch_fconly:
optimizer = optim.AdamW(chain(self.att.parameters(), self.fc.parameters()), lr, weight_decay=weight_decay)
X_train = list(X_train) # convert to list
future_dict = {}
with ThreadPoolExecutor(max_workers=1) as executor:
for b, batch_indices in enumerate(X_train):
future = future_dict.get(b, None)
if future is None:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = self._prepare_batch(batch_indices)
else:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = future.result()
# future.clear()
future_dict.pop(b)
for fb in range(b+1, b+1+n_concurrent_batch):
if fb < len(X_train):
future_dict[fb] = executor.submit(self._prepare_batch, X_train[fb])
H, R, L = self.forward(
X,
lib_size,
batch_index,
label_index,
additional_label_index,
additional_batch_index,
P,
reduction=reconstruction_reduction,
compute_mmd = mmd_weight > 0 and epoch >= compute_batch_after_n_epoch
)
reconstruction_loss = L['reconstruction_loss']
prediction_loss = pred_weight * L['prediction_loss']
additional_prediction_loss = pred_weight * L['additional_prediction_loss']
kldiv_loss = L['kldiv_loss']
mmd_loss = mmd_weight * L['mmd_loss']
avg_gate_loss = gate_weight * torch.sigmoid(R['px_rna_dropout']).sum(dim=1).mean()
avg_reconstruction_loss = reconstruction_loss.sum() / n_per_batch
avg_kldiv_loss = kldiv_loss.sum() / n_per_batch
avg_mmd_loss = mmd_loss / n_per_batch
epoch_reconstruction_loss += avg_reconstruction_loss.item()
epoch_kldiv_loss += avg_kldiv_loss.item()
epoch_mmd_loss += avg_mmd_loss.item()
epoch_gate_loss += avg_gate_loss.item()
if self.n_label > 0:
epoch_prediction_loss += prediction_loss.sum().item()
if epoch > max_epoch - pred_last_n_epoch:
loss = avg_reconstruction_loss + avg_kldiv_loss * kl_weight + avg_mmd_loss + (prediction_loss.sum() + additional_prediction_loss.sum()) / (len(self.n_additional_label) if self.n_additional_label is not None else 0 + 1) + avg_gate_loss
else:
loss = avg_reconstruction_loss + avg_kldiv_loss * kl_weight + avg_mmd_loss + avg_gate_loss
if self.constrain_latent_embedding:
loss += constrain_weight * L['latent_constrain_loss']
epoch_constrain_loss += L['latent_constrain_loss'].item()
epoch_total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix({
'rec': '{:.2e}'.format(loss_record["epoch_reconstruction_loss"]),
'kl': '{:.2e}'.format(loss_record["epoch_kldiv_loss"]),
'pred': '{:.2e}'.format(loss_record["epoch_prediction_loss"]),
'mmd': '{:.2e}'.format(loss_record["epoch_mmd_loss"]),
'step': f'{b} / {len(X_train)}'
})
loss_record = self.calculate_metric(X_test, kl_weight, pred_weight, mmd_weight, reconstruction_reduction)
if lr_schedule:
scheduler.step(loss_record["epoch_total_loss"])
pbar.set_postfix({
'rec': '{:.2e}'.format(loss_record["epoch_reconstruction_loss"]),
'kl': '{:.2e}'.format(loss_record["epoch_kldiv_loss"]),
'pred': '{:.2e}'.format(loss_record["epoch_prediction_loss"]),
'mmd': '{:.2e}'.format(loss_record["epoch_mmd_loss"]),
})
epoch_total_loss_list.append(epoch_total_loss)
epoch_reconstruction_loss_list.append(epoch_reconstruction_loss)
epoch_kldiv_loss_list.append(epoch_kldiv_loss)
epoch_prediction_loss_list.append(epoch_prediction_loss)
epoch_mmd_loss_list.append(epoch_mmd_loss)
epoch_gate_loss_list.append(epoch_gate_loss)
epoch_constraint_loss_list.append(epoch_constrain_loss)
pbar.update(1)
if n_epochs_kl_warmup:
kl_weight = min( kl_weight + kl_warmup_gradient, kl_weight_max)
random_seed += 1
if current_score < best_score:
mt("restoring state dict with best performance")
self.load_state_dict(best_state_dict)
pbar.close()
self.trained_state_dict = deepcopy(self.state_dict())
return dict(
epoch_total_loss_list=epoch_total_loss_list,
epoch_reconstruction_loss_list=epoch_reconstruction_loss_list,
epoch_kldiv_loss_list=epoch_kldiv_loss_list,
epoch_prediction_loss_list=epoch_prediction_loss_list,
epoch_mmd_loss_list=epoch_mmd_loss_list,
epoch_gate_loss_list=epoch_gate_loss_list,
epoch_constraint_loss_list=epoch_constraint_loss_list
)
@torch.no_grad()
def predict_labels(
self,
n_per_batch: int = 128,
return_pandas: bool = False,
show_progress: bool = True
) -> List:
"""
Predict labels from trained model.
:param n_per_batch: int. Number of cells for each mini-batch during inference.
:param return_pandas: bool. return a pandas DataFrame if True else return a pytorch tensor.
:param show_progress: bool. Show progress bar of total progress.
"""
self.eval()
X = self.as_dataloader(
subset_indices = list(range(len(self._dataset))),
shuffle=False,
n_per_batch=n_per_batch
)
predictions = []
additional_predictions = []
if show_progress:
pbar = get_tqdm()(
X,
desc="Predicting Labels",
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}' if not is_notebook() else '',
position=0,
leave=True
)
for x in X:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = self._prepare_batch(x)
H = self.encode(X, batch_index if batch_index != None else None)
prediction = H.get('prediction', self.fc(H['z']))
predictions.append(prediction.detach().cpu())
if self.n_additional_label is not None:
additional_prediction = [None] * len(self.n_additional_label)
for i in range(len(self.n_additional_label)):
additional_prediction[i] = self.additional_fc[i](H['z']).detach().cpu()
additional_predictions.append(additional_prediction)
if show_progress:
pbar.update(1)
if show_progress:
pbar.close()
predictions = torch.vstack(predictions)[self._shuffle_indices]
predictions_argmax = torch.argmax(predictions, dim=1)
predictions_argmax = list(map(lambda x:
self.label_category.categories[x],
predictions_argmax.detach().cpu().numpy()
))
predictions_argmax = pd.DataFrame(predictions_argmax, index=self.adata.obs.index)
predictions_argmax.columns = [self.label_key]
if return_pandas and self.n_additional_label is None:
return predictions_argmax
if self.n_additional_label is not None:
additional_predictions_result = [None] * len(self.n_additional_label)
additional_predictions_result_argmax = [None] * len(self.n_additional_label)
for i in range(len(self.n_additional_label)):
additional_predictions_ = torch.vstack([additional_predictions[x][i] for x in range(len(additional_predictions))]) [self._shuffle_indices]
additional_predictions_result_argmax[i] = np.argmax(additional_predictions_, axis=1)
additional_predictions_result_argmax[i] = list(map(lambda x:
self.additional_label_category[i].categories[x],
additional_predictions_result_argmax[i].numpy()
))
if return_pandas:
additional_predictions_result_argmax = pd.DataFrame(
additional_predictions_result_argmax,
columns = self.adata.obs.index
).T
additional_predictions_result_argmax.columns = self.additional_label_keys
return pd.concat([predictions_argmax, additional_predictions_result_argmax], axis=1)
return predictions, additional_predictions
return predictions
@torch.no_grad()
def get_latent_embedding(
self,
latent_key: Literal["z", "q_mu"] = "q_mu",
n_per_batch: int = 128,
show_progress: bool = True
) -> np.ndarray:
self.eval()
X = self.as_dataloader(
subset_indices = list(range(len(self._dataset))),
shuffle=False,
n_per_batch=n_per_batch
)
if isinstance(latent_key, str):
Zs = []
elif isinstance(latent_key, Iterable):
Zs = [[] for _ in range(len(latent_key))]
if show_progress:
pbar = get_tqdm()(
X,
desc="Latent Embedding",
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}' if not is_notebook() else '',
position=0,
leave=True
)
for x in X:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = self._prepare_batch(x)
H = self.encode(X, batch_index if batch_index != None else None)
if isinstance(latent_key, str):
Zs.append(H[latent_key].detach().cpu().numpy())
elif isinstance(latent_key, Iterable):
for i in range(len(latent_key)):
Zs[i].append(H[latent_key[i]].detach().cpu().numpy())
if show_progress:
pbar.update(1)
if show_progress:
pbar.close()
if isinstance(latent_key, str):
return np.vstack(Zs)[self._shuffle_indices]
elif isinstance(latent_key, Iterable):
return [np.vstack(Z)[self._shuffle_indices] for Z in Zs]
@torch.no_grad()
def get_reconstructed_expression(self, k = 'px_rna_scale_orig', n_per_batch=256,show_progress=True) -> np.ndarray:
self.eval()
Zs = []
X = self.as_dataloader(subset_indices = list(range(len(self._dataset))), shuffle=False, n_per_batch=n_per_batch)
predictions = []
additional_predictions = []
if show_progress:
pbar = get_tqdm()(
X,
desc="Reconstructing gene expression",
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}' if not is_notebook() else '',
position=0,
leave=True
)
for x in X:
X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size = self._prepare_batch(x)
_,R,_ = self.forward(
X,
lib_size,
batch_index,
label_index,
additional_label_index,
additional_batch_index,
P,
)
Zs.append(R[k].detach().cpu().numpy())
if show_progress:
pbar.update(1)
return np.vstack(Zs)[self._shuffle_indices]
def to(self, device:str):
super(scAtlasVAE, self).to(device)
self.device=device
return self
def transfer(self,
new_adata: sc.AnnData,
batch_key: str,
concat_with_original: bool = True,
fraction_of_original: Optional[float] = None,
times_of_new: Optional[float] = None
):
new_batch_category = new_adata.obs[batch_key]
original_batch_dim = self.batch_hidden_dim
new_n_batch = len(np.unique(new_batch_category))
if self.batch_embedding == "embedding":
original_embedding_weight = self.decoder.cat_embedding[0].weight
new_adata.obs[self.batch_key] = new_adata.obs[batch_key]
original_batch_categories = self.batch_category.categories
if fraction_of_original is not None:
old_adata = random_subset_by_key_fast(
self.adata,
key = batch_key,
n = int(len(self.adata) * fraction_of_original)
)
elif times_of_new is not None:
old_adata = random_subset_by_key_fast(
self.adata,
key = batch_key,
n = int(len(new_adata) * times_of_new)
)
else:
old_adata = self.adata
old_adata.obs['_transfer_label'] = 'reference'
new_adata.obs['_transfer_label'] = 'query'
if concat_with_original:
self.adata = sc.concat([old_adata, new_adata])
else:
self.adata = new_adata
self.initialize_dataset()
if self.batch_embedding == "onehot":
self.batch_hidden_dim = self.n_batch
if self.n_additional_batch_ is not None and self.inject_additional_batch:
if self.n_batch > 0 and self.n_label > 0 and self.inject_batch and self.inject_label:
decoder_n_cat_list = [self.n_batch, self.n_label, *self.n_additional_batch]
elif self.n_batch > 0 and self.inject_batch:
decoder_n_cat_list = [self.n_batch, *self.n_additional_batch]
elif self.n_label > 0 and self.inject_label:
decoder_n_cat_list = [self.n_label, *self.n_additional_batch]
else:
decoder_n_cat_list = None
else:
if self.n_batch > 0 and self.n_label > 0 and self.inject_batch and self.inject_label:
decoder_n_cat_list = [self.n_batch, self.n_label]
elif self.n_batch > 0 and self.inject_batch:
decoder_n_cat_list = [self.n_batch]
elif self.n_label > 0 and self.inject_label:
decoder_n_cat_list = [self.n_label]
else:
decoder_n_cat_list = None
self.decoder_n_cat_list = decoder_n_cat_list
original_weight = torch.tensor(self.decoder._fclayer[0].weight)
self.decoder = FCLayer(
in_dim = self.n_latent,
out_dim = self.n_hidden,
n_cat_list = self.decoder_n_cat_list,
cat_dim = self.batch_hidden_dim,
cat_embedding = self.batch_embedding,
use_layer_norm=False,
use_batch_norm=True,
dropout_rate=0,
device=self.device
)
if self.batch_embedding == 'embedding':
new_embedding = nn.Embedding(self.n_batch + new_n_batch, self.batch_hidden_dim).to(self.device)
original_category_index = [list(self.batch_category.categories).index(x) for x in original_batch_categories]
new_embedding_weight = new_embedding.weight.detach()
new_embedding_weight[original_category_index] = original_embedding_weight.detach()
new_embedding.weight = nn.Parameter(new_embedding_weight)
new_embedding = new_embedding.to(self.device)
self.decoder.cat_embedding[0] = new_embedding
new_weight = torch.tensor(self.decoder._fclayer[0].weight)
new_weight[:,:(self.n_latent + original_batch_dim)] = original_weight[:,:(self.n_latent + original_batch_dim)]
self.decoder._fclayer[0].weight = nn.Parameter(new_weight)
self.to(self.device)
def transfer_label(
self,
reference_adata: sc.AnnData,
label_key: str,
method: Literal['knn'] = 'knn',
use_rep: str = 'X_gex',
**method_kwargs
):
"""
Transfer label from reference_adata to self.adata
:param reference_adata: sc.AnnData
:param label_key: str
"""
s = set(reference_adata.obs.index)
s = list(filter(lambda x: x in s, self.adata.obs.index))
self.adata.obs[label_key] = np.nan
ss = set(s)
indices = list(map(lambda x: x in ss, self.adata.obs.index))
self.adata.obs[label_key][indices] = reference_adata[s].obs[label_key]
if 'X_gex' not in self.adata.obsm.keys():
Z = self.get_latent_embedding()
self.adata.obsm[use_rep] = Z
if method == 'knn':
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(**method_kwargs)
knn.fit(
self.adata.obsm[use_rep][indices], self.adata.obs.loc[indices, label_key]
)
self.adata.obs.loc[self.adata.obs['_transfer_label'] == 'query', label_key] = knn.predict(
self.adata.obsm[use_rep][self.adata.obs['_transfer_label'] == 'query']
)
else:
raise NotImplementedError()
def umap_alignment(
self,
reference_adata: sc.AnnData,
label_key: str,
method: Literal['retrain','knn'] = 'knn',
use_rep: str = 'X_gex',
**method_kwargs
):
umap_alignment(
reference_adata.obsm[use_rep],
reference_adata.obsm['X_umap'],
reference_adata.obsm[use_rep],
method=method,
)
def _prepare_batch(self, batch_indices):
P = None
batch_data = self._dataset[batch_indices.cpu().numpy().astype(int)]
batch_index, label_index, additional_label_index, additional_batch_index = None, None, None, None
if self.low_memory_initialization:
if self.anndata_tensorstore_path is not None:
if self.use_layer is None:
X = ats._ext.load_X(
os.path.join(self.anndata_tensorstore_path, ats.ATS_FILE_NAME.X),
obs_indices = self._shuffled_indices_inverse[
batch_indices.cpu().numpy()
],
var_indices = self.anndata_tensorstore_var_indices,
to_sparse=False
)
else:
X = ats._ext.load_X(
os.path.join(self.anndata_tensorstore_path, ats.ATS_FILE_NAME.layers, self.use_layer),
obs_indices = self._shuffled_indices_inverse[
batch_indices.cpu().numpy()
],
var_indices = self.anndata_tensorstore_var_indices,
to_sparse=False
)
else:
if self.use_layer is None:
X = self.adata.X[
self._shuffled_indices_inverse[
batch_indices.cpu().numpy()
]
]
else:
X = self.adata.layers[self.use_layer][
self._shuffled_indices_inverse[
batch_indices.cpu().numpy()
]
]
if self.n_batch > 0 or self.n_label > 0:
if not (isinstance(batch_data, Iterable) and len(batch_data) > 1):
raise ValueError("batch_data is not iterable or has only one element")
if self.n_additional_batch_ is not None:
if self.n_batch > 0 and self.n_label > 0 and self.n_additional_label is not None:
if self.constrain_latent_embedding:
P, batch_index, label_index, additional_label_index, additional_batch_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_k_elements(batch_data,2),
get_elements(batch_data,3, len(self.n_additional_label)),
get_last_k_elements(batch_data,3+len(self.n_additional_label))
)
else:
batch_index, label_index, additional_label_index, additional_batch_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_elements(batch_data,2, len(self.n_additional_label)),
get_last_k_elements(batch_data,2+len(self.n_additional_label))
)
additional_label_index = list(np.vstack(additional_label_index).T.astype(int))
elif self.n_batch > 0 and self.n_label > 0:
if self.constrain_latent_embedding:
P, batch_index, label_index, additional_batch_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_k_elements(batch_data,2),
get_last_k_elements(batch_data,3)
)
else:
batch_index, label_index, additional_batch_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_last_k_elements(batch_data,2)
)
elif self.n_batch > 0:
if self.constrain_latent_embedding:
P, batch_index, additional_batch_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_last_k_elements(batch_data,2)
)
else:
batch_index, additional_batch_index = get_k_elements(batch_data,0), get_last_k_elements(batch_data,1)
elif self.n_label > 0:
if self.constrain_latent_embedding:
P, label_index, additional_batch_index = get_k_elements(batch_data,0), get_k_elements(batch_data,1), get_last_k_elements(batch_data,2)
else:
label_index, additional_batch_index = get_k_elements(batch_data,0), get_last_k_elements(batch_data,2)
additional_batch_index = list(np.vstack(additional_batch_index).T.astype(int))
else:
if self.n_batch > 0 and self.n_label > 0 and self.n_additional_label is not None:
if self.constrain_latent_embedding:
P, batch_index, label_index, additional_label_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_k_elements(batch_data,2),
get_last_k_elements(batch_data,3)
)
else:
batch_index, label_index, additional_label_index = (
get_k_elements(batch_data,0),
get_k_elements(batch_data,1),
get_last_k_elements(batch_data,2)
)
additional_label_index = list(np.vstack(additional_label_index).T.astype(int))
elif self.n_batch > 0 and self.n_label > 0:
if self.constrain_latent_embedding:
P, batch_index, label_index = get_k_elements(batch_data,0), get_k_elements(batch_data,1), get_k_elements(batch_data,2)
else:
batch_index, label_index = get_k_elements(batch_data,0), get_k_elements(batch_data,1)
elif self.n_batch > 0:
if self.constrain_latent_embedding:
P, batch_index = get_k_elements(batch_data,0), get_k_elements(batch_data,1)
else:
batch_index = get_k_elements(batch_data,0)
elif self.n_label > 0:
if self.constrain_latent_embedding:
P, label_index = get_k_elements(batch_data,0), get_k_elements(batch_data,1)
else:
label_index = get_k_elements(batch_data,0)
X = torch.tensor((X.toarray() if issparse(X) else X).astype(np.float32))
else:
if self.n_batch > 0 or self.n_label > 0:
if not isinstance(batch_data, Iterable) and len(batch_data) > 1:
raise ValueError()
if self.n_additional_batch_ is not None:
if (
self.n_batch > 0
and self.n_label > 0
and self.n_additional_label is not None
):
if self.constrain_latent_embedding:
(
X,
P,
batch_index,
label_index,
additional_label_index,
additional_batch_index,
) = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_k_elements(batch_data, 3),
get_elements(batch_data, 4, len(self.n_additional_label)),
get_last_k_elements(batch_data, 4 + len(self.n_additional_label)),
)
else:
(
X,
batch_index,
label_index,
additional_label_index,
additional_batch_index,
) = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_elements(batch_data, 3, len(self.n_additional_label)),
get_last_k_elements(batch_data, 3 + len(self.n_additional_label)),
)
additional_label_index = list(
np.vstack(additional_label_index).T.astype(int)
)
elif self.n_batch > 0 and self.n_label > 0:
if self.constrain_latent_embedding:
X, P, batch_index, label_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_k_elements(batch_data, 3),
get_last_k_elements(batch_data, 4),
)
else:
X, batch_index, label_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_last_k_elements(batch_data, 3),
)
elif self.n_batch > 0:
if self.constrain_latent_embedding:
X, P, batch_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_last_k_elements(batch_data, 3),
)
else:
X, batch_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_last_k_elements(batch_data, 2),
)
elif self.n_label > 0:
if self.constrain_latent_embedding:
X, P, label_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_last_k_elements(batch_data, 3),
)
else:
X, label_index, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_last_k_elements(batch_data, 2),
)
else:
if self.constrain_latent_embedding:
X, P, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
)
else:
X, additional_batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
)
additional_batch_index = list(
np.vstack(additional_batch_index).T.astype(int)
)
else:
if (
self.n_batch > 0
and self.n_label > 0
and self.n_additional_label is not None
):
if self.constrain_latent_embedding:
X, P, batch_index, label_index, additional_label_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_k_elements(batch_data, 3),
get_last_k_elements(batch_data, 4),
)
else:
X, batch_index, label_index, additional_label_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_last_k_elements(batch_data, 3),
)
additional_label_index = list(
np.vstack(additional_label_index).T.astype(int)
)
elif self.n_batch > 0 and self.n_label > 0:
if self.constrain_latent_embedding:
X, P, batch_index, label_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
get_k_elements(batch_data, 3),
)
else:
X, batch_index, label_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
)
elif self.n_batch > 0:
if self.constrain_latent_embedding:
X, P, batch_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
)
else:
X, batch_index = get_k_elements(batch_data, 0), get_k_elements(batch_data, 1)
elif self.n_label > 0:
if self.constrain_latent_embedding:
X, P, label_index = (
get_k_elements(batch_data, 0),
get_k_elements(batch_data, 1),
get_k_elements(batch_data, 2),
)
else:
X, label_index = get_k_elements(batch_data, 0), get_k_elements(batch_data, 1)
else:
if self.constrain_latent_embedding:
X, P = get_k_elements(batch_data, 0), get_k_elements(batch_data, 1)
else:
X = get_k_elements(batch_data, 0)
X = torch.tensor(
np.vstack(list(map(lambda x: x.toarray() if issparse(x) else x, X)))
)
if self.constrain_latent_embedding:
P = torch.tensor(np.vstack(P)).type(torch.FloatTensor).to(self.device)
if self.n_label > 0:
label_index = torch.tensor(label_index)
if not isinstance(label_index, torch.FloatTensor):
label_index = label_index.type(torch.FloatTensor)
label_index = label_index.to(self.device).unsqueeze(1)
if self.n_batch > 0:
batch_index = torch.tensor(batch_index)
if not isinstance(batch_index, torch.FloatTensor):
batch_index = batch_index.type(torch.FloatTensor)
batch_index = batch_index.to(self.device).unsqueeze(1)
if self.n_additional_label is not None:
for i in range(len(additional_label_index)):
additional_label_index[i] = torch.tensor(additional_label_index[i])
if not isinstance(additional_label_index[i], torch.FloatTensor):
additional_label_index[i] = additional_label_index[i].type(torch.FloatTensor)
additional_label_index[i] = additional_label_index[i].to(self.device).unsqueeze(1)
if self.n_additional_batch_ is not None:
for i in range(len(additional_batch_index)):
additional_batch_index[i] = torch.tensor(additional_batch_index[i])
if not isinstance(additional_batch_index[i], torch.FloatTensor):
additional_batch_index[i] = additional_batch_index[i].type(torch.FloatTensor)
additional_batch_index[i] = additional_batch_index[i].to(self.device).unsqueeze(1)
if not isinstance(X, torch.FloatTensor):
X = X.type(torch.FloatTensor)
if P is not None and not isinstance(P, torch.FloatTensor):
P = P.type(torch.FloatTensor)
P = P.to(self.device)
X = X.to(self.device)
lib_size = X.sum(1).to(self.device)
return X, P, batch_index, label_index, additional_label_index, additional_batch_index, lib_size