#!/usr/bin/env python3
# coding: utf-8
"""
@author: qindanhua@genomics.cn
@time:2021/08/31
"""
from random import randint
from typing import (
Optional,
Union,
Sequence,
Literal,
Iterable
)
import hvplot.pandas # noqa
import matplotlib.pyplot as plt
import numpy as np
import panel as pn
import seaborn as sns
import tifffile as tiff
from matplotlib.axes import Axes
from natsort import natsorted
from stereo.constant import (
N_GENES_BY_COUNTS,
PCT_COUNTS_MT,
TOTAL_COUNTS,
PLOT_SCATTER_SIZE_FACTOR,
PLOT_BASE_IMAGE_EXPANSION
)
from stereo.core.stereo_exp_data import StereoExpData
from stereo.log_manager import logger
from stereo.stereo_config import stereo_conf
from .decorator import (
plot_scale,
download,
reorganize_coordinate
)
from .plot_base import PlotBase
from .scatter import (
base_scatter,
multi_scatter,
marker_gene_volcano,
highly_variable_genes
)
pn.param.ParamMethod.loading_indicator = True
class PlotCollection:
"""
The plot collection for StereoExpData object.
Parameters
--------------
data:
- a StereoExpData object.
"""
def __init__(
self,
data: StereoExpData
):
self.data: StereoExpData = data
self.result: dict = self.data.tl.result
self.marker_gene_volcano = self.marker_genes_volcano
def __getattr__(self, item):
dict_attr = self.__dict__.get(item, None)
if dict_attr:
return dict_attr
# start with __ may not be our algorithm function, and will cause import problem
if item.startswith('__'):
raise AttributeError
new_attr = PlotBase.get_attribute_helper(item, self.data, self.result)
if getattr(new_attr, '__download__', True):
new_attr = download(new_attr)
if new_attr:
self.__setattr__(item, new_attr)
logger.info(f'register plot_func {item} to {self}')
return new_attr
raise AttributeError(
f'{item} not existed, please check the function name you called!'
)
[docs] @reorganize_coordinate
def interact_cluster(
self,
res_key: str,
inline: Optional[bool] = True,
width: Optional[int] = 700,
height: Optional[int] = 500
):
"""
Interactive spatial scatter after clustering.
:param res_key: the result key of clustering.
:param inline: show in notebook.
:param width: the figure width.
:param height: the figure height.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
""" # noqa
res = self.check_res_key(res_key)
from .interact_plot.spatial_cluster import interact_spatial_cluster
import pandas as pd
df = pd.DataFrame({
'x': self.data.position[:, 0],
'y': self.data.position[:, 1],
'group': np.array(res['group'])
})
fig = interact_spatial_cluster(df, width=width, height=height)
if not inline:
fig.show()
return fig
[docs] @reorganize_coordinate
def interact_annotation_cluster(
self,
res_cluster_key: str,
res_marker_gene_key: str,
res_key: str,
inline: Optional[bool] = True,
width: Optional[int] = 700,
height: Optional[int] = 500,
):
"""
Interactive spatial scatter after clustering.
:param res_cluster_key: the result key of annotation.
:param res_marker_gene_key: the result key of marker genes.
:param res_key: the key for getting the result from the `self.result`.
:param inline: show in notebook.
:param width: the figure width.
:param height: the figure height.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
""" # noqa
res = self.check_res_key(res_cluster_key)
res_marker_gene = self.check_res_key(res_marker_gene_key)
from .interact_plot.annotation_cluster import interact_spatial_cluster_annotation
import pandas as pd
df = pd.DataFrame({
'x': self.data.position[:, 0],
'y': self.data.position[:, 1],
'bins': self.data.cell_names,
'group': np.array(res['group'])
})
fig = interact_spatial_cluster_annotation(self.data, df, res_marker_gene, res_key, width=width, height=height)
if not inline:
fig.show()
return fig
[docs] @download
def highly_variable_genes(
self,
res_key: str,
width: Optional[int] = None,
height: Optional[int] = None,
xy_label: Optional[list] = ['mean expression of genes', 'dispersions of genes (normalized)'],
xyII_label: Optional[list] = ['mean expression of genes', 'dispersions of genes (not normalized)']
):
"""
Scatter of highly variable genes
:param res_key: the result key of highly variable genes.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param xy_label: the x、y label of the first figure.
:param xyII_label: the x、y label of the second figure.
"""
res = self.check_res_key(res_key)
return highly_variable_genes(res, width=width, height=height, xy_label=xy_label, xyII_label=xyII_label)
[docs] @download
def marker_genes_volcano(
self,
group_name: str,
res_key: Optional[str] = 'marker_genes',
hue_order: Optional[set] = ('down', 'normal', 'up'),
palette: Optional[Union[list, tuple]] = ("#377EB8", "grey", "#E41A1C"),
alpha: Optional[int] = 1,
dot_size: Optional[int] = 15,
text_genes: Optional[list] = None,
x_label: Optional[str] = 'log2(fold change)',
y_label: Optional[str] = '-log10(pvalue)',
vlines: Optional[bool] = True,
cut_off_pvalue: Optional[float] = 0.01,
cut_off_logFC: Optional[int] = 1,
width: Optional[int] = None,
height: Optional[int] = None,
**kwargs
):
"""
Volcano plot of maker genes.
:param group_name: the group name.
:param res_key: the result key of marker gene.
:param hue_order: the classification method.
:param palette: the color theme, a list of colors whose length is 3,
in which, each one respectively specifies the color of 'down', 'normal' and 'up' marker genes.
:param alpha: the opacity.
:param dot_size: the dot size.
:param text_genes: show gene names.
:param x_label: the x label.
:param y_label: the y label.
:param vlines: plot cutoff line or not.
:param cut_off_pvalue: cut off of p-value to define gene type, p-values < cut_off and log2fc > cut_off_logFC
define as up genes, p-values < cut_off and log2fc < -cut_off_logFC define as down genes.
:param cut_off_logFC: cut off of log2fc to define gene type.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
res = self.check_res_key(res_key)[group_name]
fig = marker_gene_volcano(
res,
text_genes=text_genes,
cut_off_pvalue=cut_off_pvalue,
cut_off_logFC=cut_off_logFC,
hue_order=hue_order,
palette=palette,
alpha=alpha, s=dot_size,
x_label=x_label, y_label=y_label,
vlines=vlines,
width=width,
height=height,
**kwargs
)
return fig
[docs] @download
def genes_count(
self,
x_label: Optional[list] = ["total_counts", "total_counts"],
y_label: Optional[list] = ["pct_counts_mt", "n_genes_by_counts"],
ncols: Optional[int] = 2,
dot_size: Optional[int] = None,
palette: Optional[str] = '#808080',
width: Optional[int] = None,
height: Optional[int] = None,
**kwargs
):
"""
Quality control index distribution visualization.
:param x_label: list of x label.
:param y_label: list of y label.
:param ncols: the number of columns.
:param dot_size: the dot size.
:param palette: a single color specifying the color of markers.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
""" # noqa
import math
# import matplotlib.pyplot as plt
from matplotlib import gridspec
set_xy_empty = False
if x_label == y_label == '' or x_label == y_label == []:
set_xy_empty = True
x = [TOTAL_COUNTS] * 2
y = [PCT_COUNTS_MT, N_GENES_BY_COUNTS]
else:
x = [x_label] if isinstance(x_label, str) else x_label
y = [y_label] if isinstance(y_label, str) else y_label
if width is None or height is None:
width, height = 12, 6
else:
width = width / 100 if width >= 100 else 12
height = height / 100 if height >= 100 else 6
nrows = math.ceil(len(x) / ncols)
fig = plt.figure(figsize=(width, height))
axs = gridspec.GridSpec(
nrows=nrows,
ncols=ncols,
)
for i, (xi, yi) in enumerate(zip(x, y)):
draw_data = np.c_[self.data.cells.get_property(xi), self.data.cells.get_property(yi)]
ax = fig.add_subplot(axs[i])
base_scatter(
draw_data[:, 0],
draw_data[:, 1],
hue=[0 for i in range(len(draw_data[:, 1]))],
ax=ax,
palette=[palette],
x_label=' '.join(xi.split('_')) if not set_xy_empty else '',
y_label=' '.join(yi.split('_')) if not set_xy_empty else '',
dot_size=dot_size,
color_bar=False,
show_legend=False,
invert_y=False,
show_ticks=True,
**kwargs
)
return fig
def __create_base_image_data(
self,
base_image_path: str,
data_x_min: int,
data_x_max: int,
data_y_min: int,
data_y_max: int,
invert_y: bool,
clip: bool = True,
):
base_im_boundary = None
base_image_data = None
base_im_value_range = None
with tiff.TiffFile(base_image_path) as tif:
base_image_data = tif.asarray()
if clip:
if data_x_min > 0 or data_y_min > 0:
data_x_min = max(0, data_x_min - PLOT_BASE_IMAGE_EXPANSION)
data_y_min = max(0, data_y_min - PLOT_BASE_IMAGE_EXPANSION)
data_x_max += PLOT_BASE_IMAGE_EXPANSION
data_y_max += PLOT_BASE_IMAGE_EXPANSION
base_image_data = base_image_data[data_y_min:(data_y_max + 1), data_x_min:(data_x_max + 1)]
if invert_y:
base_im_boundary = [data_x_min, data_x_max, data_y_max, data_y_min]
else:
base_im_boundary = [data_x_min, data_x_max, data_y_min, data_y_max]
else:
if invert_y:
base_im_boundary = [0, base_image_data.shape[1] - 1, base_image_data.shape[0] - 1, 0]
else:
base_im_boundary = [0, base_image_data.shape[1] - 1, 0, base_image_data.shape[0] - 1]
shaped_metadata = tif.shaped_metadata
if shaped_metadata is not None:
metadata = shaped_metadata[0]
if 'value_range' in metadata:
base_im_value_range = metadata['value_range']
return base_image_data, base_im_boundary, base_im_value_range
[docs] @download
@plot_scale
@reorganize_coordinate
def spatial_scatter(
self,
cells_key: Optional[list] = ["total_counts", "n_genes_by_counts"],
ncols: Optional[int] = 2,
dot_size: Optional[int] = None,
palette: Optional[Union[str, list]] = 'stereo',
width: Optional[int] = None,
height: Optional[int] = None,
x_label: Optional[Union[list, str]] = 'spatial1',
y_label: Optional[Union[list, str]] = 'spatial2',
title: Optional[str] = None,
vmin: float = None,
vmax: float = None,
base_image: Optional[str] = None,
base_im_cmap: Optional[str] = 'Greys',
base_im_to_gray : bool = False,
clip_base_image: bool = True,
**kwargs
):
"""
Spatial distribution of `total_counts` and `n_genes_by_counts`.
:param cells_key: specified cells key list.
:param ncols: the number of plot columns.
:param dot_size: the dot size.
:param palette: a palette name or a list of colors.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param x_label: list of x label.
:param y_label: list of y label.
:param title: the title label.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
:param base_image: the path of mask image to be displayed as background, it must already be registered to the same coordinate system as the data.
:param base_im_cmap: the color map of the base image, only availabel when base image is gray scale image.
:param base_im_to_gray: whether to convert the base image to gray scale if base image is RGB/RGBA image.
:param show_plotting_scale: wheter to display the plotting scale.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
""" # noqa
from .scatter import multi_scatter
x = self.data.position[:, 0]
y = self.data.position[:, 1]
# x_min, x_max = int(x.min()), int(x.max())
# y_min, y_max = int(y.min()), int(y.max())
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()
boundary = [x_min, x_max, y_min, y_max]
marker = 's'
base_im_boundary = None
base_image_data = None
base_im_value_range = None
if base_image is not None:
base_image_data, base_im_boundary, base_im_value_range = self.__create_base_image_data(
base_image, x_min, x_max, y_min, y_max, invert_y=kwargs.get('invert_y', True), clip=clip_base_image
)
boundary = base_im_boundary[0:2] + [min(base_im_boundary[2:4]), max(base_im_boundary[2:4])]
marker = 'o'
if 'marker' not in kwargs:
kwargs['marker'] = marker
if isinstance(cells_key, str):
cells_key = [cells_key]
if title is None:
title = [' '.join(i.split('_')) for i in cells_key]
if isinstance(x_label, str):
x_label = [x_label] * len(cells_key)
if isinstance(y_label, str):
y_label = [y_label] * len(cells_key)
fig = multi_scatter(
x=self.data.position[:, 0],
y=self.data.position[:, 1],
hue=[self.data.cells.get_property(key) for key in cells_key],
x_label=x_label,
y_label=y_label,
title=title,
ncols=ncols,
dot_size=dot_size,
palette=palette,
color_bar=True,
width=width,
height=height,
vmin=vmin,
vmax=vmax,
boundary=boundary,
base_image=base_image_data,
base_im_cmap=base_im_cmap,
base_im_boundary=base_im_boundary,
base_im_value_range=base_im_value_range,
base_im_to_gray=base_im_to_gray,
**kwargs
)
return fig
[docs] @download
@plot_scale
@reorganize_coordinate
def spatial_scatter_by_gene(
self,
gene_name: Union[str, list, np.ndarray],
dot_size: Optional[int] = None,
palette: Optional[Union[str, list]] = 'CET_L4',
color_bar_reverse: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None,
x_label: Optional[str] = 'spatial1',
y_label: Optional[str] = 'spatial2',
title: Optional[str] = None,
vmin: float = None,
vmax: float = None,
**kwargs
):
"""Draw the spatial distribution of expression quantity of the gene specified by gene names.
:param gene_name: a gene or a list of genes you want to show.
:param dot_size: the dot size, defaults to `None`.
:param palette: a palette name or a list of colors, defaults to `'CET_L4'`.
:param color_bar_reverse: if True, reverse the color bar, defaults to False
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param show_plotting_scale: wheter to display the plotting scale.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
:param x_label: the x label.
:param y_label: the y label.
:param title: the title label.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
""" # noqa
self.data.array2sparse()
if isinstance(gene_name, str):
gene_name = [gene_name]
gene_idx = [np.argwhere(self.data.gene_names == gn)[0][0] for gn in gene_name]
hue = self.data.exp_matrix[:, gene_idx].T
fig = multi_scatter(
x=self.data.position[:, 0],
y=self.data.position[:, 1],
hue=hue,
x_label=[x_label] * len(gene_name),
y_label=[y_label] * len(gene_name),
title=gene_name if title is None else title,
ncols=2,
dot_size=dot_size,
palette=palette,
color_bar=True,
color_bar_reverse=color_bar_reverse,
width=width,
height=height,
vmin=vmin,
vmax=vmax,
**kwargs
)
return fig
[docs] @download
@plot_scale
@reorganize_coordinate
def gaussian_smooth_scatter_by_gene(
self,
gene_name: str = None,
dot_size: Optional[int] = None,
palette: Optional[Union[str, list]] = 'CET_L4',
color_bar_reverse: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None,
x_label: Optional[list] = ['spatial1', 'spatial1'],
y_label: Optional[list] = ['spatial2', 'spatial2'],
title: Optional[list] = None,
vmin: float = None,
vmax: float = None,
**kwargs
):
"""Draw the spatial distribution of expression quantity of the gene specified by gene names,
just only for Gaussian smoothing, inluding the raw and smoothed.
:param gene_name: specify the gene you want to draw, if `None` by default, will select randomly.
:param dot_size: marker sizemarker size, defaults to `None`.
:param palette: a palette name or a list of colors, defaults to `'CET_L4'`.
:param color_bar_reverse: if True, reverse the color bar, defaults to False
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param x_label: list of x label.
:param y_label: list of y label.
:param title: list of title label(lists of size two).
:param show_plotting_scale: wheter to display the plotting scale.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
""" # noqa
if gene_name is None:
idx = randint(0, len(self.data.tl.raw.gene_names) - 1)
gene_name = self.data.gene_names[idx]
else:
if gene_name not in self.data.gene_names:
raise Exception(f'gene {gene_name} do not exist in expression matrix')
idx = np.argwhere(self.data.gene_names == gene_name)[0][0]
raw_exp_data = self.data.tl.raw.exp_matrix[:, idx].T
exp_data = self.data.exp_matrix[:, idx].T
hue_list = [raw_exp_data, exp_data]
if not (title and len(title) == 2):
title = [f'{gene_name}(raw)', f'{gene_name}(smoothed)']
fig = multi_scatter(
x=self.data.position[:, 0],
y=self.data.position[:, 1],
hue=hue_list,
x_label=x_label,
y_label=y_label,
title=title,
ncols=2,
dot_size=dot_size,
palette=palette,
color_bar=True,
color_bar_reverse=color_bar_reverse,
width=width,
height=height,
vmin=vmin,
vmax=vmax,
**kwargs
)
return fig
[docs] @download
def violin(
self,
keys: Union[str, Sequence[str]] = [TOTAL_COUNTS, N_GENES_BY_COUNTS, PCT_COUNTS_MT],
x_label: Optional[str] = '',
y_label: Optional[list] = None,
show_stripplot: Optional[bool] = False,
jitter: Optional[float] = 0.2,
dot_size: Optional[float] = 0.8,
log: Optional[bool] = False,
rotation_angle: Optional[int] = 0,
group_by: Optional[str] = None,
multi_panel: bool = None,
scale: Literal['area', 'count', 'width'] = 'width',
ax: Optional[Axes] = None,
order: Optional[Iterable[str]] = None,
use_raw: Optional[bool] = False,
palette: Optional[str] = None,
title: Optional[str] = None,
):
"""
Violin plot to show index distribution of quality control.
:param keys: Keys for accessing variables of .cells.
:param x_label: x label.
:param y_label: y label.
:param show_stripplot: whether to overlay a stripplot of specific percentage values.
:param jitter: adjust the dispersion of points.
:param dot_size: dot size.
:param log: plot a graph on a logarithmic axis.
:param rotation_angle: rotation of xtick labels.
:param group_by: the key of the observation grouping to consider.
:param multi_panel: Display keys in multiple panels also when groupby is not None.
:param scale: The method used to scale the width of each violin. If 'width' (the default), each violin will
have the same width. If 'area', each violin will have the same area.
If 'count', a violin's width corresponds to the number of observations.
:param ax: a matplotlib axes object. only works if plotting a single component.
:param order: Order in which to show the categories.
:param use_raw: Whether to use raw attribute of data. Defaults to True if .raw is present.
:param title: the title.
:param palette: color theme.
For more color theme selection reference: https://seaborn.pydata.org/tutorial/color_palettes.html
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
from .violin import violin_distribution
return violin_distribution(
self.data,
keys=keys,
x_label=x_label,
y_label=y_label,
show_stripplot=show_stripplot,
jitter=jitter,
dot_size=dot_size,
log=log,
rotation_angle=rotation_angle,
group_by=group_by,
multi_panel=multi_panel,
scale=scale,
ax=ax,
order=order,
use_raw=use_raw,
palette=palette,
title=title
)
[docs] @reorganize_coordinate
def interact_spatial_scatter(
self,
inline: Optional[bool] = True,
width: Optional[int] = 600,
height: Optional[int] = 600,
bgcolor: Optional[str] = '#2F2F4F',
poly_select: Optional[bool] = False
):
"""
Interactive spatial distribution.
:param inline: show in notebook.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param bgcolor: set background color.
:param poly_select: poly select or not.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
""" # noqa
from .interact_plot.interactive_scatter import InteractiveScatter
fig = InteractiveScatter(self.data, width=width, height=height, bgcolor=bgcolor)
if poly_select:
from stereo.plots.interact_plot.poly_selection import PolySelection
fig = PolySelection(self.data, width=width, height=height, bgcolor=bgcolor)
if not inline:
fig.figure.show()
return fig
# def batches_umap(
# self,
# res_key: str,
# title: Optional[str] = 'umap of each batch',
# x_label: Optional[str] = 'umap1',
# y_label: Optional[str] = 'umap2',
# bfig_title: Optional[str] = 'all batches',
# dot_size: Optional[int] = 1,
# palette: Optional[Union[str, list, dict]] = 'stereo_30',
# width: Optional[int] = None,
# height: Optional[int] = None
# ):
# """
# Plot batch umap
# :param res_key: the result key of UMAP.
# :param title: the plot titles.
# :param x_label: the x label.
# :param y_label: the y label.
# :param bfig_title: the big figure title.
# :param dot_size: the dot size.
# :param palette: a palette name, a list of colors whose length is equal to the batches,
# or a dict whose keys are batch numbers and values are colors.
# :param width: the figure width in pixels.
# :param height: the figure height in pixels.
# """
# import holoviews as hv
# import panel as pn
# from bokeh.models import Title
# pn.extension()
# hv.extension('bokeh')
# assert self.data.cells.batch is not None, "there is no batches number list"
# if width is None or height is None:
# main_width, main_height = 500, 500
# sub_width, sub_height = 200, 200
# else:
# main_width = width
# main_height = height
# sub_width = np.ceil(width * 0.4).astype(np.int32)
# sub_height = np.ceil(height * 0.4).astype(np.int32)
# umap_res = self.check_res_key(res_key)
# umap_res = umap_res.rename(columns={0: 'x', 1: 'y'})
# umap_res['batch'] = self.data.cells.batch.astype(np.uint16)
# batch_number_unique = np.unique(umap_res['batch'])
# batch_count = len(batch_number_unique)
# cmap = stereo_conf.get_colors(palette, batch_count, order=batch_number_unique)
# cmap_dict = {bn: c for bn, c in zip(batch_number_unique, cmap)}
# fig_all = umap_res.hvplot.scatter(
# x='x', y='y', c='batch', cmap=cmap_dict, cnorm='eq_hist',
# ).opts(
# width=main_width,
# height=main_height,
# invert_yaxis=True,
# xlabel=x_label,
# ylabel=y_label,
# size=dot_size,
# toolbar='disable',
# colorbar=False,
# show_legend=False
# )
# bfig_all = hv.render(fig_all)
# bfig_all.axis.major_tick_line_alpha = 0
# bfig_all.axis.minor_tick_line_alpha = 0
# bfig_all.axis.major_label_text_alpha = 0
# bfig_all.axis.axis_line_alpha = 0
# bfig_all.title = Title(text=bfig_title, align='center')
# bfig_batches = []
# pn_rows = []
# for i, bn, c in zip(range(batch_count), batch_number_unique, cmap):
# sub_umap_res = umap_res[umap_res.batch == bn]
# fig = sub_umap_res.hvplot.scatter(
# x='x', y='y',
# c='batch', color=c, cnorm='eq_hist',
# ).opts(
# width=sub_width,
# height=sub_height,
# xaxis=None,
# yaxis=None,
# invert_yaxis=True,
# size=(dot_size / 3),
# toolbar='disable',
# colorbar=False,
# show_legend=False
# )
# bfig = hv.render(fig)
# bn = str(bn)
# bfig.title = Title(text=f'sn: {self.data.sn[bn]}', align='center')
# bfig_batches.append(bfig)
# if ((i + 1) % 2) == 0 or i == (batch_count - 1):
# pn_rows.append(pn.Row(*bfig_batches))
# bfig_batches.clear()
# return pn.Column(
# f"\n# {title}",
# pn.Row(
# pn.Column(bfig_all),
# pn.Column(*pn_rows)
# )
# )
[docs] @download
def batches_umap(
self,
res_key: str,
title: Optional[str] = 'umap of each batch',
x_label: Optional[str] = 'umap1',
y_label: Optional[str] = 'umap2',
main_title: Optional[str] = 'batches',
sub_titles: Optional[list] = None,
dot_size: Optional[int] = 3,
palette: Optional[Union[str, list, dict]] = 'stereo_30',
width: Optional[int] = None,
height: Optional[int] = None,
sub_cols: Optional[int] = 2,
**kwargs
):
"""
Plot the umap of batches, the leftmost is the umap for all batches
while the others are for each batch.
:param res_key: the result key of UMAP.
:param title: the plot titles.
:param x_label: the x label.
:param y_label: the y label.
:param main_title: the title of the leftmost plot.
:param sub_titles: a list of titles of the plots except the leftmost one.
:param dot_size: the dot size.
:param palette: a palette name, a list of colors whose length is equal to the batches,
or a dict whose keys are batch numbers and values are colors.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param sub_cols: the number of columns for the plots except the leftmost one.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
assert self.data.cells.batch is not None, \
"there are no batch numbers, it may not be a data merged from multiple slices."
umap_res = self.check_res_key(res_key)
umap_res = umap_res.rename(columns={0: 'x', 1: 'y'})
umap_res['batch'] = self.data.cells.batch
batch_number_unique = natsorted(np.unique(umap_res['batch']))
batch_count = len(batch_number_unique)
cmap = stereo_conf.get_colors(palette, batch_count, order=batch_number_unique)
cmap_dict = {bn: c for bn, c in zip(batch_number_unique, cmap)}
default_main_width = 500
default_main_height = 500
default_sub_width = 250
default_sub_height = 250
sub_rows = batch_count // sub_cols + 1 if batch_count % sub_cols != 0 else batch_count // sub_cols
if sub_rows == 1:
sub_rows = 2
if width is None:
width = default_main_width + default_sub_width * sub_cols
sub_width = default_sub_width
else:
sub_width = np.ceil(width / (sub_cols + 2)).astype(np.int32)
if height is None:
height = max(default_main_height, default_sub_height * sub_rows)
sub_height = sub_width
else:
sub_height = np.ceil(height / sub_rows).astype(np.int32)
width_ratios = [width - sub_width * sub_cols] + [sub_width] * sub_cols
height_ratios = [sub_height] * sub_rows
fig = plt.figure(figsize=(width / 100, height / 100))
gs = fig.add_gridspec(sub_rows, sub_cols + 1, width_ratios=width_ratios, height_ratios=height_ratios)
if title is not None:
fig.suptitle(title, fontsize=20, fontweight='bold', va='top', ha='center')
if 'show_legend' in kwargs:
del kwargs['show_legend']
if sub_rows >= 2:
ax_main = fig.add_subplot(gs[0:2, 0])
else:
ax_main = fig.add_subplot(gs[:, 0])
base_scatter(
umap_res['x'],
umap_res['y'],
hue=umap_res['batch'],
palette=cmap_dict,
title=main_title,
color_bar=False,
x_label=x_label,
y_label=y_label,
dot_size=dot_size,
width=width,
height=height,
ax=ax_main,
show_legend=False,
**kwargs
)
for i, bn, c in zip(range(batch_count), batch_number_unique, cmap):
sub_umap_res = umap_res[umap_res.batch == bn]
ax = fig.add_subplot(gs[i // sub_cols, i % sub_cols + 1])
base_scatter(
sub_umap_res['x'],
sub_umap_res['y'],
hue=sub_umap_res['batch'],
palette={bn: c},
color_bar=False,
dot_size=dot_size,
width=sub_width,
height=sub_height,
ax=ax,
show_legend=False,
**kwargs
)
sub_tilte = f'batch-{bn}' if sub_titles is None else sub_titles[i]
ax.set_title(sub_tilte, fontsize=10, fontweight='bold')
fig.tight_layout(pad=1)
return fig
[docs] @download
def umap(
self,
gene_names: Optional[Union[list, np.ndarray, str]] = None,
res_key: str = 'umap',
cluster_key: Optional[str] = None,
title: Optional[Union[str, list]] = None,
x_label: Optional[Union[str, list]] = 'umap1',
y_label: Optional[Union[str, list]] = 'umap2',
dot_size: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
palette: Optional[Union[int, list]] = None,
vmin: float = None,
vmax: float = None,
**kwargs
):
"""
Scatter plot of UMAP after reducing dimensionalities.
:param gene_names: the list of gene names.
:param cluster_key: the result key of clustering.
:param res_key: the result key of UMAP.
:param title: the plot title.
:param x_label: the x label.
:param y_label: the y label.
:param dot_size: the dot size.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param palette: a palette name of a list of colors.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
""" # noqa
res = self.check_res_key(res_key)
if palette is None:
palette = 'stereo_30' if cluster_key else 'stereo'
if cluster_key:
cluster_res = self.check_res_key(cluster_key)
n = len(set(cluster_res['group']))
if title is None:
title = cluster_key
# if not palette:
# palette = stereo_conf.get_colors('stereo_30' if colors == 'stereo' else colors, n)
return base_scatter(
res.values[:, 0],
res.values[:, 1],
hue=cluster_res['group'],
palette=palette,
title=title,
color_bar=False,
x_label=x_label,
y_label=y_label,
dot_size=dot_size,
width=width,
height=height,
**kwargs)
else:
# self.data.array2sparse()
if gene_names is None:
raise ValueError('gene name must be set if cluster_key is None')
if isinstance(gene_names, str):
gene_names = [gene_names]
return multi_scatter(
res.values[:, 0],
res.values[:, 1],
hue=self.data.sub_exp_matrix_by_name(gene_name=gene_names).T,
palette=palette,
title=gene_names if title is None else title,
x_label=[x_label for i in range(len(gene_names))],
y_label=[y_label for i in range(len(gene_names))],
dot_size=dot_size,
color_bar=True,
width=width,
height=height,
vmin=vmin,
vmax=vmax,
**kwargs
)
[docs] @download
@plot_scale
@reorganize_coordinate
def cluster_scatter(
self,
res_key: str,
groups: Optional[Union[str, list, np.ndarray]] = None,
show_others: Optional[bool] = None,
others_color: Optional[str] = '#828282',
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
dot_size: Optional[int] = None,
# colors: Optional[str] = 'stereo_30',
palette: Optional[Union[str, dict, list]] = 'stereo_30',
invert_y: Optional[bool] = True,
hue_order: Optional[set] = None,
width: Optional[int] = None,
height: Optional[int] = None,
base_image: Optional[str] = None,
base_im_cmap: Optional[str] = 'Greys',
base_im_to_gray : bool = False,
clip_base_image: bool = True,
**kwargs
):
"""
Spatial scatter distribution of clusters.
:param res_key: cluster result key.
:param groups: the group names.
:param show_others: whether to show others when groups is not None.
by default, if `base_image` is None, `show_others` is True, otherwise `show_others` is False.
:param others_color: the color of others, only available when `groups` is not None and `show_others` is True.
:param title: the plot title, defaults to None to be set as `res_key`, set it to False to disable the title.
:param x_label: the x label.
:param y_label: the y label.
:param dot_size: the dot size.
:param palette: a palette name, a list of colors whose length at least equal to the groups to be shown or
a dict whose keys are the groups and values are the colors.
:param invert_y: whether to invert y-axis.
:param hue_order: the classification method.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param base_image: the path of mask image to be displayed as background, it must already be registered to the same coordinate system as the data.
:param base_im_cmap: the color map of the base image, only availabel when base image is gray scale image.
:param base_im_to_gray: whether to convert the base image to gray scale if base image is RGB/RGBA image.
:param show_plotting_scale: wheter to display the plotting scale.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
:return: Spatial scatter distribution of clusters.
""" # noqa
res = self.check_res_key(res_key)
group_list = res['group'].to_numpy(copy=True)
if hue_order is None:
hue_order = natsorted(np.unique(group_list))
n = len(hue_order)
# palette = stereo_conf.get_colors(colors, n=n, order=hue_order)
x = self.data.position[:, 0]
y = self.data.position[:, 1]
x_min, x_max = int(x.min()), int(x.max())
y_min, y_max = int(y.min()), int(y.max())
boundary = [x_min, x_max, y_min, y_max]
marker = 's'
if dot_size is None:
dot_size = PLOT_SCATTER_SIZE_FACTOR / group_list.size
if groups is not None:
if isinstance(groups, str):
groups = [groups]
isin = np.in1d(group_list, groups)
if not np.all(isin):
if show_others is None:
if base_image is None:
show_others = True
else:
show_others = False
if show_others:
group_list[~isin] = 'others'
n = np.unique(group_list).size
# palette = palette[0:n - 1] + [others_color]
hue_order = natsorted(np.unique(group_list[isin])) + ['others']
palette = stereo_conf.get_colors(palette, n=n-1, order=hue_order)
palette.append(others_color)
else:
group_list = group_list[isin]
n = np.unique(group_list).size
# palette = palette[0:n]
hue_order = natsorted(np.unique(group_list))
palette = stereo_conf.get_colors(palette, n=n, order=hue_order)
x = x[isin]
y = y[isin]
base_im_boundary = None
base_image_data = None
base_im_value_range = None
if base_image is not None:
base_image_data, base_im_boundary, base_im_value_range = self.__create_base_image_data(
base_image, x_min, x_max, y_min, y_max, invert_y=kwargs.get('invert_y', True), clip=clip_base_image
)
boundary = base_im_boundary[0:2] + [min(base_im_boundary[2:4]), max(base_im_boundary[2:4])]
marker = '.'
if 'marker' in kwargs:
marker = kwargs['marker']
del kwargs['marker']
if title is None:
title = res_key
elif title is False:
title = None
fig = base_scatter(
x, y,
hue=group_list,
palette=palette,
title=title,
x_label=x_label,
y_label=y_label,
dot_size=dot_size,
marker=marker,
invert_y=invert_y,
hue_order=hue_order,
width=width,
height=height,
boundary=boundary,
base_image=base_image_data,
base_im_cmap=base_im_cmap,
base_im_boundary=base_im_boundary,
base_im_value_range=base_im_value_range,
base_im_to_gray=base_im_to_gray,
**kwargs
)
return fig
[docs] @download
def marker_genes_text(
self,
res_key: str,
groups: Union[str, Sequence[str]] = 'all',
markers_num: Optional[int] = 20,
sort_key: Optional[str] = 'scores',
ascend: Optional[bool] = False,
fontsize: Optional[int] = 8,
ncols: Optional[int] = 4,
sharey: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None,
**kwargs
):
"""
Scatter plot of maker genes.
:param res_key: the result key of marker genes.
:param groups: the group names.
:param markers_num: top N genes to show in each cluster.
:param sort_key: the sort key for getting top N marker genes, default `'scores'`.
:param ascend: whether to sort by ascending.
:param fontsize: the font size.
:param ncols: number of plot columns.
:param sharey: share scale or not.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
from .marker_genes import marker_genes_text
res = self.check_res_key(res_key)
fig = marker_genes_text(
res,
groups=groups,
markers_num=markers_num,
sort_key=sort_key,
ascend=ascend,
fontsize=fontsize,
ncols=ncols,
sharey=sharey,
width=width,
height=height,
**kwargs
)
return fig
[docs] @download
def marker_genes_heatmap(
self,
res_key: str,
cluster_res_key: str = 'cluster',
markers_num: Optional[int] = 5,
sort_key: Optional[str] = 'scores',
ascend: Optional[bool] = False,
show_labels: Optional[bool] = True,
show_group: Optional[bool] = True,
show_group_txt: Optional[bool] = True,
cluster_colors_array: Optional[bool] = None,
min_value: Optional[int] = None,
max_value: Optional[int] = None,
gene_list: Optional[list] = None,
do_log: Optional[bool] = True,
width: Optional[int] = None,
height: Optional[int] = None
):
"""
Heatmap plot of maker genes.
:param res_key: the result key of marker genes.
:param markers_num: top N maker genes.
:param sort_key: sorted by which key.
:param ascend: whether to sort by ascending.
:param show_labels: show labels or not.
:param show_group: show group or not.
:param show_group_txt: show group names or not.
:param cluster_colors_array: whether to show color scale.
:param min_value: minimum value of scale.
:param max_value: maximum value of scale.
:param gene_list: gene name list.
:param do_log: perform normalization if log1p before plotting, or not.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
from .marker_genes import marker_genes_heatmap
maker_res = self.check_res_key(res_key)
cluster_res_key = maker_res['parameters']['cluster_res_key']
cluster_res = self.check_res_key(cluster_res_key)
cluster_res = cluster_res.set_index(['bins'])
fig = marker_genes_heatmap(
self.data,
cluster_res,
maker_res,
markers_num=markers_num,
sort_key=sort_key,
ascend=ascend,
show_labels=show_labels,
show_group=show_group,
show_group_txt=show_group_txt,
cluster_colors_array=cluster_colors_array,
min_value=min_value,
max_value=max_value,
gene_list=gene_list,
do_log=do_log,
width=width,
height=height
)
return fig
[docs] @download
def marker_genes_scatter(
self,
res_key: str,
markers_num: Optional[int] = 10,
genes: Optional[Sequence[str]] = None,
groups: Optional[Sequence[str]] = None,
values_to_plot: Optional[
Literal[
'scores',
'logfoldchanges',
'pvalues',
'pvalues_adj',
'log10_pvalues',
'log10_pvalues_adj',
]
] = None,
sort_by: Literal[
'scores',
'logfoldchanges',
'pvalues',
'pvalues_adj'
] = 'scores',
width: Optional[int] = None,
height: Optional[int] = None
):
"""Scatter of marker genes
:param res_key: results key, defaults to 'marker_genes'.
:param markers_num: top N makers, defaults to 10.
:param genes: name of genes which would be shown on plot, markers_num is ignored if it is set, defaults to None.
:param groups: cell types which would be shown on plot, all cell types would be shown if set it to None, defaults to None.
:param values_to_plot: specify the value to color the plot, the mean expression in group would be set by default.
available values include: [scores, logfoldchanges, pvalues, pvalues_adj, log10_pvalues, log10_pvalues_adj].
:param sort_by: specify the value which sort by when selecting top N markers, defaults to 'scores'
available values include: [scores, logfoldchanges, pvalues, pvalues_adj].
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
""" # noqa
from .marker_genes import MarkerGenesScatterPlot
marker_genes_res = self.check_res_key(res_key)
mgsp = MarkerGenesScatterPlot(self.data, marker_genes_res)
return mgsp.plot_scatter(
markers_num=markers_num,
genes=genes,
groups=groups,
values_to_plot=values_to_plot,
sort_by=sort_by,
width=width,
height=height
)
def check_res_key(self, res_key):
"""
Check if result exist
:param res_key: result key
:return: tool result
"""
if res_key in self.data.tl.result:
res = self.data.tl.result[res_key]
return res
else:
raise ValueError(f'{res_key} result not found, please run tool before plot')
[docs] @download
def hotspot_local_correlations(
self,
res_key: str = 'spatial_hotspot',
width: Optional[int] = None,
height: Optional[int] = None
):
"""
Visualize module scores with spatial position.
:param res_key: the result key of spatial hotspot.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
"""
res = self.check_res_key(res_key)
if width is None or height is None:
width, height = 15, 12
else:
width = width / 100 if width >= 100 else 15
height = height / 100 if height >= 100 else 12
res.plot_local_correlations()
fig = plt.gcf()
fig.set_size_inches(width, height)
return fig
[docs] @download
def hotspot_modules(
self,
res_key: str = 'spatial_hotspot',
ncols: Optional[int] = 2,
dot_size: Optional[int] = None,
palette: Optional[Union[str, list]] = 'stereo',
width: Optional[str] = None,
height: Optional[str] = None,
title: Optional[str] = None,
vmin: float = None,
vmax: float = None,
**kwargs
):
"""
Plot hotspot modules
:param res_key: the result key of spatial hotspot.
:param ncols: the number of columns.
:param dot_size: the dot size.
:param palette: a palette name or a list of colors, defaults to `'stereo'`.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param out_path: the path to save the figure.
:param out_dpi: the dpi when the figure is saved.
:param title: the plot title.
:param vmin: The value representing the lower limit of the color scale. Values smaller than vmin are plotted with the same color as vmin.
:param vmax: The value representing the higher limit of the color scale. Values greater than vmax are plotted with the same color as vmax.
""" # noqa
res = self.check_res_key(res_key)
scores = [res.module_scores[module] for module in range(1, res.modules.max() + 1)]
vmin = np.percentile(scores, 1) if not vmin else vmin
vmax = np.percentile(scores, 99) if not vmax else vmax
title = [f"module {module}" for module in
range(1, res.modules.max() + 1)] if title is None and title != '' else title
fig = multi_scatter(
x=res.latent.iloc[:, 0],
y=res.latent.iloc[:, 1],
hue=scores,
title=title,
ncols=ncols,
dot_size=dot_size,
palette=palette,
color_bar=True,
vmin=vmin,
vmax=vmax,
width=width,
height=height,
**kwargs
)
return fig
def scenic_regulons(
self,
res_key: str,
):
"""
Plot scenic regulons
:param res_key: result key.
"""
res = self.check_res_key(res_key)
regulons = res["regulons"]
auc_mtx = res["auc_mtx"]
for tf in range(0, len(regulons)):
scores = auc_mtx.iloc[:, tf]
vmin = np.percentile(scores, 1)
vmax = np.percentile(scores, 99)
plt.scatter(x=self.data.position[:, 0],
y=self.data.position[:, 1],
s=8,
c=scores,
vmin=vmin,
vmax=vmax,
edgecolors='none'
)
axes = plt.gca()
for sp in axes.spines.values():
sp.set_visible(False)
plt.xticks([])
plt.yticks([])
plt.title('Regulon {}'.format(auc_mtx.columns[tf]))
plt.show()
def scenic_clustermap(
self,
res_key: str,
):
"""
Plot scenic cluster
:param res_key: result key.
"""
res = self.check_res_key(res_key)
auc_mtx = res["auc_mtx"]
import seaborn as sns
sns.clustermap(auc_mtx, figsize=(12, 12))
plt.show()
[docs] @reorganize_coordinate
def cells_plotting(
self,
color_by: Literal['total_counts', 'n_genes_by_counts', 'gene', 'cluster'] = 'total_counts',
color_key: Optional[str] = None,
bgcolor: Optional[str] = '#2F2F4F',
palette: Optional[Union[str, list, dict]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
fg_alpha: Optional[float] = 0.5,
base_image: Optional[str] = None,
base_im_to_gray: bool = False,
use_raw: bool = True,
show: bool = True
):
"""Plot the cells.
:param color_by: spcify the way of coloring, default to 'total_counts'.
if set to 'gene', you need to specify a gene name by `color_key`.
if set to 'cluster', you need to specify the key to get cluster result by `color_key`.
:param color_key: the key to get the data to color the plot, it is ignored when the `color_by` is set to 'total_counts' or 'n_genes_by_counts'.
:param bgcolor: set background color.
:param palette: color theme,
when `color_by` is 'cluster', it can be a palette name, a list of colors whose length equal to the groups,
or a dict whose keys are the groups and values are colors,
when other `color_by` is set, it only can be a palette name.
:param width: the figure width in pixels.
:param height: the figure height in pixels.
:param fg_alpha: the transparency of foreground image, between 0 and 1, defaults to 0.5
this is the colored image of the cells.
:param base_image: the path of the ssdna image after calibration, defaults to None
it will be located behide the image of the cells.
:param base_im_to_gray: whether to convert the base image to gray scale if base image is RGB/RGBA image.
:param use_raw: whether to use raw data, defaults to True if .raw is present.
:param show: show the figure directly or get the figure object, defaults to True.
If set to False, you need to call the `show` method of the figure object to show the figure.
:param reorganize_coordinate: if the data is merged from several slices, whether to reorganize the coordinates of the obs(cells),
if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below:
---------------
| data1 data2
| data3 data4
| data5 ...
| ... ...
---------------
if set it to `False`, the coordinates will not be changed.
:param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates.
:param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates.
:return: the figure object if `show` is set to False, otherwise, show the figure directly.
.. note::
Exporting
------------------
This plot can be exported as PNG, SVG or PDF.
You need the following necessary dependencies to support exporting:
conda install -c conda-forge selenium firefox geckodriver cairosvg
On Linux, you may need to install some additional libraries to support the above dependencies,
for example, on Ubuntu, the following libraries need to be installed:
sudo apt-get install libgtk-3-dev libasound2-dev
or on CentOS, are as follows:
sudo yum install gtk3-devel alsa-lib-devel
On others Linux, you may need to install the corresponding libraries according to the error message.
There are two ways to export the plot, one is to manupulate on browser when you run it on jupyter notebook,
another is to call the method `save_plot <stereo.plots.plot_cells.PlotCells.save_plot.html>`_ of this figure object.
Example code for the second way:
.. code-block:: python
fig = data.plt.cells_plotting(show=False)
fig.show()
fig.save_plot('path/to/save/plot.pdf')
""" # noqa
from .plot_cells import PlotCells
if color_by in ('cluster', 'gene'):
if not isinstance(color_key, str):
raise TypeError(f"the 'color_key' must be the type of string, but now is {type(color_key)}.")
pc = PlotCells(
self.data,
color_by=color_by,
color_key=color_key,
# cluster_res_key=cluster_res_key,
bgcolor=bgcolor,
palette=palette,
width=width,
height=height,
fg_alpha=fg_alpha,
base_image=base_image,
base_im_to_gray=base_im_to_gray,
use_raw=use_raw
)
if show:
return pc.show()
return pc
@download
def correlation_heatmap(
self,
width: Optional[int] = None,
height: Optional[int] = None,
title: str = 'Correlation Heatmap',
x_label: str = 'x',
y_label: str = 'y',
cmap: str = 'coolwarm'
):
df = self.data.to_df()
correlation_matrix = df.corr()
if width is None:
width = 6
if height is None:
height = 6
clustermap = sns.clustermap(
correlation_matrix,
dendrogram_ratio=0.00001,
cbar_pos=(1.05, 0.5, 0.05, 0.36),
figsize=(width, height),
vmax=1,
vmin=-1,
cmap=cmap
)
clustermap.ax_heatmap.set_title(title, fontweight='bold', fontsize=13)
clustermap.ax_heatmap.set_xlabel(x_label, fontweight='bold', fontsize=10)
clustermap.ax_heatmap.set_ylabel(y_label, fontweight='bold', fontsize=10)
return clustermap.figure