Source code for stereo.plots.plot_clusters_scatter

from typing import Optional, Sequence, Literal

import matplotlib.pylab as plt
import numpy as np
import pandas as pd
from matplotlib import gridspec
from matplotlib.axes import Axes

from stereo.log_manager import logger
from stereo.plots.plot_base import PlotBase
from stereo.utils.pipeline_utils import calc_pct_and_pct_rest
from stereo.utils.pipeline_utils import cell_cluster_to_gene_exp_cluster


class ClustersGenesScatter(PlotBase):
    __category_width = 0.37
    __category_height = 0.35
    __legend_width = 4
    __dendrogram_height = 0.6
    __title_font_size = 8

[docs] def clusters_genes_scatter( self, cluster_res_key: str, dendrogram_res_key: Optional[str] = None, topn: Optional[int] = 5, gene_names: Optional[Sequence[str]] = None, expression_kind: Literal['mean', 'sum'] = 'mean', groups: Optional[Sequence[str]] = None, width: int = None, height: int = None, colormap: str = 'Reds', standard_scale: str = 'gene' ): """ Scatter representing mean expression of genes on each cell cluster. :param cluster_res_key: the key to get cluster result. :param dendrogram_res_key: the key to get dendrogram result, defaults to None to avoid show dendrogram on plot. :param topn: select `topn` expressed genes in each cluster, defaults to 5, ignored if `gene_names` is not None, the number of genes shown in plot may be more than `topn` because the `topn` genes in each cluster are not the same. :param gene_names: a list of genes to show, defaults to None to show all genes. :param expression_kind: the kind of expression to show, 'mean' or 'sum', defaults to 'mean'. :param groups: a list of cell clusters to show, defaults to None to show all cell clusters. :param width: the figure width in pixels, defaults to None :param height: the figure height in pixels, defaults to None :param colormap: colormap used on plot, defaults to 'Reds' :param standard_scale: Whether or not to standardize that dimension between 0 and 1, meaning for each gene or cluster, subtract the minimum and divide each by its maximum, defaults to 'gene' """ if cluster_res_key not in self.pipeline_res: raise KeyError(f"Can not find the cluster result in data.tl.result by key {cluster_res_key}") drg_res = None if dendrogram_res_key is not None: if dendrogram_res_key not in self.pipeline_res: raise KeyError(f"Can not find the dendrogram result in data.tl.result by key {dendrogram_res_key}") else: drg_res = self.pipeline_res[dendrogram_res_key] if cluster_res_key != drg_res['cluster_res_key'][0]: raise KeyError(f'The cluster result used in dendrogram may not be the same as that ' f'specified by key {cluster_res_key}') if gene_names is not None: topn = None if topn is None: if gene_names is None: gene_names = self.stereo_exp_data.gene_names else: if isinstance(gene_names, str): gene_names = np.array([gene_names], dtype='U') elif not isinstance(gene_names, np.ndarray): gene_names = np.array(gene_names, dtype='U') if len(gene_names) == 0: return None if groups is None or drg_res is not None: cluster_res: pd.DataFrame = self.pipeline_res[cluster_res_key] cluster_res['group'] = cluster_res['group'].astype('category') group_codes = cluster_res['group'].cat.categories.to_numpy() groups = None if drg_res is not None: categories_idx_ordered = drg_res['categories_idx_ordered'] group_codes = group_codes[categories_idx_ordered] else: group_codes = groups if isinstance(group_codes, str): group_codes = np.array([group_codes], dtype='U') elif not isinstance(group_codes, np.ndarray): group_codes = np.array(group_codes, dtype='U') if topn is None: genes_expression = cell_cluster_to_gene_exp_cluster( self.stereo_exp_data, cluster_res_key, groups=groups, genes=gene_names, kind=expression_kind ) else: genes_expression = cell_cluster_to_gene_exp_cluster( self.stereo_exp_data, cluster_res_key, groups=groups, kind=expression_kind ) gene_names = [] for c in genes_expression.columns: gene_names.extend(genes_expression[c].sort_values(ascending=False).index[:topn].tolist()) gene_names = np.unique(gene_names) genes_expression = genes_expression.loc[gene_names] if standard_scale == 'cluster': genes_expression -= genes_expression.min(0) genes_expression = (genes_expression / genes_expression.max(0)).fillna(0) elif standard_scale == 'gene': genes_expression = genes_expression.sub(genes_expression.min(1), axis=0) genes_expression = genes_expression.div(genes_expression.max(1), axis=0).fillna(0) elif standard_scale is None: pass else: logger.warning('Unknown type for standard_scale, ignored') pct, _ = calc_pct_and_pct_rest( self.stereo_exp_data, cluster_res_key, gene_names=gene_names, groups=groups ) if 'genes' in pct.columns: pct.set_index('genes', inplace=True) dot_plot_data = self._create_dot_plot_data( pct, genes_expression, group_codes, gene_names ) if width is None: if len(group_codes) < 10: main_area_width = self.__category_width * 10 else: main_area_width = self.__category_width * len(group_codes) else: if width < self.__legend_width: width = self.__category_width * 10 + self.__legend_width main_area_width = width - self.__legend_width if height is None: if len(gene_names) < 10: main_area_height = self.__category_height * 10 else: main_area_height = self.__category_height * len(gene_names) else: if height < self.__dendrogram_height: height = self.__category_height * 10 + self.__dendrogram_height main_area_height = height - self.__dendrogram_height width_ratios = [main_area_width, self.__legend_width] height_ratios = [self.__dendrogram_height + main_area_height] width = sum(width_ratios) height = sum(height_ratios) fig = plt.figure(figsize=(width, height)) axs = gridspec.GridSpec( nrows=1, ncols=2, width_ratios=width_ratios, height_ratios=height_ratios, wspace=(0.15 / main_area_width), hspace=0 ) axs_main = gridspec.GridSpecFromSubplotSpec( nrows=2, ncols=1, width_ratios=[main_area_width], height_ratios=[self.__dendrogram_height, main_area_height], wspace=0, hspace=0, subplot_spec=axs[0, 0] ) ax_scatter = fig.add_subplot(axs_main[1, 0]) main_im = self._dotplot(ax_scatter, dot_plot_data, group_codes, gene_names, colormap) if drg_res is not None: from .plot_dendrogram import PlotDendrogram ax_drg = fig.add_subplot(axs_main[0, 0], sharex=ax_scatter) plt_drg = PlotDendrogram(self.stereo_exp_data, self.pipeline_res) plt_drg.dendrogram( orientation='top', ax=ax_drg, res_key=dendrogram_res_key, ticks=list(range(len(group_codes))) ) ax_drg.set_axis_off() axs_on_right = gridspec.GridSpecFromSubplotSpec( nrows=4, ncols=1, height_ratios=[0.55, 0.05, 0.2, 0.1], subplot_spec=axs[0, 1], hspace=0.1 ) ax_colorbar = fig.add_subplot(axs_on_right[1, 0]) self._plot_colorbar(ax_colorbar, main_im, expression_kind) ax_dot_size_map = fig.add_subplot(axs_on_right[3, 0]) self._plot_dot_size_map(ax_dot_size_map) return fig
def _dotplot( self, ax: Axes, dot_plot_data: pd.DataFrame, group_codes: Sequence[str], gene_names: Sequence[str], colormap: str ): ax.set_xlim(left=-1, right=len(group_codes)) ax.xaxis.set_ticks(range(len(group_codes)), group_codes) ax.set_ylim(bottom=-1, top=len(gene_names)) ax.yaxis.set_ticks(range(len(gene_names)), gene_names) ax.tick_params(axis='x', labelrotation=90) return ax.scatter( x=dot_plot_data['x'], y=dot_plot_data['y'], s=dot_plot_data['dot_size'], c=dot_plot_data['dot_color'], cmap=colormap ) def _plot_colorbar( self, ax: Axes, im, expression_kind: str ): ax.set_title(f'{expression_kind.capitalize()} expression in group', fontdict={'fontsize': self.__title_font_size}) plt.colorbar(im, cax=ax, orientation='horizontal', ticklocation='bottom') def _create_dot_plot_data( self, pct: pd.DataFrame, genes_expression: pd.DataFrame, group_codes: Sequence[str], gene_names: Sequence[str] ): x = [i for i in range(len(group_codes))] data_list = [] for y, g in enumerate(gene_names): dot_size = pct.loc[g][group_codes] * 100 dot_color = genes_expression.loc[g][group_codes] df = pd.DataFrame({ 'x': x, 'y': y, 'dot_size': dot_size, 'dot_color': dot_color }) data_list.append(df) return pd.concat(data_list, axis=0) def _plot_dot_size_map( self, ax: Axes ): ax.set_title('Fraction of cells in group(%)', fontdict={'fontsize': self.__title_font_size}) ax.set_xlim(left=5, right=105) ax.set_ylim(bottom=0, top=1) ax.xaxis.set_ticks([10, 20, 30, 40, 50, 60, 70, 80, 90, 100], [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]) ax.axes.set_frame_on(False) ax.yaxis.set_tick_params(left=False, labelleft=False) ax.scatter( x=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100], y=[0.4] * 10, s=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100], c='grey' )