Source code for stereo.plots.plot_ccc

import json
import os
from pathlib import Path
from typing import (
    Union,
    List,
    Dict,
    Optional
)

import matplotlib.pyplot as plt
import natsort
import networkx as nx
# third part module
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns
from pycirclize import Circos

from stereo.algorithm.cell_cell_communication.analysis_helper import mouse2human
from stereo.algorithm.cell_cell_communication.exceptions import PipelineResultInexistent
# module in self project
from stereo.log_manager import logger
from stereo.plots.plot_base import PlotBase
from stereo.stereo_config import stereo_conf


class PlotCellCellCommunication(PlotBase):
    def __init__(self, *args, **kwargs):
        super(PlotCellCellCommunication, self).__init__(*args, **kwargs)
        PlotCellCellCommunication.ccc_sankey_plot.__download__ = False

    # TODO: change default paths
[docs] def ccc_dot_plot( self, interacting_pairs: Union[str, list, np.ndarray] = None, clusters1: Union[str, list, np.ndarray] = None, clusters2: Union[str, list, np.ndarray] = None, separator_cluster: str = '|', palette: str = 'Reds', res_key: str = 'cell_cell_communication', width: int = None, height: int = None, **kw_args ): """Generate dot plot based on the result of CellCellCommunication. :param interacting_pairs: path, string, list or ndarray. specify the interacting pairs which would be shown on plot, defaults to None. 1) path: the path of file in which saves the interacting pairs which would be shown, one line one pair. 2) string: only one interacting pair. 3) list or ndarray: an array contains the interacting pairs which would be shown. :param clusters1: path, string, list or ndarray. the first clusters in cluster pairs which would be shown on plot, defaults to None. 1) path: the path of file in which saves the clusters which would be shown, one line one cluster. 2) string: only one cluster. 3) list or ndarray: an array contains the clusters which would be shown. :param clusters2: path, string, list or ndarray. the second clusters in cluster pairs which would be shown on plot, defaults to None. clusters1 and clusters2 together form cluster pairs each cluster in cluster1 will join with each one in cluster2 to form the cluster pairs. if set it to None, it will be set to all clusters. 1) path: the path of file in which saves the clusters which would be shown, one line one cluster. 2) string: only one cluster. 3) list or ndarray: an array contains the clusters which would be shown. :param separator_cluster: the symbol for joining the clusters1 and clusters2, defaults to '|' :param palette: plot palette, defaults to 'Reds' :param res_key: the key which specifies the cell cell communication result in data.tl.result, defaults to 'cell_cell_communication' :param width: the figure width in pixels, defaults to None :param height: the figure height in pixels, defaults to None :return: matplotlib.figure """ # noqa logger.info('Generating dot plot') if res_key not in self.pipeline_res: PipelineResultInexistent(res_key) if self.pipeline_res[res_key]['parameters']['analysis_type'] != 'statistical': logger.warning("This plot just only support analysis type 'statistical'") return None means_df = self.pipeline_res[res_key]['means'] pvalues_df = self.pipeline_res[res_key]['pvalues'] interacting_pairs = self._parse_interacting_pairs_or_clusters(interacting_pairs) if interacting_pairs is None: interacting_pairs = means_df['interacting_pair'].tolist() else: if all(np.isin(interacting_pairs, means_df['interacting_pair']) == False): # noqa raise Exception("there is no interacting pairs to show, maybe the parameter 'interacting_pairs' " "you set is not in result.") clusters1 = self._parse_interacting_pairs_or_clusters(clusters1) clusters2 = self._parse_interacting_pairs_or_clusters(clusters2) if clusters1 is None: cluster_pairs = natsort.natsorted([x for x in means_df.columns if separator_cluster in x]) else: if clusters2 is None: cluster_res_key = self.pipeline_res[res_key]['parameters']['cluster_res_key'] clusters2 = self.pipeline_res[cluster_res_key]['group'].unique() cluster_pairs = [f'{c1}{separator_cluster}{c2}' for c1 in clusters1 for c2 in clusters2] cluster_pairs = natsort.natsorted([x for x in cluster_pairs if x in means_df.columns]) if len(cluster_pairs) == 0: raise Exception( "there is no cluster pairs to show, maybe the parameter 'clusters' you set is not in result.") means_selected: pd.DataFrame = means_df[means_df['interacting_pair'].isin(interacting_pairs)][ ['interacting_pair'] + cluster_pairs] pvalues_selected: pd.DataFrame = pvalues_df[pvalues_df['interacting_pair'].isin(interacting_pairs)][ ['interacting_pair'] + cluster_pairs] nrows, ncols = means_selected.shape means = means_selected.melt(id_vars='interacting_pair', value_vars=cluster_pairs, value_name='mean') means['log2(mean+1)'] = np.log2(means['mean'] + 1) pvalues = pvalues_selected.melt(id_vars='interacting_pair', value_vars=cluster_pairs, value_name='pvalue') pvalues['-log10(pvalue)'] = pvalues['pvalue'].apply( lambda x: -np.log10(0.000001) if x == 0 else (-np.log10(x) if x != 1 else 0)) result = pd.merge(means, pvalues, on=["interacting_pair", "variable"]) result = result.rename(columns={'variable': 'cluster_pair'}) # plotting if width is None or height is None: width, height = int(5 + max(3, ncols * 0.8)), int(3 + max(5, nrows * 0.5)) else: width = width / 100 if width >= 100 else int(5 + max(3, ncols * 0.8)) height = height / 100 if height >= 100 else int(3 + max(5, nrows * 0.5)) fig, ax = plt.subplots(figsize=(width, height)) sns.scatterplot(data=result, x="cluster_pair", y="interacting_pair", palette=palette, hue='log2(mean+1)', size='-log10(pvalue)', sizes=(100, 300), legend='auto', ax=ax, **kw_args) ax.legend(fontsize=12, frameon=False, ncol=1, loc=(1.02, 0)) ax.tick_params(axis='x', labelsize=12, labelrotation=90) ax.tick_params(axis='y', labelsize=12) ax.set_xlabel('') ax.set_ylabel('') return fig
[docs] def ccc_heatmap( self, pvalue: float = 0.05, separator_cluster: str = '|', res_key: str = 'cell_cell_communication', width: int = None, height: int = None, palette: str = 'coolwarm', **kwargs ): """ Heatmap of number of interactions in each cluster pairs. Each off-diagonal cell value equals the number of interactions from A to B + the number of interactions from B to A :param pvalue: the threshold of pvalue, defaults to 0.05 :param separator_cluster: the symbol for joining the first and second cluster in cluster pairs, defaults to '|' :param res_key: the key which specifies the cell cell communication result in data.tl.result, defaults to 'cell_cell_communication' :return: _description_ """ # noqa logger.info('Generating heatmap plot') if res_key not in self.pipeline_res: PipelineResultInexistent(res_key) if self.pipeline_res[res_key]['parameters']['analysis_type'] != 'statistical': logger.warning("This plot just only support analysis type 'statistical'") return None cluster_res_key = self.pipeline_res[res_key]['parameters']['cluster_res_key'] meta_df = self.pipeline_res[cluster_res_key] clusters_all = natsort.natsorted(meta_df['group'].unique()) n_cluster: int = len(clusters_all) pvalues_df = self.pipeline_res[res_key]['pvalues'] cluster_pairs = np.array(np.meshgrid(clusters_all, clusters_all)).T.reshape(-1, 2) network = pd.DataFrame(cluster_pairs, columns=['source', 'target']) for index, row in network.iterrows(): col1 = row['source'] + separator_cluster + row['target'] col2 = row['target'] + separator_cluster + row['source'] if col1 in pvalues_df.columns and col2 in pvalues_df.columns: if col1 == col2: network.loc[index, 'number'] = (pvalues_df[col1] <= pvalue).sum() else: network.loc[index, 'number'] = (pvalues_df[col1] <= pvalue).sum() + ( pvalues_df[col2] <= pvalue).sum() else: network.loc[index, 'number'] = 0 network: pd.DataFrame = network.pivot("source", "target", "number") network = network.loc[clusters_all][clusters_all] rows = network.index.tolist() rows.reverse() network = network[rows] log_network = np.log1p(network) if width is None or height is None: width, height = int(3 + max(3, n_cluster * 0.5)) * 2, int(3 + max(3, n_cluster * 0.5)) else: width = width / 100 if width >= 100 else int(3 + max(3, n_cluster * 0.5)) * 2 height = height / 100 if height >= 100 else int(3 + max(3, n_cluster * 0.5)) fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(width, height), gridspec_kw={'wspace': 0}) sns.heatmap( data=network, square=True, cmap=palette, cbar_kws={'pad': 0.1, 'shrink': 0.5, 'location': 'top', 'orientation': 'horizontal'}, ax=axes[0], **kwargs) axes[0].yaxis.set_ticks_position('left') axes[0].invert_yaxis() axes[0].tick_params(axis='x', labelsize=13, labelrotation=90) axes[0].tick_params(axis='y', labelsize=13, labelrotation=0) axes[0].set_xlabel('') axes[0].set_ylabel('') axes[0].set_title('count') sns.heatmap( data=log_network, square=True, cmap=palette, cbar_kws={'pad': 0.1, 'shrink': 0.5, 'location': 'top', 'orientation': 'horizontal'}, ax=axes[1], **kwargs) axes[1].yaxis.set_ticks_position('right') axes[1].invert_yaxis() axes[1].tick_params(axis='x', labelsize=13, labelrotation=90) axes[1].tick_params(axis='y', labelsize=13, labelrotation=0) axes[1].set_xlabel('') axes[1].set_ylabel('') axes[1].set_title('log_count') return fig
[docs] def ccc_circos_plot( self, separator_cluster: str = '|', cluster_pair: list = None, palette: str = "RdYlBu_r", width: int = 8, height: int = 8, res_key: str = 'cell_cell_communication' ): # TODO: change default significant_path to the significant_means_xxx file generated by the main CCC analysis """ Circos plot of number of interactions in each cluster pairs. :param separator_cluster: separator used for cell cluster pairs. :param cluster_pair: if None, use all cluster pairs in the significant result; else list selected cluster pairs used in the plot. # noqa :param palette: colormap used. :param width: the figure width in pixels. :param height: the figure height in pixels. :param res_key: the key which specifies the cell cell communication result in data.tl.result, defaults to 'cell_cell_communication'. # noqa """ logger.info('Generating circos plot') if res_key not in self.pipeline_res: PipelineResultInexistent(res_key) significant_df: pd.DataFrame = self.pipeline_res[res_key]['significant_means'] if cluster_pair is None: significant_df = significant_df.drop( ['id_cp_interaction', 'interacting_pair', 'partner_a', 'partner_b', 'gene_a', 'gene_b', 'secreted', 'receptor_a', 'receptor_b', 'annotation_strategy', 'is_integrin', 'rank'], axis=1) significant_df = significant_df.dropna(how='all').reset_index(drop=True) cluster_pair = significant_df.columns else: significant_df = significant_df[cluster_pair] significant_df = significant_df.dropna(how='all').reset_index(drop=True) cell_types = [x.split(separator_cluster) for x in cluster_pair] cell_types = list(set([item for sublist in cell_types for item in sublist])) cell_types.sort() number_matrix = pd.DataFrame(index=cell_types, columns=cell_types) for source in cell_types: for target in cell_types: if source + separator_cluster + target in cluster_pair: number_matrix.loc[source, target] = significant_df[source + separator_cluster + target].count() number_matrix = number_matrix.fillna(0) circos = Circos.initialize_from_matrix( number_matrix, space=5, cmap=palette, label_kws=dict(size=12), link_kws=dict(ec="black", lw=0.5, direction=1), ) fig, ax = plt.subplots(figsize=(width, height), subplot_kw={"projection": "polar"}) circos.plotfig(ax=ax) return fig
[docs] def ccc_sankey_plot( self, sender_cluster: str, receiver_cluster: str, homo_transfer: bool = False, homogene_path: Optional[str] = None, separator_cluster: str = '|', separator_interaction: str = '_', weighted_network_path: str = None, regulons: Union[str, Dict[str, list]] = None, tfs: Optional[List[str]] = None, pct_expressed: float = 0.05, max_path_length: int = 4, width: int = 600, height: int = 880, out_path: Optional[str] = None, receptor_tf_paths_out_path: Optional[str] = None, res_key: str = 'cell_cell_communication', ): """ Sankey-plot showing inter- and/or intra-cellular gene interaction. Left pillar is ligands, middle pillar receptors, # noqa right pillar TFs. The width of each band is the average expression of the two genes. :param sender_cluster: sender cell type :param receiver_cluster: receiver cell type :param homo_transfer: If species is 'MOUSE' but database is 'cellphonedb' or 'liana', the gene names in the significant result have been transferred to 'HUMAN', we need to transfer them back in order to match the gene names in counts. :param homogene_path: path to the file storing mouse-human homologous genes ralations. :param separator_cluster: separator used for cell cluster pairs :param separator_interaction: separator used for LR interaction :param weighted_network_path: path to the weighted network :param regulons: path or dict to the spaGRN regulon output :param tfs: list of TFs used in the Sankey plot, if it is gave, the parameter `regulons` is ignored. :param pct_expressed: threshold used to detect expressed path between receptor and TF. :param max_path_length: the max path length between receptor and TF, paths longer than max_path_length is not considered as a potential pathway :param width: the figure width in pixels. :param height: the figure height in pixels. :param out_path: the path to save the plot. :param res_key: the key which specifies the cell cell communication result in data.tl.result, defaults to 'cell_cell_communication'. # noqa """ if res_key not in self.pipeline_res: PipelineResultInexistent(res_key) cluster_res_key = self.pipeline_res[res_key]['parameters']['cluster_res_key'] if cluster_res_key not in self.pipeline_res: PipelineResultInexistent(cluster_res_key) assert weighted_network_path is not None assert regulons is not None or tfs is not None cluster_pair_key = sender_cluster + separator_cluster + receiver_cluster significant_df: pd.DataFrame = self.pipeline_res[res_key]['significant_means'] significant_df = significant_df[['interacting_pair', 'partner_a', 'partner_b', cluster_pair_key]] def __clean_significant_df(row, cluster_pair_key): partner_a, partner_b = row['partner_a'], row['partner_b'] if row[cluster_pair_key] == -1: row[cluster_pair_key] = pd.NaT if partner_a.startswith('complex') or partner_b.startswith('complex'): return pd.NaT else: return row significant_df = significant_df.apply(__clean_significant_df, axis=1, result_type='broadcast', cluster_pair_key=cluster_pair_key) significant_df.dropna( how='any', subset=['partner_a', 'partner_b', cluster_pair_key], axis=0, inplace=True ) assert significant_df.size > 0, 'No significant LR pairs detected.' significant_df.reset_index(drop=True, inplace=True) ligands = list(set([lr.split('_')[0] for lr in significant_df['interacting_pair'].values])) ligands.sort() receptors = list(set([lr.split('_')[1] for lr in significant_df['interacting_pair'].values])) receptors.sort() significant_pairs = significant_df['interacting_pair'].values if homo_transfer: if homogene_path is None: homogene_path = Path(stereo_conf.data_dir, 'algorithm/cell_cell_communication/database/mouse2human.csv').absolute().as_posix() genes_mouse = self.stereo_exp_data.gene_names genes_human, human_genes_to_mouse = mouse2human(genes_mouse, homogene_path) ligands = [human_genes_to_mouse[x] for x in ligands] receptors = [human_genes_to_mouse[x] for x in receptors] significant_pairs_tmp = [] for lr in significant_pairs: interaction1, interaction2 = lr.split(separator_interaction) if interaction1 not in human_genes_to_mouse or interaction2 not in human_genes_to_mouse: continue significant_pairs_tmp.append( human_genes_to_mouse[interaction1] + separator_interaction + human_genes_to_mouse[interaction2]) significant_pairs = significant_pairs_tmp # Construct expressed weighted gene regulatory network counts_receiver = self._get_cell_counts(cluster_res_key, receiver_cluster) expressed_genes_receiver = self._get_expressed_genes(counts_receiver, pct_expressed) weighted_network_lr_sig = pd.read_csv(weighted_network_path, sep='\t') weighted_network_lr_sig_expressed = self._get_expressed_network(weighted_network_lr_sig, expressed_genes_receiver) weighted_network_lr_sig_expressed['distance'] = 1 / weighted_network_lr_sig_expressed['weight'] # return genes_mouse_obtained, counts_receiver, expressed_genes_receiver, weighted_network_lr_sig_expressed G = nx.DiGraph() for idx, row in weighted_network_lr_sig_expressed.iterrows(): G.add_edge(row['from'], row['to'], weight=row['weight'], distance=row['distance']) # Get TFs from the json file of GRN analysis if tfs is None: if isinstance(regulons, dict): tfs = list(map(lambda x: x[:-3], regulons.keys())) elif regulons.endswith('json'): with open(regulons, 'r') as fp: rgl = json.load(fp) tfs = list(map(lambda x: x[:-3], rgl.keys())) elif regulons.endswith('csv'): regulons_df = pd.read_csv(regulons, sep=',', header=0) tfs = regulons_df['Regulons'].apply(lambda x: x[:-3]).tolist() # Get paths between Receptors and TFs source_rtf = [] target_rtf = [] length_rtf = [] paths = [] for receptor in receptors: for tf in tfs: try: path, weight = self._get_shortest_path(G, source=receptor, target=tf, distance='distance', weight='weight') paths.append(path) source_rtf.append(receptor) target_rtf.append(tf) length_rtf.append(len(path)) except Exception: source_rtf.append(receptor) target_rtf.append(tf) paths.append([]) length_rtf.append(999) result_path = pd.DataFrame({'receptor': source_rtf, 'TF': target_rtf, 'path': paths, 'path_length': length_rtf}) if receptor_tf_paths_out_path: result_path.to_csv(receptor_tf_paths_out_path) result_path = result_path[result_path['path_length'] <= max_path_length] tfs = list(set(result_path['TF'])) tfs.sort() assert len(tfs) > 0, 'Can not find any path from Receptor to TF.' # Generate final data for plotting label = ligands + receptors + tfs counts_sender = self._get_cell_counts(cluster_res_key, sender_cluster) # The left part of Ligand-Receptor interaction source_lr = [] target_lr = [] value_lr = [] for i, ligand in enumerate(ligands): for j, receptor in enumerate(receptors): current_pair = ligand + separator_interaction + receptor if current_pair in significant_pairs: source_lr.append(i) target_lr.append(len(ligands) + j) mean_expression = (self._calculate_mean_expression(counts_sender, ligand) + self._calculate_mean_expression(counts_receiver, receptor)) / 2 value_lr.append(mean_expression) # The right part of Receptor-TF interaction source_rtf = [] target_rtf = [] value_rtf = [] for i, receptor in enumerate(receptors): # weight_r = [] for j, tf in enumerate(tfs): df = result_path[(result_path['receptor'] == receptor) & (result_path['TF'] == tf)] if df.empty: continue else: source_rtf.append(len(ligands) + i) target_rtf.append(len(ligands) + len(receptors) + j) mean_expression = (self._calculate_mean_expression(counts_receiver, receptor) + self._calculate_mean_expression(counts_receiver, tf)) / 2 value_rtf.append(mean_expression) # Generate sankey plot node = dict(thickness=8, pad=2, label=label) link = dict(source=source_lr + source_rtf, target=target_lr + target_rtf, value=value_lr + value_rtf) fig = go.Figure( data=[ go.Sankey( node=node, link=link) ]) fig.update_layout(height=height, width=width, font_size=12, font_family='Arial') if out_path is not None: fig.write_image(out_path) return fig
def _get_cell_counts(self, cluster_res_key, cluster): cluster_res: pd.DataFrame = self.pipeline_res[cluster_res_key] isin = cluster_res['group'].isin([cluster]).to_numpy() cell_counts = self.stereo_exp_data.exp_matrix[isin] cell_list = self.stereo_exp_data.cell_names[isin] if self.stereo_exp_data.issparse(): cell_counts = cell_counts.toarray() cell_counts = cell_counts.T return pd.DataFrame(cell_counts, columns=cell_list, index=self.stereo_exp_data.gene_names) def _get_expressed_genes(self, count_df, pct): count_df = count_df.loc[(count_df != 0).mean(axis=1) >= pct, :] return list(count_df.index) def _get_expressed_network(self, network, expressed_genes): sub_network = network[network['from'].isin(expressed_genes) & network['to'].isin(expressed_genes)].copy() return sub_network def _get_shortest_path(self, graph, source, target, distance, weight): shortest_path = nx.shortest_path(graph, source=source, target=target, weight=distance) path_weight_sum = 0 for i in range(len(shortest_path) - 1): path_weight_sum = path_weight_sum + graph.get_edge_data(shortest_path[i], shortest_path[i + 1])[weight] return shortest_path, path_weight_sum def _calculate_mean_expression(self, counts, gene): counts_gene = counts.loc[gene, :] return np.mean(counts_gene) def _parse_interacting_pairs_or_clusters(self, interacting_pairs_or_clusters: Union[str, list, np.ndarray]): if isinstance(interacting_pairs_or_clusters, list) or isinstance(interacting_pairs_or_clusters, np.ndarray): return np.unique(interacting_pairs_or_clusters) if isinstance(interacting_pairs_or_clusters, str): path = Path(interacting_pairs_or_clusters) if not path.is_file() and not path.is_dir(): return [interacting_pairs_or_clusters] if path.is_file() and path.exists(): with path.open('r') as fp: return np.unique([line.strip() for line in fp.readlines()]) return None def _ensure_path_exists(self, path: str): expanded_path = os.path.expanduser(path) if not os.path.exists(expanded_path): os.makedirs(expanded_path)