Source code for stereo.algorithm.cell_cell_communication.main

# python core module
from functools import partial
from multiprocessing.pool import Pool
from pathlib import Path
from typing import Tuple
from typing import Union

import numpy as np
import numpy_groupies as npg
import pandas as pd
from sqlalchemy import create_engine
# third part module
from tqdm import tqdm

from stereo.algorithm.algorithm_base import AlgorithmBase
from stereo.algorithm.cell_cell_communication.analysis_helper import (
    Subsampler,
    write_to_file,
    mouse2human
)
from stereo.algorithm.cell_cell_communication.exceptions import (
    ProcessMetaException,
    ParseCountsException,
    ThresholdValueException,
    AllCountsFilteredException,
    NoInteractionsFound,
    InvalidDatabase,
    PipelineResultInexistent,
    InvalidSpecies
)
from stereo.algorithm.cell_cell_communication.utils.database_utils import Database, DatabaseManager
from stereo.algorithm.cell_cell_communication.utils.sqlalchemy_model import Base
from stereo.algorithm.cell_cell_communication.utils.sqlalchemy_repository import (
    ComplexRepository,
    GeneRepository,
    InteractionRepository,
    MultidataRepository,
    ProteinRepository
)
from stereo.algorithm.cell_cell_communication.utils.visualization_process import visualization_process
# module in self project
from stereo.log_manager import logger
from stereo.stereo_config import stereo_conf


class CellCellCommunication(AlgorithmBase):
    # FIXME: change the default output_path in linux, change default homogene_path
[docs] def main( self, analysis_type: str = 'statistical', cluster_res_key: str = 'cluster', micro_envs: Union[pd.DataFrame, str] = None, species: str = "HUMAN", database: str = 'cellphonedb', homogene_path: str = None, counts_identifiers: str = "hgnc_symbol", subsampling: bool = False, subsampling_log: bool = False, subsampling_num_pc: int = 100, subsampling_num_cells: int = None, pca_res_key: str = None, separator_cluster: str = "|", separator_interaction: str = "_", iterations: int = 500, threshold: float = 0.1, processes: int = 1, pvalue: float = 0.05, result_precision: int = 3, output_path: str = None, means_filename: str = 'means', pvalues_filename: str = 'pvalues', significant_means_filename: str = 'significant_means', deconvoluted_filename: str = 'deconvoluted', output_format: str = 'csv', res_key: str = 'cell_cell_communication' ): """ Cell-cell communication analysis main functon. :param analysis_type: type of analysis: "simple", "statistical". :param cluster_res_key: the key which specifies the clustering result in data.tl.result. :param micro_envs: a datafram or a string: if a datafram, it has two columns, column names should be "cell_type" and "microenvironment". if a string, it is a key which specifies the `gen_ccc_micro_envs` result in data.tl.result. :param species: 'HUMAN' or 'MOUSE' :param database: if species is HUMAN, choose from 'cellphonedb' or 'liana'; if MOUSE, use 'cellphonedb' or 'liana' or 'celltalkdb'; you can also specify the path of a database. :param homogene_path: path to the file storing mouse-human homologous genes ralations. if species is MOUSE but database is 'cellphonedb' or 'liana', we need to use the human homologous genes for the input mouse genes. :param counts_identifiers: type of gene identifiers in the Counts data: "ensembl", "gene_name" or "hgnc_symbol". :param subsampling: flag of subsampling. :param subsampling_log: flag of doing log1p transformation before subsampling. :param subsampling_num_pc: number of pcs used when doing subsampling, <= min(m,n). :param subsampling_num_cells: size of the subsample. :param pca_res_key: the key which specifies the pca result in data.tl.result if set subsampling to True and set it to None, this function will run the pca. :param separator_cluster: separator of cluster names used in the result and plots, e.g. '|'. :param separator_interaction: separator of interactions used in the result and plots, e.g. '_'. :param iterations: number of iterations for the 'statistical' analysis type. :param threshold: threshold of percentage of gene expression, above which being considered as significant. :param processes: number of processes used for doing the statistical analysis, on notebook just only support one process. # noqa :param pvalue: the cut-point of p-value, below which being considered significant. :param result_precision: result precision for the results, default=3. :param output_path: the path of directory to save the result files, set it to output the result to files. :param means_filename: name of the means result file. :param pvalues_filename: name of the pvalues result file. :param significant_means_filename: name of the significant mean result file. :param deconvoluted_filename: name of the deconvoluted result file. :param output_format: format of result, 'txt', 'csv', 'tsv', 'tab'. :param res_key: set a key to store the result to data.tl.result. :return: """ if subsampling and pca_res_key is not None and pca_res_key not in self.pipeline_res: raise PipelineResultInexistent(pca_res_key) if species is None or species.upper() not in ('HUMAN', 'MOUSE'): raise InvalidSpecies(species) if species.upper() == 'HUMAN' and database == 'celltalkdb': raise InvalidDatabase("The database 'celltalkdb' can not be used with species 'HUMAN'") db_path = self._check_database(database) if db_path is None: raise InvalidDatabase() logger.info(f'species: {species.upper()}') logger.info(f'database: {database}') interactions, genes, complex_composition, complex_expanded = self._get_ref_database(db_path) counts, meta = self._prepare_data(cluster_res_key) # 1. preprocess and validate input data # 1.1. preprocess and validate meta data (cell name as index, cell type as the only column). # meta = self._check_meta_data(meta) # 1.2. preprocess and validate counts data self._check_counts_data(counts, counts_identifiers) counts = self._counts_validations(counts, meta) # 1.3. if species is mouse, get the homologous genes. human_genes_to_mouse = None if species.upper() == 'MOUSE' and (database == 'cellphonedb' or database == 'liana'): if homogene_path is None: homogene_path = Path(stereo_conf.data_dir, 'algorithm/cell_cell_communication/database/mouse2human.csv').absolute().as_posix() genes_mouse = counts.index.tolist() genes_human, human_genes_to_mouse = mouse2human(genes_mouse, homogene_path) counts.index = genes_human if 'NotAvailable' in genes_human: counts = counts.drop('NotAvailable') counts = counts.groupby(counts.index, as_index=True).sum() # 1.4. preprocess and validate micro_env data if micro_envs is None: micro_envs = pd.DataFrame() else: if isinstance(micro_envs, str): if micro_envs not in self.pipeline_res: raise PipelineResultInexistent(micro_envs) micro_envs = self.pipeline_res[micro_envs]['micro_envs'] micro_envs = self._check_microenvs_data(micro_envs, meta) # 1.5. preprocess and validate other parameters threshold = float(threshold) if threshold < 0 or threshold > 1: raise ThresholdValueException(threshold) # 2. subsampling if required if subsampling: subsampler = Subsampler(subsampling_log, subsampling_num_pc, subsampling_num_cells) else: subsampler = None if subsampler is not None: if pca_res_key is not None: counts = subsampler.subsample(counts, self.pipeline_res[pca_res_key]) else: counts = subsampler.subsample(counts) meta = meta.filter(items=list(counts), axis=0) # 3. do the analysis # 3.1. filter input and database data if analysis_type == 'statistical': logger.info( '[{} analysis] Threshold:{} Precision:{} Iterations:{} Threads:{}'.format(analysis_type, threshold, result_precision, iterations, processes)) if analysis_type == 'simple': logger.info( '[{} analysis] Threshold:{} Precision:{}'.format(analysis_type, threshold, result_precision)) interactions_reduced = interactions[['multidata_1_id', 'multidata_2_id']].drop_duplicates() # add id_multidata as index to counts, calculate mean grouped by id_multidata: # counts is the grouped means, counts_relations includes 'id_multidata', 'ensembl', 'gene_name', 'hgnc_symbol'. counts, counts_relations = self.add_multidata_and_means_to_counts(counts, genes, counts_identifiers) # filter the complex_composition, interactions_reduced, counts data complex_composition_filtered, interactions_filtered, counts_filtered = self.prefilters(interactions_reduced, counts, complex_composition) if interactions_filtered.empty: raise NoInteractionsFound() meta = meta.loc[counts.columns] # 3.2. build the cluster (means, percentages) and do the analysis # dict: cluster names, cluster means of proteins and complexes (min), # cluster percentages of proteins and complexes (min). clusters = self.build_clusters(meta, counts_filtered, complex_composition_filtered, skip_percent=False) logger.info('Running Real Analysis') cluster_interactions = self.get_cluster_combinations(clusters['names'], micro_envs) # arrays base_result = self.build_result_matrix(interactions_filtered, cluster_interactions, separator_cluster) # (x > 0) * (y > 0) * (x + y) / 2 real_mean_analysis = self.mean_analysis(interactions_filtered, clusters, cluster_interactions, separator_cluster) # ((x > threshold) * (y > threshold)).astype(int) real_percents_analysis = self.percent_analysis(clusters, threshold, interactions_filtered, cluster_interactions, separator_cluster) if analysis_type == 'statistical': logger.info('Running Statistical Analysis') statistical_mean_analysis = self.shuffled_analysis(iterations, meta, counts_filtered, interactions_filtered, cluster_interactions, complex_composition_filtered, real_mean_analysis, processes, separator_cluster) result_pvalues = self.build_pvalue_result(real_mean_analysis, real_percents_analysis, statistical_mean_analysis, base_result) else: result_pvalues = pd.DataFrame() # 3.3. output results pvalues_result = None means_result = None significant_means = None deconvoluted_result = None if analysis_type == 'simple': pvalues_result, means_result, significant_means, deconvoluted_result = self.build_results( analysis_type, interactions_filtered, interactions, counts_relations, real_mean_analysis, real_percents_analysis, clusters['means'], complex_composition_filtered, counts, genes, result_precision, pvalue, counts_identifiers, separator_interaction ) if analysis_type == 'statistical': pvalues_result, means_result, significant_means, deconvoluted_result = self.build_results( analysis_type, interactions_filtered, interactions, counts_relations, real_mean_analysis, result_pvalues, clusters['means'], complex_composition_filtered, counts, genes, result_precision, pvalue, counts_identifiers, separator_interaction ) max_rank = significant_means['rank'].max() significant_means['rank'] = significant_means['rank'].apply( lambda rank: rank if rank != 0 else (1 + max_rank)) significant_means.sort_values('rank', inplace=True) # min to max, 0s at the bottom visualization_data = visualization_process(significant_means, separator_cluster, separator_interaction, human_genes_to_mouse) self.pipeline_res[res_key] = { 'means': means_result, 'significant_means': significant_means, 'deconvoluted': deconvoluted_result, 'visualization_data': visualization_data } if analysis_type == "statistical": self.pipeline_res[res_key]['pvalues'] = pvalues_result self.pipeline_res[res_key]['parameters'] = { 'analysis_type': analysis_type, 'cluster_res_key': cluster_res_key } self.stereo_exp_data.tl.reset_key_record('cell_cell_communication', res_key) if output_path is not None: logger.info('Writing results to files') # Todo: Test output_path in linux write_to_file(means_result, means_filename, output_path=output_path, output_format=output_format) write_to_file(significant_means, significant_means_filename, output_path=output_path, output_format=output_format) write_to_file(deconvoluted_result, deconvoluted_filename, output_path=output_path, output_format=output_format) if analysis_type == "statistical": write_to_file(pvalues_result, pvalues_filename, output_path=output_path, output_format=output_format)
def _prepare_data(self, cluster_res_key): if cluster_res_key not in self.pipeline_res: raise PipelineResultInexistent(cluster_res_key) cluster: pd.DataFrame = self.pipeline_res[cluster_res_key].copy() cluster['bins'] = cluster['bins'].astype(str) cluster.rename({'group': 'cell_type'}, axis=1, inplace=True) cluster.set_index('bins', drop=True, inplace=True) cluster.index.name = 'cell' if self.stereo_exp_data.issparse(): data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T.toarray()) else: data = pd.DataFrame(self.stereo_exp_data.exp_matrix.T) data.columns = self.stereo_exp_data.cell_names.astype(str) data.index = self.stereo_exp_data.gene_names return data, cluster def _check_database(self, database: str): if (database is None) or (not isinstance(database, str)): return None database_dir = Path(stereo_conf.data_dir, "algorithm/cell_cell_communication/database") path = Path(database) if path.is_dir(): return None if path.is_file(): return path.absolute().as_posix() if path.exists() else None if database not in ('cellphonedb', 'liana', 'celltalkdb'): return None return (database_dir / f'{database}.db').absolute().as_posix() def _get_ref_database(self, db_path): """ preprocessing the reference database """ url = 'sqlite:///{}'.format(db_path) engine = create_engine(url) database = Database(engine) database.base_model = Base database_manager = DatabaseManager(None, database) # load repositories database_manager.add_repository(ComplexRepository) database_manager.add_repository(GeneRepository) database_manager.add_repository(InteractionRepository) database_manager.add_repository(MultidataRepository) database_manager.add_repository(ProteinRepository) # get data form database interactions = database_manager.get_repository('interaction').get_all_expanded(include_gene=False) genes = database_manager.get_repository('gene').get_all_expanded() # join gene, protein, multidata complex_composition = database_manager.get_repository('complex').get_all_compositions() complex_expanded = database_manager.get_repository('complex').get_all_expanded() # index interactions and complex dataframes interactions.set_index('id_interaction', drop=True, inplace=True) complex_composition.set_index('id_complex_composition', inplace=True, drop=True) return interactions, genes, complex_composition, complex_expanded def _check_meta_data(self, meta_raw: pd.DataFrame): """ Preprocess the meta dataframe: When the dataframe has both "cell" and "cell_type" columns, take "cell" as the index and "cell_type" as the only column. When the dataframe does not have a "cell" column and the index type is range, take the 1st column of meta as "cell" and use it as the index. When the dataframe does not have a "cell" column and the index type is base, rename the index as "cell". When the dataframe has no "cell" and "cell_type" at all, take the 1st column as the index and the 2nd column as "cell_type". """ meta_raw.columns = map(str.lower, meta_raw.columns) try: if 'cell' in meta_raw and 'cell_type' in meta_raw: meta = meta_raw[['cell', 'cell_type']] meta.set_index('cell', inplace=True, drop=True) return meta if type(meta_raw.index) == pd.core.indexes.multi.MultiIndex: # noqa raise ProcessMetaException elif 'cell_type' in meta_raw: meta = meta_raw[['cell_type']] if type(meta_raw.index) == pd.core.indexes.range.RangeIndex: # noqa meta.set_index(meta_raw.iloc[:, 0], inplace=True) meta.index.name = 'cell' return meta if type(meta_raw.index) == pd.core.indexes.base.Index: # noqa meta.index.name = 'cell' return meta meta = pd.DataFrame(data={'cell_type': meta_raw.iloc[:, 1]}) meta.set_index(meta_raw.iloc[:, 0], inplace=True) meta.index.name = 'cell' meta.index = meta.index.astype(str) return meta except Exception: raise ProcessMetaException def _check_counts_data(self, counts: pd.DataFrame, counts_data: str) -> None: """Naive check count data against counts gene names. This method quickly checks if count_data matches the all gene names and gives a comprehensive warning. Parameters ---------- counts: pd.DataFrame Counts data counts_data: str Gene identifier expected in counts data """ if ~np.all(counts.index.str.startswith(("ENSG0", "ENSMUSG0"))) and counts_data == "ensembl": logger.warning(f"Gene format missmatch. Using gene type '{counts_data}' " f"expects gene names to start with 'ENSG' (human) or 'ENSMUSG0' (mouse) but " f"some genes seem to be in another format. " f"Try using hgnc_symbol if all counts are filtered.") def _counts_validations(self, counts: pd.DataFrame, meta: pd.DataFrame) -> pd.DataFrame: """ Change counts type to np.float32. Check if the column names of counts matches the indexes of meta. """ if not len(counts.columns): raise ParseCountsException('Counts values are not decimal values', 'Incorrect file format') try: if np.any(counts.dtypes.values != np.dtype('float32')): counts = counts.astype(np.float32) except Exception: raise ParseCountsException("Counts values cannot be changed to np.float32") meta.index = meta.index.astype(str) if np.any(~meta.index.isin(counts.columns)): raise ParseCountsException("Some cells in meta did not exist in counts", "Maybe incorrect file format") if np.any(~counts.columns.isin(meta.index)): logger.debug("Dropping counts cells that are not present in meta") counts = counts.loc[:, counts.columns.isin(meta.index)].copy() return counts def _check_microenvs_data(self, microenvs: pd.DataFrame, meta: pd.DataFrame) -> pd.DataFrame: """ Runs validations to make sure the file has enough columns and that all the cell types in the microenvironment are included in meta. Rename the two columns as "cell_type" and "microenvironment". """ microenvs.drop_duplicates(inplace=True) len_columns = len(microenvs.columns) if len_columns < 2: raise Exception(f"Missing columns in microenvironments: 2 required but {len_columns} provieded") elif len_columns > 2: logger.warning(f"Microenvrionemnts expects 2 columns and got {len_columns}. Droppoing extra columns.") microenvs = microenvs.iloc[:, 0:2] if not all(microenvs.iloc[:, 0].isin(meta['cell_type'])): raise Exception("Some clusters/cell_types in microenvironments are not present in meta") microenvs.columns = ["cell_type", "microenvironment"] return microenvs def add_multidata_and_means_to_counts(self, counts: pd.DataFrame, genes: pd.DataFrame, counts_identifiers: str): """Adds multidata and means to counts. This method merges multidata ids into counts data using counts_identifiers as column name for the genes. Then sorts the counts columns based on the cell names, makes sure count data is of type float32 and finally calculates the means grouped by id_multidata. Returns ------- Tuple: A tuple containing: - counts: counts data merged with mutidata and indexsed by id_multidata - counts_relations: a subset of counts with only id_multidata and all gene identifiers """ # sort cell names cells_names = sorted(counts.columns) # add id multidata to counts input, INNER join, new index is the range index in genes. counts = counts.merge( genes[['id_multidata', 'ensembl', 'gene_name', 'hgnc_symbol']], left_index=True, right_on=counts_identifiers ) if counts.empty: raise AllCountsFilteredException(hint='Are you using human data?') counts_relations = counts[['id_multidata', 'ensembl', 'gene_name', 'hgnc_symbol']].copy() counts.set_index('id_multidata', inplace=True, drop=True) # id_multidata not unique counts = counts[cells_names] if np.any(counts.dtypes.values != np.dtype('float32')): counts = counts.astype(np.float32) # one protein could correspond to multiple genes # e.g. genes.loc[[21,22]] counts = counts.groupby(counts.index).mean() return counts, counts_relations def prefilters( self, interactions: pd.DataFrame, counts: pd.DataFrame, complex_composition: pd.DataFrame ): """ Filter complex_composition, interaction and counts. """ # Remove rows with all zero values if counts.empty: counts_filtered = counts else: counts_filtered = counts[counts.apply(lambda row: row.sum() > 0, axis=1)] # Filter complex_composition, keeping only complexes whose composing proteins are all in counts. # Also filter the counts, keep only proteins in the filtered complex_composition. complex_composition_filtered, counts_complex = self._filter_complex_composition_by_counts(counts_filtered, complex_composition) # Filter interactions, keeping only interactions whose two parts are both in the filtered complex or counts. interactions_filtered = self._filter_interactions_by_counts(interactions, counts_filtered, complex_composition_filtered) # Filter counts, keeping only simple proteins in the interactions. counts_simple = self._filter_counts_by_interactions(counts_filtered, interactions_filtered) # Combine counts of proteins in the interaction and proteins in the complexes. counts_filtered = counts_simple.append(counts_complex, sort=False) counts_filtered = counts_filtered[~counts_filtered.index.duplicated()] return complex_composition_filtered, interactions_filtered, counts_filtered def _filter_complex_composition_by_counts( self, counts: pd.DataFrame, complex_composition: pd.DataFrame ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Filter the counts and complex_composition: - keep only complexes whose composing proteins are all in the counts. - keep only the above proteins in counts. Returns ------- Tuple: A tuple containing: - complex_composition filtered - counts filtered """ proteins_in_complexes = complex_composition['protein_multidata_id'].drop_duplicates().tolist() # Remove counts that can't be part of a complex counts_filtered = counts[counts.apply(lambda count: count.name in proteins_in_complexes, axis=1)] # Find complexes with all components defined in counts multidata_protein = list(counts_filtered.index) composition_filtered = complex_composition[complex_composition['protein_multidata_id'].apply( lambda protein_multidata: protein_multidata in multidata_protein)] if composition_filtered.empty: complex_composition_filtered = pd.DataFrame(columns=complex_composition.columns) else: def all_protein_involved(current_complex: pd.Series) -> bool: number_proteins_in_counts = len(composition_filtered[ composition_filtered['complex_multidata_id'] == current_complex[ 'complex_multidata_id']]) if number_proteins_in_counts < current_complex['total_protein']: return False return True complex_composition_filtered = composition_filtered[ composition_filtered.apply(all_protein_involved, axis=1)] if complex_composition_filtered.empty: return complex_composition_filtered, pd.DataFrame(columns=counts.columns) available_complex_proteins = complex_composition_filtered['protein_multidata_id'].drop_duplicates().to_list() # Remove counts that are not defined in selected complexes counts_filtered = counts_filtered[ counts_filtered.apply(lambda count: count.name in available_complex_proteins, axis=1)] return complex_composition_filtered, counts_filtered def _filter_interactions_by_counts( self, interactions: pd.DataFrame, counts: pd.DataFrame, complex_composition: pd.DataFrame ) -> pd.DataFrame: """ Use filtered complex_composition and unfiltered counts to filter interactions. - keep only interactions that both parts are in the complex or counts. """ multidatas = list(counts.index) if not complex_composition.empty: multidatas += complex_composition['complex_multidata_id'].to_list() + complex_composition[ 'protein_multidata_id'].to_list() multidatas = list(set(multidatas)) def filter_interactions(interaction: pd.Series) -> bool: if interaction['multidata_1_id'] in multidatas and interaction['multidata_2_id'] in multidatas: return True return False interactions_filtered = interactions[interactions.apply(filter_interactions, axis=1)] return interactions_filtered def _filter_counts_by_interactions( self, counts: pd.DataFrame, interactions: pd.DataFrame ) -> pd.DataFrame: """ Removes counts if is not defined in interactions components. """ multidata_ids = interactions['multidata_1_id'].append( interactions['multidata_2_id']).drop_duplicates().tolist() counts_filtered = counts.filter(multidata_ids, axis=0) return counts_filtered def build_clusters( self, meta: pd.DataFrame, counts: pd.DataFrame, complex_composition: pd.DataFrame, skip_percent: bool ) -> dict: """ Build the means and percent values for each cluster and stores the results in a dictionary with the following keys: 'names', 'means' and 'percents'. Parameters ---------- meta: pd.DataFrame Meta data. counts: pd.DataFrame Counts data complex_composition: pd.DataFrame Complex data. skip_percent: bool Calculate percent values for each cluster or not. Returns ------- dict: Dictionary containing the following: - names: cluster names - means: cluster means - percents: cluster percents """ meta['cell_type'] = meta['cell_type'].astype('category') cluster_names = meta['cell_type'].cat.categories # cell counts to cluster counts cluster_means = pd.DataFrame( npg.aggregate(meta['cell_type'].cat.codes, counts.values, func='mean', axis=1), index=counts.index, columns=cluster_names.to_list() ) if not skip_percent: cluster_pcts = pd.DataFrame( npg.aggregate(meta['cell_type'].cat.codes, (counts > 0).astype(int).values, func='mean', axis=1), index=counts.index, columns=cluster_names.to_list() ) else: cluster_pcts = pd.DataFrame(index=counts.index, columns=cluster_names.to_list()) # Complex genes cluster counts if not complex_composition.empty: complexes = complex_composition.groupby('complex_multidata_id').apply( lambda x: x['protein_multidata_id'].values).to_dict() # complex_id as rows, clusters as columns complex_cluster_means = pd.DataFrame( {complex_id: cluster_means.loc[protein_ids].min(axis=0).values for complex_id, protein_ids in complexes.items()}, index=cluster_means.columns ).T cluster_means = cluster_means.append(complex_cluster_means) if not skip_percent: complex_cluster_pcts = pd.DataFrame( {complex_id: cluster_pcts.loc[protein_ids].min(axis=0).values for complex_id, protein_ids in complexes.items()}, index=cluster_pcts.columns ).T cluster_pcts = cluster_pcts.append(complex_cluster_pcts) return {'names': cluster_names, 'means': cluster_means, 'percents': cluster_pcts} def get_cluster_combinations(self, cluster_names: np.array, microenvs: pd.DataFrame = pd.DataFrame()) -> np.array: """ Calculates and sorts combinations of clusters. Generates all possible combinations between the 'cluster_names' provided. Combinations include each cluster with itself. If `microenvs` is provided then the combinations are limited to the clusters within each microenvironment as specified. Parameters ---------- cluster_names: np.array Array of cluster names. microenvs: pd.DataFrame Microenvironments data. Example ------- INPUT cluster_names = ['cluster1', 'cluster2', 'cluster3'] RESULT [('cluster1','cluster1'),('cluster1','cluster2'),('cluster1','cluster3'), ('cluster2','cluster1'),('cluster2','cluster2'),('cluster2','cluster3'), ('cluster3','cluster1'),('cluster3','cluster2'),('cluster3','cluster3')] if microenvironments are provided combinations are performed only within each microenv INPUT cluster_names = ['cluster1', 'cluster2', 'cluster3'] microenvs = [ ('cluster1', 'env1'), ('cluster2', 'env1'), ('cluster3', 'env2')] RESULT [('cluster1','cluster1'),('cluster1','cluster2'), ('cluster2','cluster1'),('cluster2','cluster2'), ('cluster3','cluster3')] Returns ------- np.array An array of arrays representing cluster combinations. Each inner array represents the combination of two clusters. """ result = np.array([]) if microenvs.empty: result = np.array(np.meshgrid(cluster_names.values, cluster_names.values)).T.reshape(-1, 2) else: logger.info('Limiting cluster combinations using microenvironments') cluster_combinations = [] for me in microenvs["microenvironment"].unique(): me_cell_types = microenvs[microenvs["microenvironment"] == me]["cell_type"] combinations = np.array(np.meshgrid(me_cell_types, me_cell_types)) cluster_combinations.extend(combinations.T.reshape(-1, 2)) result = pd.DataFrame(cluster_combinations).drop_duplicates().to_numpy() logger.debug(f'Using {len(result)} cluster combinations for analysis') return result def build_result_matrix(self, interactions: pd.DataFrame, cluster_interactions: list, separator: str) -> pd.DataFrame: """ builds an empty cluster matrix to fill it later, index is id_interaction """ columns = [] for cluster_interaction in cluster_interactions: columns.append('{}{}{}'.format(cluster_interaction[0], separator, cluster_interaction[1])) result = pd.DataFrame(index=interactions.index, columns=columns, dtype=float) return result def mean_analysis( self, interactions: pd.DataFrame, clusters: dict, cluster_interactions: list, separator: str ) -> pd.DataFrame: """ Calculates the mean for the list of interactions and for each cluster interaction Based on the interactions from CellPhoneDB database (gene1|gene2) and each cluster means (gene|cluser) this method calculates the mean of an interaction (gene1|gene2) and a cluster combination (cluster1|cluster2). When any of the values is 0, the result is set to 0, otherwise the mean is used. The following expression is used to get the result `(x > 0) * (y > 0) * (x + y) / 2` where `x = mean(gene1|cluster1)` and `y = mean(gene2|cluster2)` and the output is expected to be mean(gene1|gene2, cluster1|cluster2). Parameters ---------- interactions: pd.DataFrame Interactions from CellPhoneDB database. Gene names will be taken from here and interpret as 'multidata_1_id' for gene1 and 'multidata_2_id' for gene2. clusters: dict Clusters information. 'means' key will be used to get the means of a gene/cluster combination/ cluster_interactions: list List of cluster interactions obtained from the combination of the cluster names and possibly filtered using microenvironments. separator: str Character to use as a separator when joining cluster as column names. Example ---------- cluster_means cluster1 cluster2 cluster3 ensembl1 0.0 0.2 0.3 ensembl2 0.4 0.5 0.6 ensembl3 0.7 0.0 0.9 interactions: ensembl1,ensembl2 ensembl2,ensembl3 RESULT: cluster1_cluster1 cluster1_cluster2 ... cluster3_cluster2 cluster3_cluster3 ensembl1_ensembl2 mean(0.0,0.4)* mean(0.0,0.5)* mean(0.3,0.5) mean(0.3,0.6) ensembl2_ensembl3 mean(0.4,0.7) mean(0.4,0.0)* mean(0.6,0.0)* mean(0.6,0.9) results with * are 0 because one of both components is 0. Returns ------- DataFrame A DataFrame where each column is a cluster combination (cluster1|cluster2) and each row represents an interaction (gene1|gene2). Values are the mean for that interaction and that cluster combination. """ GENE_ID1 = 'multidata_1_id' GENE_ID2 = 'multidata_2_id' cluster1_names = cluster_interactions[:, 0] cluster2_names = cluster_interactions[:, 1] gene1_ids = interactions[GENE_ID1].values gene2_ids = interactions[GENE_ID2].values x = clusters['means'].loc[gene1_ids, cluster1_names].values y = clusters['means'].loc[gene2_ids, cluster2_names].values result = pd.DataFrame( (x > 0) * (y > 0) * (x + y) / 2, index=interactions.index, columns=(pd.Series(cluster1_names) + separator + pd.Series(cluster2_names)).values) return result def percent_analysis( self, clusters: dict, threshold: float, interactions: pd.DataFrame, cluster_interactions: list, separator: str ) -> pd.DataFrame: """ Calculates the percents for cluster interactions and for each gene interaction. This method builds a gene1|gene2,cluster1|cluster2 table of percent values. As the first step, calculates the percents for each gene|cluster. The cluster percent is 0 if the number of positive cluster cells divided by total of cluster cells is greater than threshold and 1 if not. If one of both is NOT 0 then sets the value to 0 else sets the value to 1. Then it calculates the percent value of the interaction. Parameters ---------- clusters: dict Clusters information. 'percents' key will be used to get the precent of a gene/cell combination. threshold: float Cutoff value for percentages (number of positive cluster cells divided by total of cluster cells). interactions: pd.DataFrame Interactions from CellPhoneDB database. Gene names will be taken from here and interpret as 'multidata_1_id' for gene1 and 'multidata_2_id' for gene2. cluster_interactions: list List of cluster interactions obtained from the combination of the cluster names and possibly filtered using microenvironments. separator: str Character to use as a separator when joining cluster as column names. Returns ---------- pd.DataFrame: A DataFrame where each column is a cluster combination (cluster1|cluster2) and each row represents an interaction (gene1|gene2). Values are the percent values calculated for each interaction and cluster combination. """ GENE_ID1 = 'multidata_1_id' GENE_ID2 = 'multidata_2_id' cluster1_names = cluster_interactions[:, 0] cluster2_names = cluster_interactions[:, 1] gene1_ids = interactions[GENE_ID1].values gene2_ids = interactions[GENE_ID2].values x = clusters['percents'].loc[gene1_ids, cluster1_names].values y = clusters['percents'].loc[gene2_ids, cluster2_names].values result = pd.DataFrame( ((x > threshold) * (y > threshold)).astype(int), index=interactions.index, columns=(pd.Series(cluster1_names) + separator + pd.Series(cluster2_names)).values) return result def shuffled_analysis( self, iterations: int, meta: pd.DataFrame, counts: pd.DataFrame, interactions: pd.DataFrame, cluster_interactions: list, complex_composition: pd.DataFrame, real_mean_analysis: pd.DataFrame, processes: int, separator: str ) -> list: """ Shuffles meta and calculates the means for each and saves it in a list. Runs it in a multiple processes to run it faster Note that on notebook just only support one process """ statistical_analysis_thread = partial(self._statistical_analysis, cluster_interactions, counts, interactions, meta, complex_composition, separator, real_mean_analysis) if processes > 1: with Pool(processes=processes) as pool: results = pool.map(statistical_analysis_thread, range(iterations)) else: results = [statistical_analysis_thread(i) for i in tqdm(range(iterations), desc='statistical analysis', ncols=100)] return results def _statistical_analysis( self, cluster_interactions: list, counts: pd.DataFrame, interactions: pd.DataFrame, meta: pd.DataFrame, complex_composition: pd.DataFrame, separator: str, real_mean_analysis: pd.DataFrame, iteration_number: int ): """ Shuffles meta dataset and calculates the means """ def shuffle_meta(meta: pd.DataFrame) -> pd.DataFrame: """ Get a randomly shuffled copy of the input meta. """ meta_copy = meta.copy() labels = list(meta_copy['cell_type'].values) np.random.shuffle(labels) meta_copy['cell_type'] = labels return meta_copy shuffled_meta = shuffle_meta(meta) shuffled_clusters = self.build_clusters(shuffled_meta, counts, complex_composition, skip_percent=True) shuffled_mean_analysis = self.mean_analysis(interactions, shuffled_clusters, cluster_interactions, separator) result_mean_analysis = np.packbits(shuffled_mean_analysis.values > real_mean_analysis.values, axis=None) return result_mean_analysis def build_pvalue_result( self, real_mean_analysis: pd.DataFrame, real_percents_analysis: pd.DataFrame, statistical_mean_analysis: list, base_result: pd.DataFrame ) -> pd.DataFrame: """ Calculates the pvalues after statistical analysis. If real_percent or real_mean are zero, result_pvalue is 1 If not: Calculates how many shuffled means are bigger than real mean and divides it for the number of the total iterations Parameters ---------- real_mean_analysis: pd.DataFrame Means cluster analyisis real_percents_analysis: pd.DataFrame Percents cluster analyisis statistical_mean_analysis: list Statitstical means analyisis base_result: pd.DataFrame Contains the index and columns that will be used by the returned object Returns ------- pd.DataFrame A DataFrame with interactions as rows and cluster combinations as columns. """ logger.info('Building Pvalues result') percent_result = np.zeros(real_mean_analysis.shape) result_size = percent_result.size result_shape = percent_result.shape for statistical_mean in statistical_mean_analysis: percent_result += np.unpackbits(statistical_mean, axis=None)[:result_size].reshape(result_shape) percent_result /= len(statistical_mean_analysis) mask = (real_mean_analysis.values == 0) | (real_percents_analysis == 0) percent_result[mask] = 1 return pd.DataFrame(percent_result, index=base_result.index, columns=base_result.columns) def build_results( self, analysis_type, interactions: pd.DataFrame, interactions_original: pd.DataFrame, counts_relations: pd.DataFrame, real_mean_analysis: pd.DataFrame, result_percent: pd.DataFrame, clusters_means: pd.DataFrame, complex_compositions: pd.DataFrame, counts: pd.DataFrame, genes: pd.DataFrame, result_precision: int, pvalue: float = None, counts_data: str = None, separator: str = '|' ): """ Sets the results data structure from method generated data. Results documents are defined by specs. """ logger.info('Building results') interactions: pd.DataFrame = interactions_original.loc[interactions.index] # get full interaction info interactions['interaction_index'] = interactions.index # add 'id_multidata', 'ensembl', 'gene_name', 'hgnc_symbol' interactions = interactions.merge(counts_relations, how='left', left_on='multidata_1_id', right_on='id_multidata', ) interactions = interactions.merge(counts_relations, how='left', left_on='multidata_2_id', right_on='id_multidata', suffixes=('_1', '_2')) interactions.set_index('interaction_index', inplace=True, drop=True) interacting_pair = self._interacting_pair_build(interactions, separator) def simple_complex_indicator(interaction: pd.Series, suffix: str) -> str: """ Add simple/complex prefixes to interaction components """ if interaction['is_complex{}'.format(suffix)]: return 'complex:{}'.format(interaction['name{}'.format(suffix)]) return 'simple:{}'.format(interaction['name{}'.format(suffix)]) interactions['partner_a'] = interactions.apply(lambda interaction: simple_complex_indicator(interaction, '_1'), axis=1) interactions['partner_b'] = interactions.apply(lambda interaction: simple_complex_indicator(interaction, '_2'), axis=1) significant_means = None significant_mean_rank = None if analysis_type == 'simple': significant_mean_rank, significant_means = self.build_significant_means(real_mean_analysis, result_percent) if analysis_type == 'statistical': significant_mean_rank, significant_means = self.build_significant_means(real_mean_analysis, result_percent, pvalue) significant_means = significant_means.round(result_precision) gene_columns = ['{}_{}'.format(counts_data, suffix) for suffix in ('1', '2')] # ['ensembl_1', 'ensembl_2'] gene_renames = {column: 'gene_{}'.format(suffix) for column, suffix in zip(gene_columns, ['a', 'b'])} # Remove useless columns interactions_data_result = pd.DataFrame( interactions[['id_cp_interaction', 'partner_a', 'partner_b', 'receptor_1', 'receptor_2', *gene_columns, 'annotation_strategy']].copy()) interactions_data_result = pd.concat([interacting_pair, interactions_data_result], axis=1, sort=False) interactions_data_result['secreted'] = (interactions['secreted_1'] | interactions['secreted_2']) interactions_data_result['is_integrin'] = (interactions['integrin_1'] | interactions['integrin_2']) interactions_data_result.rename( columns={**gene_renames, 'receptor_1': 'receptor_a', 'receptor_2': 'receptor_b'}, inplace=True) # Dedupe rows and filter only desired columns interactions_data_result.drop_duplicates(inplace=True) means_columns = ['id_cp_interaction', 'interacting_pair', 'partner_a', 'partner_b', 'gene_a', 'gene_b', 'secreted', 'receptor_a', 'receptor_b', 'annotation_strategy', 'is_integrin'] interactions_data_result = interactions_data_result[means_columns] real_mean_analysis = real_mean_analysis.round(result_precision) significant_means = significant_means.round(result_precision) # Round result decimals for key, cluster_means in clusters_means.items(): clusters_means[key] = cluster_means.round(result_precision) # Document 1 pvalues_result = pd.DataFrame() if analysis_type == 'statistical': pvalues_result = pd.merge(interactions_data_result, result_percent, left_index=True, right_index=True, how='inner') # Document 2 means_result = pd.merge(interactions_data_result, real_mean_analysis, left_index=True, right_index=True, how='inner') # Document 3 significant_means_result = pd.merge(interactions_data_result, significant_mean_rank, left_index=True, right_index=True, how='inner') significant_means_result = pd.merge(significant_means_result, significant_means, left_index=True, right_index=True, how='inner') # Document 4 deconvoluted_result = self.deconvoluted_complex_result_build(clusters_means, interactions, complex_compositions, counts, genes, counts_data) def fillna_func(column: pd.Series): if column.dtype == object: return column.fillna(value='') if column.dtype == bool: return column.fillna(value=False) return column.fillna(value=-1) pvalues_result = pvalues_result.apply(fillna_func, axis=0) means_result = means_result.apply(fillna_func, axis=0) significant_means_result = significant_means_result.apply(fillna_func, axis=0) deconvoluted_result = deconvoluted_result.apply(fillna_func, axis=0) return pvalues_result, means_result, significant_means_result, deconvoluted_result def _interacting_pair_build(self, interactions: pd.DataFrame, separator) -> pd.Series: """ Returns the interaction result formated with name1_name2 """ def get_interactor_name(interaction: pd.Series, suffix: str) -> str: """ If part of interaction is complex, return name; if not, return gene_name """ if interaction['is_complex{}'.format(suffix)]: return interaction['name{}'.format(suffix)] return interaction['gene_name{}'.format(suffix)] interacting_pair = interactions.apply( lambda interaction: '{}{}{}'.format(get_interactor_name(interaction, '_1'), separator, get_interactor_name(interaction, '_2')), axis=1) interacting_pair.rename('interacting_pair', inplace=True) return interacting_pair def build_significant_means( self, real_mean_analysis: pd.DataFrame, result_percent: pd.DataFrame, min_significant_mean: float = None ) -> Tuple[pd.Series, pd.DataFrame]: """ Calculates the significant means and adds rank (number of non-empty entries divided by total entries) :param real_mean_analysis: the real mean results :param result_percent: the real percent results if simple analysis, pvalue results if statistical analysis :param min_significant_mean: """ significant_means = self._get_significant_means(real_mean_analysis, result_percent, min_significant_mean) significant_mean_rank = significant_means.count(axis=1) # type: pd.Series number_of_clusters = len(significant_means.columns) significant_mean_rank = significant_mean_rank.apply(lambda rank: rank / number_of_clusters) significant_mean_rank = significant_mean_rank.round(3) significant_mean_rank.name = 'rank' return significant_mean_rank, significant_means def _get_significant_means( self, real_mean_analysis: pd.DataFrame, result_percent: pd.DataFrame, min_significant_mean: float = None ) -> pd.DataFrame: """ Get the significant means for gene1_gene2|cluster1_cluster2. For statistical_analysis `min_signigicant_mean` needs to be provided and if `result_percent > min_significant_mean` then sets the value to NaN otherwise uses the mean. For simple analysis `min_signigicant_mean` is NOT provided and uses `result_percent == 0` to set NaN, otherwise uses the mean. Parameters ---------- real_mean_analysis : pd.DataFrame Mean results for each gene|cluster combination result_percent : pd.DataFrame Percent results for each gene|cluster combination - Simple analysis: real percent result - Statistical analysis: p-value result min_significant_mean : float,optional - Simple analysis, 0. - Statistical analysis: Filter p-value > min_significant_mean. Returns ------- pd.DataFrame Significant means data frame. Columns are cluster interactions (cluster1|cluster2) and rows are NaN if there is no significant interaction or the mean value of the interaction if it is a relevant interaction. """ significant_means = real_mean_analysis.values.copy() if min_significant_mean: mask = result_percent > min_significant_mean else: mask = result_percent == 0 significant_means[mask] = np.nan return pd.DataFrame(significant_means, index=real_mean_analysis.index, columns=real_mean_analysis.columns) def deconvoluted_complex_result_build( self, clusters_means: pd.DataFrame, interactions: pd.DataFrame, complex_compositions: pd.DataFrame, counts: pd.DataFrame, genes: pd.DataFrame, counts_data: str ) -> pd.DataFrame: genes_counts = list(counts.index) genes_filtered = genes[genes['id_multidata'].apply(lambda gene: gene in genes_counts)] deconvoluted_complex_result_1 = self._deconvolute_complex_interaction_component( complex_compositions, genes_filtered, interactions, '_1', counts_data) deconvoluted_simple_result_1 = self._deconvolute_interaction_component(interactions, '_1', counts_data) deconvoluted_complex_result_2 = self._deconvolute_complex_interaction_component( complex_compositions, genes_filtered, interactions, '_2', counts_data) deconvoluted_simple_result_2 = self._deconvolute_interaction_component(interactions, '_2', counts_data) deconvoluted_result = deconvoluted_complex_result_1.append( [deconvoluted_simple_result_1, deconvoluted_complex_result_2, deconvoluted_simple_result_2], sort=False) deconvoluted_result.set_index('multidata_id', inplace=True, drop=True) deconvoluted_columns = ['gene_name', 'name', 'is_complex', 'protein_name', 'complex_name', 'id_cp_interaction', 'gene'] deconvoluted_result = deconvoluted_result[deconvoluted_columns] deconvoluted_result.rename({'name': 'uniprot'}, axis=1, inplace=True) deconvoluted_result = pd.concat([deconvoluted_result, clusters_means.reindex(deconvoluted_result.index)], axis=1, join='inner', sort=False) deconvoluted_result.set_index('gene', inplace=True, drop=True) deconvoluted_result.drop_duplicates(inplace=True) return deconvoluted_result def _deconvolute_interaction_component(self, interactions, suffix, counts_data): interactions = interactions[~interactions['is_complex{}'.format(suffix)]] deconvoluted_result = pd.DataFrame() deconvoluted_result['gene'] = interactions['{}{}'.format(counts_data, suffix)] interactions_index_data = interactions[ ['multidata{}_id'.format(suffix), 'protein_name{}'.format(suffix), 'gene_name{}'.format(suffix), 'name{}'.format(suffix), 'is_complex{}'.format(suffix), 'id_cp_interaction', 'receptor{}'.format(suffix)]] deconvoluted_result[['multidata_id', 'protein_name', 'gene_name', 'name', 'is_complex', 'id_cp_interaction', 'receptor']] = interactions_index_data deconvoluted_result['complex_name'] = np.nan return deconvoluted_result def _deconvolute_complex_interaction_component( self, complex_compositions, genes_filtered, interactions, suffix, counts_data ): return_properties = [counts_data, 'protein_name', 'gene_name', 'name', 'is_complex', 'id_cp_interaction', 'receptor', 'complex_name'] if complex_compositions.empty: return pd.DataFrame( columns=return_properties) deconvoluted_result = pd.DataFrame() component = pd.DataFrame() component[counts_data] = interactions['{}{}'.format(counts_data, suffix)] interactions_index_data = interactions[ ['{}{}'.format(counts_data, suffix), 'protein_name{}'.format(suffix), 'gene_name{}'.format(suffix), 'name{}'.format(suffix), 'is_complex{}'.format(suffix), 'id_cp_interaction', 'multidata{}_id'.format(suffix), 'receptor{}'.format(suffix)]] component[[counts_data, 'protein_name', 'gene_name', 'name', 'is_complex', 'id_cp_interaction', 'id_multidata', 'receptor']] = interactions_index_data deconvolution_complex = pd.merge(complex_compositions, component, left_on='complex_multidata_id', right_on='id_multidata') deconvolution_complex = pd.merge(deconvolution_complex, genes_filtered, left_on='protein_multidata_id', right_on='protein_multidata_id', suffixes=['_complex', '_simple']) deconvoluted_result['gene'] = deconvolution_complex['{}_simple'.format(counts_data)] deconvolution_complex_index_data = deconvolution_complex[ ['complex_multidata_id', 'protein_name_simple', 'gene_name_simple', 'name_simple', 'is_complex_complex', 'id_cp_interaction', 'receptor_simple', 'name_complex']] deconvoluted_result[['multidata_id', 'protein_name', 'gene_name', 'name', 'is_complex', 'id_cp_interaction', 'receptor', 'complex_name']] = deconvolution_complex_index_data return deconvoluted_result