Source code for stereo.algorithm.score_genes

from typing import Union, Tuple, List

import pandas as pd
import numpy as np
from scipy.sparse import spmatrix, issparse
from scipy import stats

from stereo.algorithm.algorithm_base import AlgorithmBase
from stereo.log_manager import logger


class ScoreGenes(AlgorithmBase):

    def _expression_mean(
        self,
        exp_matrix: Union[np.ndarray, spmatrix],
        axis: int
    ) -> np.ndarray:
        if issparse(exp_matrix):
            s = exp_matrix.sum(axis=axis, dtype=np.float64)
            m = s / exp_matrix.shape[axis]
            return m.A.flatten()
        return exp_matrix.mean(axis=axis, dtype=np.float64)

    def _get_expression_subset(
        self,
        genes: np.ndarray,
        use_raw: bool
    ) -> Union[np.ndarray, spmatrix]:
        data = self.stereo_exp_data
        gene_names = data.raw.gene_names if use_raw else data.gene_names
        exp_matrix = data.raw.exp_matrix if use_raw else data.exp_matrix

        if len(genes) == len(gene_names):
            return exp_matrix
        idx = pd.Index(gene_names).get_indexer(genes)
        return exp_matrix[:, idx]

    def _check_score_genes(
        self,
        genes_used: np.ndarray,
        genes_reference: Union[np.ndarray, None],
        use_raw: bool,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Restrict `genes_used` and `genes_reference` to present genes in `data`.
        """
        data = self.stereo_exp_data
        gene_names = data.raw.gene_names if use_raw else data.gene_names
        genes_used = np.array([genes_used] if isinstance(genes_used, str) else genes_used, dtype='U')
        isin = np.isin(genes_used, gene_names)
        genes_to_ignore = genes_used[~isin]  # first get missing
        genes_used = genes_used[isin]  # then restrict to present
        if len(genes_to_ignore) > 0:
            logger.warning(f"genes are not in gene_names and ignored: {genes_to_ignore}")
        if len(genes_used) == 0:
            raise ValueError("No valid genes were passed for scoring.")

        if genes_reference is None:
            genes_reference = gene_names
        else:
            genes_reference = np.array([genes_reference] if isinstance(genes_reference, str) else genes_reference, dtype='U')
            genes_reference = np.intersect1d(genes_reference, gene_names)
        if len(genes_reference) == 0:
            raise ValueError("No valid genes are passed for reference set.")

        return genes_used, genes_reference


    def _score_genes_bins(
        self,
        genes_used: np.ndarray,
        genes_reference: np.ndarray,
        ctrl_as_ref: bool,
        ctrl_size: int,
        n_bins: int,
        use_raw: bool
    ) -> np.ndarray:
        # mean expression of genes in `genes_reference`
        exp_matrix = self._get_expression_subset(genes_reference, use_raw)
        genes_exp_mean = self._expression_mean(exp_matrix, axis=0)
        n_items = int(np.round(len(genes_exp_mean) / (n_bins - 1)))
        cells_cut = stats.rankdata(genes_exp_mean, method='min') // n_items
        keep_ctrl_in_cells_cut = np.zeros(genes_reference.size, dtype=bool) if ctrl_as_ref else np.isin(genes_reference, genes_used)

        # now pick `ctrl_size` genes from every cut
        control_genes = pd.array([], dtype="U")
        isin_used = np.isin(genes_reference, genes_used)
        cells_cut_iterable = np.unique(cells_cut[isin_used])
        for cut in cells_cut_iterable:
            r_genes = genes_reference[(cells_cut == cut) & ~keep_ctrl_in_cells_cut]
            if len(r_genes) == 0:
                msg = (
                    f"No control genes for {cut=}. You may need to increase the"
                    f"size of genes_reference (current size: {len(genes_reference)})"
                )
                logger.warning(msg)
            if ctrl_size < len(r_genes):
                r_genes = np.random.choice(r_genes, ctrl_size, replace=False)
            if ctrl_as_ref:  # otherwise `r_genes` is already filtered
                r_genes = np.setdiff1d(r_genes, genes_used)
            control_genes = np.union1d(control_genes, r_genes)
        return control_genes

[docs] def main( self, genes_used: Union[np.ndarray, List[str], Tuple[str]], ctrl_as_ref: bool = True, ctrl_size: int = 50, genes_reference: Union[np.ndarray, List[str], Tuple[str], None] = None, n_bins: int = 25, random_state: Union[int, np.random.RandomState, None] = 0, use_raw: bool = None, res_key: str = "score", ): """ Score a set of genes for each cell/bin. The score is the average expression of a set of genes subtracted with the average expression of a reference set of genes. The reference set is randomly sampled from the `genes_reference` for each binned expression value. :param genes_used: The list of gene names used for score calculation. :param ctrl_as_ref: Allow to use the control genes as reference, defaults to True :param ctrl_size: Number of reference genes to be sampled from each bin, defaults to 50, you can set `ctrl_size=len(genes_used)` if the length of `genes_used` is not too short. :param genes_reference: Genes for sampling the reference set, default is all genes. :param n_bins: Number of expression level bins for sampling, defaults to 25 :param random_state: The random seed for sampling, defaults to 0, fixed value to fixed result. :param use_raw: Whether to use the `data.raw`, defaults to `True` if `data.raw` is not `None` :param res_key: the column name of the result to be added in `data.cells`, defaults to "score" """ logger.info(f"calculating score, the result will be saved in data.cells['{res_key}']") if random_state is not None: np.random.seed(random_state) if not isinstance(genes_used,(np.ndarray, list, tuple)): raise ValueError("genes_used must be a list, tuple or numpy array.") if isinstance(genes_used, (list, tuple)): genes_used = np.array(genes_used, dtype="U") if genes_reference is not None: if not isinstance(genes_reference, (np.ndarray, list, tuple, str)): raise ValueError("genes_reference must be a list, tuple, numpy array or string.") if isinstance(genes_reference, str): genes_reference = [genes_reference] if isinstance(genes_reference, (list, tuple)): genes_reference = np.array(genes_reference, dtype="U") data = self.stereo_exp_data if use_raw is None: use_raw = True if data.raw is not None else False else: use_raw = use_raw and data.raw is not None genes_used, genes_reference = self._check_score_genes( genes_used, genes_reference, use_raw ) # Trying here to match the Seurat approach in scoring cells. # Basically we need to compare genes against random genes in a matched # interval of expression. control_genes = self._score_genes_bins( genes_used, genes_reference, ctrl_as_ref=ctrl_as_ref, ctrl_size=ctrl_size, n_bins=n_bins, use_raw=use_raw ) if len(control_genes) == 0: msg = "No control genes found in any cut." if ctrl_as_ref: msg += " Try setting `ctrl_as_ref` to False." raise RuntimeError(msg) means_list = self._expression_mean( self._get_expression_subset(genes_used, use_raw), axis=1 ) means_control = self._expression_mean( self._get_expression_subset(control_genes, use_raw), axis=1 ) score = means_list - means_control self.stereo_exp_data.cells[res_key] = pd.Series( score, index=self.stereo_exp_data.cells.cell_name, dtype=np.float64 )