Source code for stereo.utils.data_helper

#!/usr/bin/env python3
# coding: utf-8
"""
@author: Ping Qiu  qiuping1@genomics.cn
@last modified by: Ping Qiu
@file: data_helper.py
@time: 2021/3/14 16:11
"""
from math import ceil
from typing import Optional
from typing import Union
from functools import singledispatch

import anndata as ad
import numba as nb
import numpy as np
import pandas as pd
import scipy.sparse as sp
from natsort import natsorted

from stereo.core.cell import Cell
from stereo.core.gene import Gene
from stereo.core.stereo_exp_data import StereoExpData, AnnBasedStereoExpData


def select_group(groups, cluster, all_groups):
    groups = [groups] if isinstance(groups, str) else groups
    for g in groups:
        if g not in all_groups:
            raise ValueError(f"cluster {g} is not in all cluster.")
    group_index = cluster['group'].isin(groups)
    return group_index


def get_cluster_res(adata, data_key='clustering'):
    cluster_data = adata.uns[data_key].cluster
    cluster = cluster_data['cluster'].astype(str).astype('category').values
    return cluster


def get_position_array(data, obs_key='spatial'):
    return np.array(data.obsm[obs_key])[:, 0: 2]


def exp_matrix2df(data: StereoExpData, cell_name: Optional[np.ndarray] = None, gene_name: Optional[np.ndarray] = None):
    # if data.tl.raw:
    #     cell_isin = np.isin(data.tl.raw.cell_names, data.cell_names)
    #     gene_isin = np.isin(data.tl.raw.gene_names, data.gene_names)
    #     exp_matrix = data.tl.raw.exp_matrix[cell_isin, :][:, gene_isin]
    # else:
    #     exp_matrix = data.exp_matrix
    cell_index = [np.argwhere(data.cells.cell_name == i)[0][0] for i in cell_name] if cell_name is not None else None
    gene_index = [np.argwhere(data.genes.gene_name == i)[0][0] for i in gene_name] if gene_name is not None else None
    # x = exp_matrix[cell_index, :] if cell_index is not None else exp_matrix
    x = data.exp_matrix[cell_index, :] if cell_index is not None else data.exp_matrix
    x = x[:, gene_index] if gene_index is not None else x
    x = x if isinstance(x, np.ndarray) else x.toarray()
    index = cell_name if cell_name is not None else data.cell_names
    columns = gene_name if gene_name is not None else data.gene_names
    df = pd.DataFrame(data=x, index=index, columns=columns)
    return df


def get_top_marker(g_name: str, marker_res: dict, sort_key: str, ascend: bool = False, top_n: int = 10):
    result: pd.DataFrame = marker_res[g_name]
    top_res = result.sort_values(by=sort_key, ascending=ascend).head(top_n).dropna(axis=0, subset=[sort_key])
    return top_res

@singledispatch
def _union_merge(arr1, arr2, col1, col2):
    """
    Merge two array:
               a  b        b  c
            0  1  3     0  1  3
            1  2  4     1  2  4

    To:  a  b  c
       [[1. 3. 0.]
       [2. 4. 0.]
       [0. 1. 3.]
       [0. 2. 4.]]
    """
    pass

@_union_merge.register(np.ndarray)
def _union_merge_array(arr1: np.ndarray, arr2: np.ndarray, col1: np.ndarray, col2: np.ndarray):
    """
    Merge two array:
               a  b        b  c
            0  1  3     0  1  3
            1  2  4     1  2  4

    To:  a  b  c
       [[1. 3. 0.]
       [2. 4. 0.]
       [0. 1. 3.]
       [0. 2. 4.]]
    """
    if (col1.size == col2.size) and np.all(col1 == col2):
        return col1, np.concatenate([arr1, arr2])

    new_col = np.union1d(col1, col2)
    merged_arr = np.zeros([arr1.shape[0] + arr2.shape[0], new_col.size], dtype=arr1.dtype)
    ind1, ind2 = np.where(col1 == new_col[:, None])
    merged_arr[0:arr1.shape[0], ind1] = arr1[:, ind2]
    ind1, ind2 = np.where(col2 == new_col[:, None])
    merged_arr[arr1.shape[0]:(arr1.shape[0] + arr2.shape[0]), ind1] = arr2[:, ind2]
    return new_col, merged_arr


@_union_merge.register(sp.csr_matrix)
def _union_merge_csr_matrix(mtx1: sp.csr_matrix, mtx2: sp.csr_matrix, col1: np.ndarray, col2: np.ndarray):
    """
    Merge two array:
               a  b        b  c
            0  1  3     0  1  3
            1  2  4     1  2  4

    To:  a  b  c
       [[1. 3. 0.]
       [2. 4. 0.]
       [0. 1. 3.]
       [0. 2. 4.]]
    """
    if (col1.size == col2.size) and np.all(col1 == col2):
        return col1, sp.vstack([mtx1, mtx2])

    @nb.njit(cache=True)
    def __merge(
        new_col_size: int, 
        mtx1_shape: tuple,
        mtx1_indptr: np.ndarray,
        mtx1_indices: np.ndarray,
        mtx1_data: np.ndarray,
        ind1_in_new_col: np.ndarray,
        mtx2_shape: np.ndarray,
        mtx2_indptr: np.ndarray,
        mtx2_indices: np.ndarray,
        mtx2_data: np.ndarray,
        ind2_in_new_col: np.ndarray
    ):
        row_count = mtx1_shape[0] + mtx2_shape[0]
        
        data = np.zeros(mtx1_data.size + mtx2_data.size, dtype=mtx1_data.dtype)
        indices = np.zeros(mtx1_indices.size + mtx2_indices.size, dtype=mtx1_indices.dtype)
        indptr = np.zeros(row_count + 1, dtype=mtx1_indptr.dtype)
        
        row_new = np.zeros(new_col_size, dtype=data.dtype)
        
        for i in range(row_count):
            if i < mtx1_shape[0]:
                col_ind_start, col_ind_end = mtx1_indptr[i], mtx1_indptr[i + 1]
                col_ind = mtx1_indices[col_ind_start:col_ind_end]
                row_old = np.zeros(mtx1_shape[1], dtype=data.dtype)
                row_old[col_ind] = mtx1_data[col_ind_start:col_ind_end]
                row_new[ind1_in_new_col] = row_old
            else:
                j = i - mtx1_shape[0]
                col_ind_start, col_ind_end = mtx2_indptr[j], mtx2_indptr[j + 1]
                col_ind = mtx2_indices[col_ind_start:col_ind_end]
                row_old = np.zeros(mtx2_shape[1], dtype=data.dtype)
                row_old[col_ind] = mtx2_data[col_ind_start:col_ind_end]
                row_new[ind2_in_new_col] = row_old
            nonzero_ind = np.nonzero(row_new)[0]
            indptr[i + 1] = indptr[i] + nonzero_ind.size
            data[indptr[i]:indptr[i + 1]] = row_new[nonzero_ind]
            indices[indptr[i]:indptr[i + 1]] = nonzero_ind
            row_new[:] = 0
        return data, indices, indptr
    
    new_col = np.union1d(col1, col2)
    
    new_col_index = pd.Index(new_col)
    ind1_in_new_col = new_col_index.get_indexer(col1)
    ind2_in_new_col = new_col_index.get_indexer(col2)

    data, indices, indptr = __merge(
        new_col.size,
        mtx1.shape, mtx1.indptr, mtx1.indices, mtx1.data, ind1_in_new_col,
        mtx2.shape, mtx2.indptr, mtx2.indices, mtx2.data, ind2_in_new_col
    )

    return new_col, sp.csr_matrix((data, indices, indptr), shape=(mtx1.shape[0] + mtx2.shape[0], new_col.size))


def reorganize_data_coordinates(
        cells_batch: np.ndarray,
        data_position: np.ndarray,
        data_position_offset: dict = None,
        data_position_min: dict = None,
        reorganize_coordinate: Union[bool, int] = 2,
        horizontal_offset_additional: Union[int, float] = 0,
        vertical_offset_additional: Union[int, float] = 0
):
    if not reorganize_coordinate:
        return data_position, data_position_offset, data_position_min

    batches = natsorted(np.unique(cells_batch))
    data_count = len(batches)
    position_row_count = ceil(data_count / reorganize_coordinate)
    position_column_count = reorganize_coordinate
    max_xs = [0] * (position_column_count + 1)
    max_ys = [0] * (position_row_count + 1)

    if data_position_min is None:
        data_position_min = {}
        for i, bno in enumerate(batches):
            idx = np.where(cells_batch == bno)[0]
            position_min = np.min(data_position[idx], axis=0)
            data_position[idx] -= position_min
            data_position_min[bno] = position_min

    for i, bno in enumerate(batches):
        idx = np.where(cells_batch == bno)[0]
        data_position[idx] -= data_position_offset[bno] if data_position_offset is not None else 0
        position_row_number = i // reorganize_coordinate
        position_column_number = i % reorganize_coordinate
        max_x = data_position[idx][:, 0].max() - data_position[idx][:, 0].min() + 1
        max_y = data_position[idx][:, 1].max() - data_position[idx][:, 1].min() + 1
        if max_x > max_xs[position_column_number + 1]:
            max_xs[position_column_number + 1] = max_x
        if max_y > max_ys[position_row_number + 1]:
            max_ys[position_row_number + 1] = max_y

    data_position_offset = {}
    for i, bno in enumerate(batches):
        idx = np.where(cells_batch == bno)[0]
        position_row_number = i // reorganize_coordinate
        position_column_number = i % reorganize_coordinate
        x_add = max_xs[position_column_number]
        y_add = max_ys[position_row_number]
        if position_column_number > 0:
            x_add += sum(max_xs[0:position_column_number]) + horizontal_offset_additional * position_column_number
        if position_row_number > 0:
            y_add += sum(max_ys[0:position_row_number]) + vertical_offset_additional * position_row_number
        position_offset = np.array([x_add, y_add], dtype=data_position.dtype)
        data_position[idx] += position_offset
        data_position_offset[bno] = position_offset
    return data_position, data_position_offset, data_position_min

def __parse_space_between(space_between: str):
    import re
    if space_between == '0':
        return 0.0
    pattern = r"^\d+(\.\d)*(nm|um|mm|cm|dm|m)$"
    match = re.match(pattern, space_between)
    if match is None:
        raise ValueError(f"Invalid space between: '{space_between}'")
    unit = match.groups()[1]
    space_between = float(space_between.replace(unit, ''))
    if unit == 'um':
        space_between *= 1e3
    elif unit == 'mm':
        space_between *= 1e6
    elif unit == 'cm':
        space_between *= 1e7
    elif unit == 'dm':
        space_between *= 1e8
    elif unit == 'm':
        space_between *= 1e9
    return space_between

[docs]@singledispatch def merge( *data_list: Union[StereoExpData, AnnBasedStereoExpData], reorganize_coordinate: Union[bool, int] = False, horizontal_offset_additional: Union[int, float] = 0, vertical_offset_additional: Union[int, float] = 0, space_between: Optional[str] = '0', var_type: str = "intersect", batch_tags: Union[list, np.ndarray, pd.Series] = None ) -> Union[StereoExpData, AnnBasedStereoExpData]: """ Merge several slices of data. :param data_list: several slices of data to be merged, at least two slices. :param reorganize_coordinate: whether to reorganize the coordinates of the obs(cells), if set it to a number, like 2, the coordinates will be reorganized to 2 columns on coordinate system as below --------------- | data1 data2 | data3 data4 | data5 ... | ... ... --------------- if set to `False`, the coordinates maybe overlap between slices. :param horizontal_offset_additional: the additional offset between each slice on horizontal direction while reorganizing coordinates. :param vertical_offset_additional: the additional offset between each slice on vertical direction while reorganizing coordinates. :param space_between: the distance between each slice, like '10nm', '1um', ..., it will be used for calculating the z-coordinate of each slice. :param var_type: Which claims that `_var` is intersected by lots of `genes` from different samples. :return: A merged StereoExpData object. """ # noqa pass
@merge.register(StereoExpData) def __merge_for_stereo_exp_data( *data_list: StereoExpData, reorganize_coordinate: Union[bool, int] = False, horizontal_offset_additional: Union[int, float] = 0, vertical_offset_additional: Union[int, float] = 0, space_between: Optional[str] = '0', var_type: str = "intersect", batch_tags: Union[list, np.ndarray, pd.Series] = None ): if data_list is None or len(data_list) < 2: raise Exception("At least two slices of data need to be input.") space_between = __parse_space_between(space_between) data_count = len(data_list) new_data = StereoExpData(merged=True) new_data.sn = {} current_position_z = 0 issparse = data_list[0].issparse() for i in range(data_count): data: StereoExpData = data_list[i] batch = i if batch_tags is None or i >= len(batch_tags) else batch_tags[i] data.cells.batch = batch cell_names = np.char.add(data.cells.cell_name, f"-{i}") if issparse: data.array2sparse() new_data.sn[str(batch)] = data.sn if i == 0: new_data.exp_matrix = data.exp_matrix.copy() new_data.cells = Cell(cell_name=cell_names, cell_border=data.cells.cell_border, batch=data.cells.batch) new_data.genes = Gene(gene_name=data.gene_names) new_data.cells._obs = data.cells._obs.copy(deep=True) new_data.cells._obs.index = cell_names new_data.position = data.position if data.position_z is None: new_data.position_z = np.repeat([[0]], repeats=data.position.shape[0], axis=0).astype( data.position.dtype) else: new_data.position_z = data.position_z new_data.bin_type = data.bin_type new_data.bin_size = data.bin_size new_data.offset_x = data.offset_x new_data.offset_y = data.offset_y new_data.attr = data.attr else: current_obs = data.cells._obs.copy() current_obs.index = cell_names new_data.cells._obs = pd.concat([new_data.cells._obs, current_obs]) if new_data.cell_borders is not None and data.cell_borders is not None: new_data.cells.cell_border = np.concatenate([new_data.cells.cell_border, data.cells.cell_border]) new_data.position = np.concatenate([new_data.position, data.position]) if data.position_z is None: current_position_z += space_between / data.attr['resolution'] new_data.position_z = np.concatenate( [new_data.position_z, np.repeat([[current_position_z]], repeats=data.position.shape[0], axis=0)]) else: new_data.position_z = np.concatenate([new_data.position_z, data.position_z]) if var_type == "intersect": new_data.genes.gene_name, ind1, ind2 = \ np.intersect1d(new_data.genes.gene_name, data.genes.gene_name, return_indices=True) if issparse: new_data.exp_matrix = sp.vstack([new_data.exp_matrix[:, ind1], data.exp_matrix[:, ind2]]) else: new_data.exp_matrix = np.concatenate([new_data.exp_matrix[:, ind1], data.exp_matrix[:, ind2]]) elif var_type == "union": new_data.genes.gene_name, new_data.exp_matrix = _union_merge( new_data.exp_matrix, data.exp_matrix, new_data.genes.gene_name, data.genes.gene_name ) else: raise Exception(f"got an unexpected var_type: {var_type}") if new_data.offset_x is not None and data.offset_x is not None: new_data.offset_x = min(new_data.offset_x, data.offset_x) if new_data.offset_y is not None and data.offset_y is not None: new_data.offset_y = min(new_data.offset_y, data.offset_y) if new_data.attr is not None and data.attr is not None: for key, value in data.attr.items(): if key in ('minX', 'minY'): new_data.attr[key] = min(new_data.attr[key], value) elif key in ('maxX', 'maxY'): new_data.attr[key] = max(new_data.attr[key], value) elif key == 'minExp': new_data.attr['minExp'] = new_data.exp_matrix.min() elif key == 'maxExp': new_data.attr['maxExp'] = new_data.exp_matrix.max() elif key == 'resolution': new_data.attr['resolution'] = value if reorganize_coordinate: new_data.position, new_data.position_offset, new_data.position_min = reorganize_data_coordinates( new_data.cells.batch, new_data.position, new_data.position_offset, new_data.position_min, reorganize_coordinate, horizontal_offset_additional, vertical_offset_additional ) return new_data @merge.register(AnnBasedStereoExpData) def __merge_for_ann_based_stereo_exp_data( *data_list: AnnBasedStereoExpData, reorganize_coordinate: Union[bool, int] = False, horizontal_offset_additional: Union[int, float] = 0, vertical_offset_additional: Union[int, float] = 0, space_between: Optional[str] = '0', var_type: str = "intersect", batch_tags: Union[list, np.ndarray, pd.Series] = None ): if data_list is None or len(data_list) < 2: raise Exception("At least two slices of data need to be input.") space_between = __parse_space_between(space_between) current_position_z = 0 batches = [] sn = {} adata_list = [] position_z_list = [] offset_x = None offset_y = None attr = None for i, data in enumerate(data_list): if batch_tags is None or i >= len(batch_tags): batch = str(i) else: batch = str(batch_tags[i]) batches.append(batch) sn[batch] = data.sn adata_list.append(data.adata) if data.position_z is None: if i == 0: position_z = np.repeat([[0]], repeats=data.position.shape[0], axis=0).astype( data.position.dtype) else: current_position_z += space_between / data.attr['resolution'] position_z = np.repeat([[current_position_z]], repeats=data.position.shape[0], axis=0) else: position_z = data.position_z position_z_list.append(position_z) if i == 0: offset_x = data.offset_x offset_y = data.offset_y attr = data.attr else: if offset_x is not None and data.offset_x is not None: offset_x = min(offset_x, data.offset_x) if offset_y is not None and data.offset_y is not None: offset_y = min(offset_y, data.offset_y) if attr is not None and data.attr is not None: for key, value in data.attr.items(): if key in ('minX', 'minY'): attr[key] = min(attr[key], value) elif key in ('maxX', 'maxY'): attr[key] = max(attr[key], value) elif key == 'minExp': attr['minExp'] = min(attr['minExp'], data.exp_matrix.min()) elif key == 'maxExp': attr['maxExp'] = max(attr['maxExp'], data.exp_matrix.max()) elif key == 'resolution': attr['resolution'] = value adata_merged = ad.concat( adata_list, join='inner' if var_type != 'union' else 'outer', axis=0, label='batch', keys=batches, index_unique='-', merge='first', uns_merge='first' ) bin_type = data_list[0].bin_type bin_size = data_list[0].bin_size spatial_key = data_list[0].spatial_key new_data = AnnBasedStereoExpData( based_ann_data=adata_merged, bin_type=bin_type, bin_size=bin_size, spatial_key=spatial_key ) new_data.merged = True new_data.offset_x = offset_x new_data.offset_y = offset_y new_data.attr = attr new_data.sn = sn if new_data.adata.obsm[spatial_key].shape[1] == 2: position_z = np.concatenate(position_z_list, axis=0) new_data.position_z = position_z if reorganize_coordinate: new_data.position, new_data.position_offset, new_data.position_min = reorganize_data_coordinates( new_data.cells.batch, new_data.position, new_data.position_offset, new_data.position_min, reorganize_coordinate, horizontal_offset_additional, vertical_offset_additional ) return new_data
[docs]@singledispatch def split(data: Union[StereoExpData, AnnBasedStereoExpData] = None): """ Split a data object which is merged from different batches of data, according to the batch number. :param data: a merged data object. :return: A split data list. """ pass
@split.register(StereoExpData) def split_for_stereo_exp_data(data: StereoExpData = None): if data is None: return None from copy import deepcopy from .pipeline_utils import cell_cluster_to_gene_exp_cluster all_data = [] # data.array2sparse() batch = np.unique(data.cells.batch) result = data.tl.result for bno in batch: cell_idx = np.where(data.cells.batch == bno)[0] cell_names = data.cell_names[cell_idx] new_data = StereoExpData( bin_type=data.bin_type, bin_size=data.bin_size, cells=deepcopy(data.cells), genes=deepcopy(data.genes) ) new_data.cells = new_data.cells.sub_set(cell_idx) if data.position_offset is not None: new_data.position = data.position[cell_idx] - data.position_offset[bno] else: new_data.position = data.position[cell_idx] new_data.position_z = data.position_z[cell_idx] new_data.exp_matrix = data.exp_matrix[cell_idx] new_data.tl.key_record = deepcopy(data.tl.key_record) new_data.sn = data.sn[bno] for key, all_res_key in data.tl.key_record.items(): if len(all_res_key) == 0: continue if key == 'hvg': for res_key in all_res_key: new_data.tl.result[res_key] = result[res_key] elif key in ['pca', 'cluster', 'umap', 'totalVI']: for res_key in all_res_key: new_data.tl.result[res_key] = result[res_key].iloc[cell_idx] new_data.tl.result[res_key].reset_index(drop=True, inplace=True) elif key == 'neighbors': min_idx = cell_idx.min() max_idx = cell_idx.max() + 1 for res_key in all_res_key: connectivities = result[res_key]['connectivities'] nn_dist = result[res_key]['nn_dist'] new_data.tl.result[res_key] = { 'neighbor': result[res_key]['neighbor'], 'connectivities': connectivities[min_idx:max_idx, min_idx:max_idx], 'nn_dist': nn_dist[min_idx:max_idx, min_idx:max_idx] } elif key == 'marker_genes': for res_key in all_res_key: new_data.tl.result[res_key] = result[res_key] elif key == 'sct': for res_key in all_res_key: cells_bool_list = np.isin(result[res_key][0]['umi_cells'], cell_names) # sct `counts` and `data` should have same shape new_data.tl.result[res_key] = ( new_data, { 'cells': cell_names, 'genes': result[res_key][0]['umi_genes'], 'filtered_corrected_counts': result[res_key][0]['counts'][cells_bool_list, :], 'filtered_normalized_counts': result[res_key][0]['data'][cells_bool_list, :] } ) elif key == 'tsne': for res_key in all_res_key: new_data.tl.result[res_key] = result[res_key] elif key == 'gene_exp_cluster': continue else: for res_key in all_res_key: new_data.tl.result[res_key] = result[res_key] if data.tl.raw is not None: new_data.tl.raw = data.tl.raw.tl.filter_cells(cell_list=cell_names, inplace=False) if 'gene_exp_cluster' in data.tl.key_record: for cluster_res_key in data.tl.key_record['cluster']: gene_exp_cluster_res = cell_cluster_to_gene_exp_cluster(new_data, cluster_res_key) if gene_exp_cluster_res is not False: new_data.tl.result[f"gene_exp_{cluster_res_key}"] = gene_exp_cluster_res all_data.append(new_data) return all_data @split.register(AnnBasedStereoExpData) def split_for_ann_based_stereo_exp_data(data: AnnBasedStereoExpData = None): if data is None: return None from copy import deepcopy from .pipeline_utils import cell_cluster_to_gene_exp_cluster all_data = [] # data.array2sparse() batch = np.unique(data.cells.batch) for bno in batch: adata = data.adata[data.adata.obs['batch'] == bno] adata.uns = adata.uns.copy() new_data = AnnBasedStereoExpData(based_ann_data=adata, spatial_key=data.spatial_key) new_data.tl.key_record = deepcopy(data.tl.key_record) new_data.sn = data.sn[bno] if data.position_offset is not None: new_data.position = new_data.position - data.position_offset[bno] if data.tl.raw is not None: new_data.tl.raw = data.tl.raw.tl.filter_cells(cell_list=new_data.cells.cell_name, inplace=False) if 'gene_exp_cluster' in data.tl.key_record: for cluster_res_key in data.tl.key_record['cluster']: gene_exp_cluster_res = cell_cluster_to_gene_exp_cluster(new_data, cluster_res_key) if gene_exp_cluster_res is not False: new_data.tl.result[f"gene_exp_{cluster_res_key}"] = gene_exp_cluster_res all_data.append(new_data) return all_data