Source code for stereo.algorithm.spatial_alignment.main

from typing import Union, Optional

import pandas as pd

from stereo.io.reader import stereo_to_anndata
from stereo.algorithm.algorithm_base import AlgorithmBase
from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
from stereo.core.stereo_exp_data import AnnBasedStereoExpData
from stereo.core.ms_data import MSData
from stereo.log_manager import logger
from stereo.preprocess.filter import filter_genes

from .spatialign import Spatialign

class SpatialAlignment(AlgorithmBase):

[docs] def main( self, use_hvg: bool = False, n_neighors: int = 15, is_undirected: bool = True, spatial_key: str = 'spatial', latent_dims: int = 100, # seed: int = 42, gpu: Optional[Union[int, str]] = None, lr: float = 1e-3, max_epoch: int = 500, alpha: float = 0.5, patient: int = 15, tau1: float = 0.2, tau2: float = 1.0, tau3: float = 0.5, is_verbose: bool = True, inplace: bool = False ): """ A method to remove batch effect. :param use_hvgs: Whether to use a subset only contains highly variable genes, defaults to False, if True, `data.tl.highly_variable_genes` should be run first and the data will be filtered to only contain highly variable genes. :param n_neighors: The number of neighbors selected when constructing a spatial neighbor graph, defaults to 15 :param is_undirected: Whether the constructed spatial neighbor graph is undirected graph, defaults to True :param latent_dims: The number of embedding dimensions, defaults to 100, a reduced dimension matrix whose shape is (n_cells, latent_dims) will be output, you can get it through `data.tl.result['aligned_reduction']`. :param seed: Random seed, difference seed will cause difference result, defaults to 42 :param gpu: Whether to use GPU to train, set the ID of GPU to be used, defaults to None to use CPU :param lr: Learning rate, defaults to 1e-3 :param max_epoch: The number of maximum epochs, defaults to 500 :param alpha: The momentum parameter, defaults to 0.5 :param patient: Early stop parameter, defaults to 15 :param tau1: Instance level and pseudo prototypical cluster level contrastive learning parameters, defaults to 0.2 :param tau2: Pseudo prototypical cluster entropy parameter, defaults to 1. :param tau3: Cross-batch instance self-supervised learning parameter, defaults to 0.5 :param is_verbose: Whether to print the detail information, defaults to True :param inplace: a corrected expression matrix will replace the `data.exp_matrix` if True or will be stored in `data.layers['aligned_matrix']` if False, defaults to False. """ if not isinstance(self.stereo_exp_data, AnnBasedStereoExpData): raise TypeError("The input data should be an object of AnnBasedStereoExpData.") if use_hvg: if 'highly_variable' not in self.stereo_exp_data.genes: raise KeyError(f"The data does not contain highly variable genes, run `data.tl.highly_variable_genes` first.") logger.info('The data only containing highly variable genes will be used.') hvg_flag = self.stereo_exp_data.genes['highly_variable'] hvg_genes = self.stereo_exp_data.gene_names[hvg_flag] filter_genes(self.stereo_exp_data, gene_list=hvg_genes, inplace=True) self.model = Spatialign( merge_data=self.stereo_exp_data.adata, batch_key='batch', is_reduce=False, # n_pcs=n_pcs, n_neigh=n_neighors, is_undirected=is_undirected, latent_dims=latent_dims, tau1=tau1, tau2=tau2, tau3=tau3, is_verbose=is_verbose, # seed=42, gpu=gpu, # save_path=save_path, spatial_key=spatial_key ) # return self.model self.model.train(lr=lr, max_epoch=max_epoch, alpha=alpha, patient=patient) aligned_matrix, aligned_reduction = self.model.alignment() if inplace: self.stereo_exp_data.exp_matrix = aligned_matrix else: self.stereo_exp_data.layers['aligned_matrix'] = aligned_matrix self.stereo_exp_data.cells_matrix['aligned_reduction'] = aligned_reduction