Source code for stereo.plots.plot_grn

# python core modules
from typing import Union

# third party modules
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from pyscenic.rss import regulon_specificity_scores

# modules in self project
from stereo.log_manager import logger
from stereo.plots.decorator import plot_scale
from stereo.plots.decorator import reorganize_coordinate
from stereo.plots.plot_base import PlotBase
from stereo.plots.scatter import base_scatter


class PlotRegulatoryNetwork(PlotBase):
    """
    Plot Gene Regulatory Networks related plots
    """

    # dotplot method for StereoExpData
    @staticmethod
    def _cal_percent_exp_df(
            exp_matrix: pd.DataFrame,
            cluster_meta: pd.DataFrame,
            regulon_genes: str,
            celltype: list,
            groupby: str = 'group',
            cell_label: str = 'bins',
            cutoff: float = 0
    ):
        """
        Expression percent
        cell numbers
        :param exp_matrix:
        :param cluster_meta:
        :param regulon_genes:
        :param celltype:
        :param cutoff:
        :return:
        """
        # which cells are in cluster X
        cells = cluster_meta[cluster_meta[groupby] == celltype][cell_label]
        ncells = set(exp_matrix.index).intersection(set(cells))
        # get expression data for cells
        # input genes in regulon Y
        # get expression data for regulon Y genes in cluster X cells
        g_ct_exp = exp_matrix.loc[list(ncells), regulon_genes]
        # count regulon totol expression value
        g_ct_exp['total'] = g_ct_exp.sum(axis=1)
        # count the number of genes which expressed in cluster X cells
        regulon_cell_num = g_ct_exp['total'][g_ct_exp['total'] > cutoff].count()
        total_cell_num = g_ct_exp.shape[0]
        if total_cell_num == 0:
            return 0
        else:
            reg_ct_percent = regulon_cell_num / total_cell_num
            reg_ct_avg_exp = np.mean(g_ct_exp['total'])
            return round(reg_ct_percent, 2), round(reg_ct_avg_exp, 2)

[docs] def grn_dotplot(self, cluster_res_key: str, regulon_names: Union[str, list] = None, celltypes: Union[str, list] = None, groupby: str = 'group', cell_label: str = 'bins', network_res_key: str = 'regulatory_network_inference', palette: str = 'Reds', width: int = None, height: int = None, **kwargs): """ Intuitive way of visualizing how feature expression changes across different identity classes (clusters). The size of the dot encodes the percentage of cells within a class, while the color encodes the AverageExpression level across all cells within a class (red is high). :param cluster_res_key: the key which specifies the clustering result in data.tl.result. :param regulon_names: the regulon which would be shown on plot, defaults to None. If set it to None, it will be set to all regulon. 1) string: only one cluster. 2) list: an array contains the regulon which would be shown. :param celltypes: the celltypes in cluster pairs which would be shown on plot, defaults to None. If set it to None, it will be set to all clusters. 1) string: only one cluster. 2) list: an array contains the clusters which would be shown. :param groupby: cell type label. :param cell_label: cell bin label. :param network_res_key: the key which specifies inference regulatory network result in data.tl.result, defaults to 'regulatory_network_inference' :param palette: Color theme, defaults to 'Reds' :param kwargs: features Input vector of features, or named list of feature vectors :param width: the figure width in pixels. :param height: the figure height in pixels. :return: matplotlib.figure """ if network_res_key not in self.pipeline_res: logger.info(f"The result specified by {network_res_key} is not exists.") expr_matrix = self.stereo_exp_data.tl.raw.to_df() dot_data = {'cell type': [], 'regulons': [], 'percentage': [], 'avg exp': []} regulon_dict = self.pipeline_res[network_res_key]['regulons'] if cluster_res_key in self.stereo_exp_data.cells._obs.columns: meta = pd.DataFrame({ 'bins': self.stereo_exp_data.cells.cell_name, 'group': self.stereo_exp_data.cells._obs[cluster_res_key].tolist() }) else: meta = self.pipeline_res[cluster_res_key] if celltypes is None: meta_new = meta.drop_duplicates(subset='group', inplace=False) celltypes = sorted(meta_new['group']) elif isinstance(celltypes, str) and celltypes.upper() == 'ALL': meta_new = meta.drop_duplicates(subset='group', inplace=False) celltypes = sorted(meta_new['group']) elif isinstance(celltypes, str) and celltypes.upper() != 'ALL': celltypes = [celltypes] if regulon_names is None: regulon_names = regulon_dict.keys() elif isinstance(regulon_names, str) and regulon_names.upper() == 'ALL': regulon_names = regulon_dict.keys() elif isinstance(regulon_names, str) and regulon_names.upper() != 'ALL': regulon_names = [regulon_names] for reg in regulon_names: if '(+)' not in reg: reg = reg + '(+)' target_genes = regulon_dict[f'{reg}'] for ct in celltypes: reg_ct_percent, reg_ct_avg_exp = PlotRegulatoryNetwork._cal_percent_exp_df(exp_matrix=expr_matrix, cluster_meta=meta, regulon_genes=target_genes, celltype=ct, groupby=groupby, cell_label=cell_label) dot_data['regulons'].append(reg) dot_data['cell type'].append(ct) dot_data['percentage'].append(reg_ct_percent) dot_data['avg exp'].append(reg_ct_avg_exp) dot_df = pd.DataFrame(dot_data) # plotting ncols = len(dot_df['regulons'].unique()) nrows = len(dot_df['cell type'].unique()) if width is None or height is None: width, height = int(5 + max(3, ncols * 0.8)), int(3 + max(5, nrows * 0.5)) else: width = width / 100 if width >= 100 else int(5 + max(3, ncols * 0.8)) height = height / 100 if height >= 100 else int(3 + max(5, nrows * 0.5)) fig, ax = plt.subplots(figsize=(width, height)) sns.scatterplot(data=dot_df, size='percentage', hue='avg exp', x='regulons', y='cell type', sizes=(100, 300), marker='o', palette=palette, legend='auto', ax=ax, **kwargs) ax.legend(fontsize=12, frameon=False, ncol=1, loc=(1.02, 0)) ax.tick_params(axis='x', labelsize=12, labelrotation=90) ax.tick_params(axis='y', labelsize=12) ax.set_xlabel('Regulon') ax.set_ylabel('Cell type') return fig
[docs] def auc_heatmap( self, network_res_key: str = 'regulatory_network_inference', width: int = 8, height: int = 8, pivot_kws: dict = None, method: str = 'average', metric: str = 'euclidean', z_score: int = None, standard_scale: int = None, cbar_kws: dict = None, row_cluster: bool = True, col_cluster: bool = True, row_linkage: np.ndarray = None, col_linkage: np.ndarray = None, row_colors: Union[list, pd.DataFrame, pd.Series] = None, col_colors: Union[list, pd.DataFrame, pd.Series] = None, mask: Union[list, pd.DataFrame] = None, dendrogram_ratio: float = (.1, .2), colors_ratio: float = 0.03, cbar_pos: tuple = (-.05, .2, .03, .4), tree_kws: dict = None ): """ Plot heatmap for auc value for regulons Parameters ---------- network_res_key: str, optional the key which specifies inference regulatory network result in data.tl.result, defaults to 'regulatory_network_inference' # noqa height: int, optional height of drawing width: int, optional width of drawing pivot_kws : dict, optional If `data` is a tidy dataframe, can provide keyword arguments for pivot to create a rectangular dataframe. method : str, optional Linkage method to use for calculating clusters. See :func:`scipy.cluster.hierarchy.linkage` documentation for more information. metric : str, optional Distance metric to use for the data. See :func:`scipy.spatial.distance.pdist` documentation for more options. To use different metrics (or methods) for rows and columns, you may construct each linkage matrix yourself and provide them as `{row,col}_linkage`. z_score : int or None, optional Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores for the rows or the columns. Z scores are: z = (x - mean)/std, so values in each row (column) will get the mean of the row (column) subtracted, then divided by the standard deviation of the row (column). This ensures that each row (column) has mean of 0 and variance of 1. standard_scale : int or None, optional Either 0 (rows) or 1 (columns). Whether or not to standardize that dimension, meaning for each row or column, subtract the minimum and divide each by its maximum. cbar_kws : dict, optional Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to add a label to the colorbar. {row,col}_cluster : bool, optional If ``True``, cluster the {rows, columns}. {row,col}_linkage : :class:`numpy.ndarray`, optional Precomputed linkage matrix for the rows or columns. See :func:`scipy.cluster.hierarchy.linkage` for specific formats. {row,col}_colors : list-like or pandas DataFrame/Series, optional List of colors to label for either the rows or columns. Useful to evaluate whether samples within a group are clustered together. Can use nested lists or DataFrame for multiple color levels of labeling. If given as a :class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are extracted from the DataFrames column names or from the name of the Series. DataFrame/Series colors are also matched to the data by their index, ensuring colors are drawn in the correct order. mask : bool array or DataFrame, optional If passed, data will not be shown in cells where `mask` is True. Cells with missing values are automatically masked. Only used for visualizing, not for calculating. {dendrogram,colors}_ratio : float, or pair of floats, optional Proportion of the figure size devoted to the two marginal elements. If a pair is given, they correspond to (row, col) ratios. cbar_pos : tuple of (left, bottom, width, height), optional Position of the colorbar axes in the figure. Setting to ``None`` will disable the colorbar. tree_kws : dict, optional Parameters for the :class:`matplotlib.collections.LineCollection` that is used to plot the lines of the dendrogram tree. :return: matplotlib.figure """ logger.info('Generating auc heatmap plot') if network_res_key not in self.pipeline_res: logger.info(f"The result specified by {network_res_key} is not exists.") g = sns.clustermap( self.pipeline_res[network_res_key]['auc_matrix'], pivot_kws=pivot_kws, method=method, metric=metric, z_score=z_score, standard_scale=standard_scale, cbar_kws=cbar_kws, row_cluster=row_cluster, col_cluster=col_cluster, col_linkage=col_linkage, row_linkage=row_linkage, row_colors=row_colors, col_colors=col_colors, mask=mask, colors_ratio=colors_ratio, tree_kws=tree_kws, figsize=(width, height), dendrogram_ratio=dendrogram_ratio, cbar_pos=cbar_pos, ) return g.figure
[docs] @plot_scale @reorganize_coordinate def spatial_scatter_by_regulon( self, network_res_key: str = 'regulatory_network_inference', reg_name: str = None, dot_size: int = None, palette: str = 'CET_L4', width: int = None, height: int = None, vmin: float = None, vmax: float = None, **kwargs): """ Plot genes of one regulon on a 2D map :param network_res_key: the key which specifies inference regulatory network result in data.tl.result, defaults to 'regulatory_network_inference' :param reg_name: specify the regulon you want to draw, defaults to None, if none, will select randomly. :param dot_size: marker size, defaults to None :param palette: Color theme, defaults to 'CET_L4' :param height: height of drawing :param width: width of drawing :param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin. :param vmax: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin. :param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells), if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below: --------------- | data1 data2 | data3 data4 | data5 ... | ... ... --------------- if set it to `False`, the coordinates will not be changed. :param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates. :param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates. :return: matplotlib.figure """ # noqa logger.info('Please adjust the dot_size to prevent dots from covering each other') if network_res_key not in self.pipeline_res: logger.info(f"The result specified by {network_res_key} is not exists.") if reg_name is None: regulon_dict = self.pipeline_res[network_res_key]['regulons'] reg_name = list(regulon_dict.keys())[0] elif '(+)' not in reg_name: reg_name = reg_name + '(+)' cell_coor = self.stereo_exp_data.position # prepare plotting data auc_zscore = cal_zscore(self.pipeline_res[network_res_key]['auc_matrix'][reg_name]) # sort data points by zscore (low to high), because first dot will be covered by latter dots df = pd.DataFrame({'x': cell_coor[:, 0], 'y': cell_coor[:, 1], 'auc_zscore': auc_zscore}) df.sort_values(by=['auc_zscore'], inplace=True) # plot cell/bin dot, x y coor if 'color_bar_reverse' in kwargs: color_bar_reverse = kwargs['color_bar_reverse'] del kwargs['color_bar_reverse'] else: color_bar_reverse = False fig = base_scatter( x=df['x'], y=df['y'], hue=df['auc_zscore'], title=reg_name, x_label='spatial1', y_label='spatial2', dot_size=dot_size, palette=palette, color_bar=True, color_bar_reverse=color_bar_reverse, width=width, height=height, vmin=vmin, vmax=vmax, **kwargs ) return fig
@staticmethod def plot_2d_reg_h5ad(data: anndata.AnnData, pos_label, auc_mtx, reg_name: str, **kwargs): """ Plot genes of one regulon on a 2D map :param pos_label: :param data: :param auc_mtx: :param reg_name: :return: Example: plot_2d_reg_h5ad(data, 'spatial', auc_mtx, 'Zfp354c') """ if '(+)' not in reg_name: reg_name = reg_name + '(+)' cell_coor = data.obsm[pos_label] auc_zscore = cal_zscore(auc_mtx) # prepare plotting data sub_zscore = auc_zscore[reg_name] # sort data points by zscore (low to high), because first dot will be covered by latter dots zorder = np.argsort(sub_zscore[reg_name].values) # plot cell/bin dot, x y coor sc = plt.scatter(cell_coor[:, 0][zorder], cell_coor[:, 1][zorder], c=sub_zscore[reg_name][zorder], marker='.', edgecolors='none', cmap='plasma', lw=0, **kwargs) plt.box(False) plt.axis('off') plt.colorbar(sc, shrink=0.35) plt.savefig(f'{reg_name.split("(")[0]}.png') plt.close()
[docs] def auc_heatmap_by_group( self, network_res_key: str = 'regulatory_network_inference', cluster_res_key: str = None, top_n_feature: int = 5, width: int = 18, height: int = 28, pivot_kws: dict = None, method: str = 'average', metric: str = 'euclidean', z_score: int = None, standard_scale: int = None, cbar_kws: dict = None, row_cluster: bool = True, col_cluster: bool = True, row_linkage: np.ndarray = None, col_linkage: np.ndarray = None, col_colors: Union[list, pd.DataFrame, pd.Series] = None, mask: Union[list, pd.DataFrame] = None, dendrogram_ratio: float = (.1, .2), colors_ratio: float = 0.03, cbar_pos: tuple = (-.05, .2, .03, .4), tree_kws: dict = None ): """ Plot heatmap for Regulon specificity scores (RSS) value Parameters ---------- network_res_key: str, optional the key which specifies inference regulatory network result in data.tl.result, defaults to 'regulatory_network_inference' # noqa cluster_res_key: str, optional the key which specifies the clustering result in data.tl.result, defaults to None top_n_feature: int, optional get the top n feature height: int, optional height of drawing width: int, optional width of drawing pivot_kws : dict, optional If `data` is a tidy dataframe, can provide keyword arguments for pivot to create a rectangular dataframe. method : str, optional Linkage method to use for calculating clusters. See :func:`scipy.cluster.hierarchy.linkage` documentation for more information. metric : str, optional Distance metric to use for the data. See :func:`scipy.spatial.distance.pdist` documentation for more options. To use different metrics (or methods) for rows and columns, you may construct each linkage matrix yourself and provide them as `{row,col}_linkage`. z_score : int or None, optional Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores for the rows or the columns. Z scores are: z = (x - mean)/std, so values in each row (column) will get the mean of the row (column) subtracted, then divided by the standard deviation of the row (column). This ensures that each row (column) has mean of 0 and variance of 1. standard_scale : int or None, optional Either 0 (rows) or 1 (columns). Whether or not to standardize that dimension, meaning for each row or column, subtract the minimum and divide each by its maximum. cbar_kws : dict, optional Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to add a label to the colorbar. {row,col}_cluster : bool, optional If ``True``, cluster the {rows, columns}. {row,col}_linkage : :class:`numpy.ndarray`, optional Precomputed linkage matrix for the rows or columns. See :func:`scipy.cluster.hierarchy.linkage` for specific formats. {row,col}_colors : list-like or pandas DataFrame/Series, optional List of colors to label for either the rows or columns. Useful to evaluate whether samples within a group are clustered together. Can use nested lists or DataFrame for multiple color levels of labeling. If given as a :class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are extracted from the DataFrames column names or from the name of the Series. DataFrame/Series colors are also matched to the data by their index, ensuring colors are drawn in the correct order. mask : bool array or DataFrame, optional If passed, data will not be shown in cells where `mask` is True. Cells with missing values are automatically masked. Only used for visualizing, not for calculating. {dendrogram,colors}_ratio : float, or pair of floats, optional Proportion of the figure size devoted to the two marginal elements. If a pair is given, they correspond to (row, col) ratios. cbar_pos : tuple of (left, bottom, width, height), optional Position of the colorbar axes in the figure. Setting to ``None`` will disable the colorbar. tree_kws : dict, optional Parameters for the :class:`matplotlib.collections.LineCollection` that is used to plot the lines of the dendrogram tree. :return: matplotlib.figure """ if network_res_key not in self.pipeline_res: logger.info(f"The result specified by {network_res_key} is not exists.") elif cluster_res_key not in self.pipeline_res: logger.info(f"The result specified by {cluster_res_key} is not exists.") auc_mtx = self.pipeline_res[network_res_key]['auc_matrix'] if cluster_res_key in self.stereo_exp_data.cells._obs.columns: meta = pd.DataFrame({ 'bins': self.stereo_exp_data.cells.cell_name, 'group': self.stereo_exp_data.cells._obs[cluster_res_key].tolist() }) else: meta = self.pipeline_res[cluster_res_key].copy(deep=True) # Regulon specificity scores (RSS) across predicted cell types rss_cellType = regulon_specificity_scores(auc_mtx, meta['group']) # rss_cellType.to_csv('regulon_specificity_scores.txt') # Select the top 5 regulon_list from each cell type cats = sorted(list(set(meta['group']))) topreg = [] for i, c in enumerate(cats): topreg.extend( list(rss_cellType.T[c].sort_values(ascending=False)[:top_n_feature].index) ) topreg = list(set(topreg)) # plot z-score auc_zscore = cal_zscore(auc_mtx) sns.set(font_scale=1.2) # set group color lut = dict(zip(meta['group'].unique(), ncolors(len(meta['group'].unique())))) row_colors = meta['group'].map(lut) meta['group'] = row_colors g = sns.clustermap( auc_zscore[topreg], row_colors=meta.set_index(['bins']), figsize=(width, height), pivot_kws=pivot_kws, method=method, metric=metric, z_score=z_score, standard_scale=standard_scale, cbar_kws=cbar_kws, row_cluster=row_cluster, col_cluster=col_cluster, col_linkage=col_linkage, row_linkage=row_linkage, col_colors=col_colors, mask=mask, colors_ratio=colors_ratio, tree_kws=tree_kws, dendrogram_ratio=dendrogram_ratio, cbar_pos=cbar_pos, ) return g.figure
[docs] def spatial_scatter_by_regulon_3D( self, network_res_key: str = 'regulatory_network_inference', reg_name: str = None, # fn: str = None, view_vertical: int = 0, view_horizontal: int = 0, show_axis: bool = False, width: int = None, height: int = None, **kwargs): """Plot genes of one regulon on a 3D map :param network_res_key: the key which specifies inference regulatory network result in data.tl.result, defaults to 'regulatory_network_inference' # noqa :param reg_name: specify the regulon you want to draw, defaults to None, if none, will select randomly. :param fn: specify the file name of the output figure, defaults to None, if none, will use regulon name. :param view_vertical: vertical angle to view to the 3D object :param view_horizontal: horizontal angle to view the 3D object Example: data.plt.plot_3d_reg('regulatory_network_inference', 'Zfp354c', view_vertical=30, view_horizontal=-30) """ if reg_name is None: regulon_dict = self.pipeline_res[network_res_key]['regulons'] reg_name = list(regulon_dict.keys())[0] elif '(+)' not in reg_name: reg_name = reg_name + '(+)' # if fn is None: # fn = f'{reg_name.strip("(+)")}.pdf' # prepare plotting data arr2 = self.stereo_exp_data.position_z position_3D = np.concatenate((self.stereo_exp_data.position, arr2), axis=1) cell_coor = position_3D assert cell_coor.shape[1] == 3 # TODO: ensure position is 3D auc_mtx = self.pipeline_res[network_res_key]['auc_matrix'] auc_zscore = cal_zscore(auc_mtx) sub_zscore = auc_zscore[reg_name] # plot fig = plt.figure() if width is not None: fig.set_figwidth(width) if height is not None: fig.set_figheight(height) # ax = Axes3D(fig) ax: Axes3D = fig.add_subplot(projection='3d') sc = ax.scatter(cell_coor[:, 0], cell_coor[:, 1], cell_coor[:, 2], c=sub_zscore, marker='.', edgecolors='none', cmap='plasma', lw=0, **kwargs) # set view angle ax.view_init(view_vertical, view_horizontal) # scale axis # xlen = cell_coor[:, 0].max() - cell_coor[:, 0].min() # ylen = cell_coor[:, 1].max() - cell_coor[:, 1].min() # zlen = cell_coor[:, 2].max() - cell_coor[:, 2].min() # yscale = ylen / xlen # zscale = zlen / xlen # ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([1, yscale, zscale, 1])) if not show_axis: plt.box(False) plt.axis('off') plt.colorbar(sc, shrink=0.35) # plt.savefig(fn, format='pdf') return fig
def get_n_hls_colors(num): import random hls_colors = [] i = 0 step = 360.0 / num while i < 360: h = i s = 90 + random.random() * 10 l = 50 + random.random() * 10 # noqa _hlsc = [h / 360.0, l / 100.0, s / 100.0] hls_colors.append(_hlsc) i += step return hls_colors def rgb_to_hex(r, g, b): return '#{:02x}{:02x}{:02x}'.format(r, g, b) def ncolors(num): import colorsys rgb_colors = [] if num < 1: return rgb_colors hls_colors = get_n_hls_colors(num) for hlsc in hls_colors: _r, _g, _b = colorsys.hls_to_rgb(hlsc[0], hlsc[1], hlsc[2]) r, g, b = [int(x * 255.0) for x in (_r, _g, _b)] rgb_colors.append(rgb_to_hex(r, g, b)) return rgb_colors def cal_zscore(auc_mtx: pd.DataFrame) -> pd.DataFrame: """ calculate z-score for each gene among cells :param auc_mtx: :return: """ func = lambda x: (x - x.mean()) / x.std(ddof=0) # noqa auc_zscore = auc_mtx.transform(func, axis=0) return auc_zscore