Source code for stereo.plots.plot_genes_in_pseudotime

from typing import Literal

import matplotlib.pylab as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from .plot_base import PlotBase


class PlotGenesInPseudotime(PlotBase):
    __DEFAULT_FIGURE_WIDTH = 17
    __DEFAULT_AXES_HEIGHT = 3
    __FONT_SIZE = 10

[docs] def plot_genes_in_pseudotime( self, marker_genes_res_key: str = 'marker_genes', group: str = None, pseudotime_key: str = 'dpt_pseudotime', sort_by: Literal[ 'scores', 'pvalues', 'pvalues_adj', 'log2fc', 'pct', 'pct_rest' ] = 'scores', topn: int = 5, cmap: str = 'plasma', width: int = None, height: int = None, size: int = 30, marker: str = '.' ): """ Distribution of expression count of marker genes along with pseudotime :param marker_genes_res_key: Specifies the key to get result of `find_marker_genes` , defaults to 'marker_genes' :param group: Sepcifies the cell type to plot, defaults to None :param sort_by: Get top N marker genes in this order, defaults to 'scores' :param topn: Get top N marker genes, defaults to 5 :param cmap: Color map, definded in `matplotlib <https://matplotlib.org/stable/users/explain/colors/colormaps.html>`_, defaults to 'plasma' :param width: The figure width in pixels, defaults to None :param height: The figure height in pixels, defaults to None :param size: The size of markers in plot, defaults to 30 :param marker: The style of markers in plot, defaults to '.' """ # noqa if group is None: raise ValueError('Must specify the group to plot.') if marker_genes_res_key not in self.pipeline_res: raise KeyError(f"Can not find 'find_marker_genes' result by key '{marker_genes_res_key}'.") # if 'dpt_pseudotime' not in self.pipeline_res: # raise KeyError("Can not find pseudotime infomation.") if pseudotime_key not in self.stereo_exp_data.cells and pseudotime_key not in self.pipeline_res: raise KeyError("Can not find pseudotime infomation.") marker_genes_res: dict = self.pipeline_res[marker_genes_res_key] marker_genes_res_used = None for key in marker_genes_res.keys(): if 'vs' not in key: continue group_in_key = key.split('.vs.')[0] if group == group_in_key: marker_genes_res_used: pd.DataFrame = marker_genes_res[key] break if marker_genes_res_used is None: raise ValueError(f"Can not find the 'find_marker_genes' result related to group '{group}'.") if sort_by not in marker_genes_res_used.columns: raise ValueError(f"The key '{sort_by}' used to sort Can not be found in 'find_marker_genes' result.") marker_genes_res_used.sort_values(by=sort_by, ascending=False, inplace=True) topn = min(topn, marker_genes_res_used.shape[0]) genes = marker_genes_res_used[0:topn]['genes'].to_numpy() # ptime = self.pipeline_res['dpt_pseudotime'] if pseudotime_key in self.stereo_exp_data.cells: ptime = self.stereo_exp_data.cells[pseudotime_key].to_numpy() else: ptime = self.pipeline_res[pseudotime_key] if width is None: width = self.__DEFAULT_FIGURE_WIDTH if height is None: height = self.__DEFAULT_AXES_HEIGHT * len(genes) fig, axes = plt.subplots(nrows=len(genes), ncols=1, sharex=True, figsize=(width, height)) norm = Normalize(vmin=ptime.min(), vmax=ptime.max(), clip=False) x_label = 'Pseudotime' y_label = 'Expression' hue_label = 'Pseudotime' for i, gene in enumerate(genes): flag = self.stereo_exp_data.gene_names == gene gene_exp = self.stereo_exp_data.exp_matrix[:, flag] if self.stereo_exp_data.issparse(): gene_exp = gene_exp.toarray() gene_exp = gene_exp.reshape(-1) plot_data = pd.DataFrame({ x_label: ptime, y_label: gene_exp, 'hue': ptime, 'size': size }) ax: Axes = axes[i] ax.set_title(gene, fontdict=dict(fontsize=self.__FONT_SIZE), loc='center') ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.xaxis.set_visible(False) ax.set_xlabel(x_label, fontdict=dict(fontsize=self.__FONT_SIZE), loc='center') ax.set_ylabel(y_label, fontdict=dict(fontsize=self.__FONT_SIZE), loc='center') sns.scatterplot( data=plot_data, x=x_label, y=y_label, hue='hue', hue_norm=norm, palette=cmap, size='size', sizes=(size, size), marker=marker, ax=ax, legend=False ) axes[-1].spines['bottom'].set_visible(True) axes[-1].xaxis.set_visible(True) sm = ScalarMappable(norm=norm, cmap=cmap) cb = fig.colorbar(sm, ax=axes, shrink=0.2, orientation='vertical') cb.ax.set_title(hue_label, fontdict=dict(fontsize=self.__FONT_SIZE), loc='left') return fig