Source code for stereo.plots.plot_clusters_heatmap

from typing import Optional, Sequence, Literal

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from matplotlib import gridspec

from stereo.log_manager import logger
from stereo.plots.plot_base import PlotBase
from stereo.utils.pipeline_utils import cell_cluster_to_gene_exp_cluster
from ._plot_basic.heatmap_plt import heatmap


class ClustersGenesHeatmap(PlotBase):
    __category_width = 0.37
    __category_height = 0.35
    __legend_width = 2
    __dendrogram_height = 0.6
    __title_font_size = 8

[docs] def clusters_genes_heatmap( self, cluster_res_key: str, dendrogram_res_key: Optional[str] = None, topn: Optional[int] = 5, gene_names: Optional[Sequence[str]] = None, expression_kind: Literal['mean', 'sum'] = 'mean', groups: Optional[Sequence[str]] = None, width: int = None, height: int = None, colormap: str = 'Greens', standard_scale: str = 'gene' ): """ Heatmap representing mean expression of genes on each cell cluster. :param cluster_res_key: the key to get cluster result. :param dendrogram_res_key: the key to get dendrogram result, defaults to None to avoid show dendrogram on plot. :param topn: select `topn` expressed genes in each cluster, defaults to 5, ignored if `gene_names` is not None, the number of genes shown in plot may be more than `topn` because the `topn` genes in each cluster are not the same. :param gene_names: a list of genes to show, defaults to None to show all genes. :param expression_kind: the kind of expression to show, 'mean' or 'sum', defaults to 'mean'. :param groups: a list of cell clusters to show, defaults to None to show all cell clusters. :param width: the figure width in pixels, defaults to None. :param height: the figure height in pixels, defaults to None. :param colormap: colormap used on plot, defaults to 'Greens' :param standard_scale: Whether or not to standardize that dimension between 0 and 1, meaning for each gene or cluster, subtract the minimum and divide each by its maximum, defaults to 'gene' """ if cluster_res_key not in self.pipeline_res: raise KeyError(f"Can not find the cluster result in data.tl.result by key {cluster_res_key}") drg_res = None if dendrogram_res_key is not None: if dendrogram_res_key not in self.pipeline_res: raise KeyError(f"Can not find the dendrogram result in data.tl.result by key {dendrogram_res_key}") else: drg_res = self.pipeline_res[dendrogram_res_key] if cluster_res_key != drg_res['cluster_res_key'][0]: raise KeyError( f'The cluster result used in dendrogram may not be the same as that specified ' f'by key {cluster_res_key}') if gene_names is not None: topn = None if topn is None: if gene_names is None: gene_names = self.stereo_exp_data.gene_names else: if isinstance(gene_names, str): gene_names = np.array([gene_names], dtype='U') elif not isinstance(gene_names, np.ndarray): gene_names = np.array(gene_names, dtype='U') if len(gene_names) == 0: return None if groups is None or drg_res is not None: cluster_res: pd.DataFrame = self.pipeline_res[cluster_res_key] cluster_res['group'] = cluster_res['group'].astype('category') group_codes = cluster_res['group'].cat.categories.to_numpy() groups = None if drg_res is not None: categories_idx_ordered = drg_res['categories_idx_ordered'] group_codes = group_codes[categories_idx_ordered] else: group_codes = groups if isinstance(group_codes, str): group_codes = np.array([group_codes], dtype='U') elif not isinstance(group_codes, np.ndarray): group_codes = np.array(group_codes, dtype='U') if topn is None: genes_expression: pd.DataFrame = cell_cluster_to_gene_exp_cluster( self.stereo_exp_data, cluster_res_key, groups=groups, genes=gene_names, kind='mean' ) else: genes_expression = cell_cluster_to_gene_exp_cluster( self.stereo_exp_data, cluster_res_key, groups=groups, kind=expression_kind ) gene_names = [] for c in genes_expression.columns: gene_names.extend(genes_expression[c].sort_values(ascending=False).index[:topn].tolist()) gene_names = np.unique(gene_names) genes_expression = genes_expression.loc[gene_names] genes_expression = genes_expression[group_codes] if standard_scale == 'cluster': genes_expression -= genes_expression.min(0) genes_expression = (genes_expression / genes_expression.max(0)).fillna(0) elif standard_scale == 'gene': genes_expression = genes_expression.sub(genes_expression.min(1), axis=0) genes_expression = genes_expression.div(genes_expression.max(1), axis=0).fillna(0) elif standard_scale is None: pass else: logger.warning('Unknown type for standard_scale, ignored') if width is None: if len(group_codes) < 10: main_area_width = self.__category_width * 10 else: main_area_width = self.__category_width * len(group_codes) else: if width < self.__legend_width: width = self.__category_width * 10 + self.__legend_width main_area_width = width - self.__legend_width if height is None: if len(gene_names) < 10: main_area_height = self.__category_height * 10 else: main_area_height = self.__category_height * len(gene_names) else: if height < self.__dendrogram_height: height = self.__category_height * 10 + self.__dendrogram_height main_area_height = height - self.__dendrogram_height width_ratios = [main_area_width, self.__legend_width] height_ratios = [self.__dendrogram_height + main_area_height] width = sum(width_ratios) height = sum(height_ratios) fig = plt.figure(figsize=(width, height)) axs = gridspec.GridSpec( nrows=1, ncols=2, width_ratios=width_ratios, height_ratios=height_ratios, wspace=(0.15 / main_area_width), hspace=0 ) axs_main = gridspec.GridSpecFromSubplotSpec( nrows=2, ncols=1, width_ratios=[main_area_width], height_ratios=[self.__dendrogram_height, main_area_height], wspace=0, hspace=0, subplot_spec=axs[0, 0] ) axs_on_right = gridspec.GridSpecFromSubplotSpec( nrows=2, ncols=1, height_ratios=[0.95, 0.05], subplot_spec=axs[0, 1], hspace=0.1 ) ax_heatmap = fig.add_subplot(axs_main[1, 0]) ax_colorbar = fig.add_subplot(axs_on_right[1, 0]) heatmap( genes_expression, ax=ax_heatmap, plot_colorbar=True, colorbar_ax=ax_colorbar, cmap=colormap, colorbar_orientation='horizontal', colorbar_ticklocation='bottom', colorbar_title=f'{expression_kind.capitalize()} expression in group', show_xaxis=True, show_yaxis=True, ) if drg_res is not None: from .plot_dendrogram import PlotDendrogram ax_drg = fig.add_subplot(axs_main[0, 0], sharex=ax_heatmap) plt_drg = PlotDendrogram(self.stereo_exp_data, self.pipeline_res) plt_drg.dendrogram( orientation='top', ax=ax_drg, res_key=dendrogram_res_key, ticks=list(range(len(group_codes))) ) ax_drg.set_axis_off() return fig