Source code for stereo.tools.cell_correct

# @FileName : cell_correct.py
# @Time     : 2022-05-26 14:14:27
# @Author   : TanLiWei
# @Email    : tanliwei@genomics.cnW
import os
import time
from multiprocessing import cpu_count
from typing import Literal

import numba
import numpy as np
import pandas as pd
from gefpy import (
    cgef_writer_cy,
    bgef_writer_cy,
    cgef_adjust_cy,
    gef_to_gem_cy
)

from ..algorithm import cell_correction_fast
from ..algorithm import cell_correction_fast_by_mask
from ..algorithm.cell_correction import CellCorrection
from ..algorithm.draw_contours import DrawContours
from ..io import read_gef
from ..log_manager import logger
from ..utils.time_consume import TimeConsume
from ..utils.time_consume import log_consumed_time


@log_consumed_time
@numba.njit(cache=True, parallel=True, nogil=True)
def generate_cell_and_dnb(adjusted_data: np.ndarray):
    cells_list = adjusted_data[:, 3]
    cells_idx_sorted = np.argsort(cells_list)
    adjusted_data = adjusted_data[cells_idx_sorted]
    cell_data = []
    dnb_data = []
    last_cell = -1
    cellid = -1
    offset = -1
    count = -1
    for i, row in enumerate(adjusted_data):
        current_cell = row[3]
        if current_cell != last_cell:
            if last_cell >= 0:
                cell_data.append((cellid, offset, count))
            cellid, offset, count = current_cell, i, 1
            last_cell = current_cell
        else:
            count += 1
        dnb_data.append((row[0], row[1], row[2], row[4]))
    cell_data.append((cellid, offset, count))
    return cell_data, dnb_data


class CellCorrect(object):

    def __init__(self, gem_path=None, bgef_path=None, raw_cgef_path=None, mask_path=None, out_dir=None):
        self.tc = TimeConsume()
        self.gem_path = gem_path
        self.bgef_path = bgef_path
        self.raw_cgef_path = raw_cgef_path
        self.mask_path = mask_path
        self.new_mask_path = None
        self.out_dir = out_dir
        self.cad = cgef_adjust_cy.CgefAdjust()
        self.gene_names = None
        self.check_input()
        self.sn = self.get_sn()

    def check_input(self):
        if self.bgef_path is None and self.gem_path is None:
            raise Exception("must to input gem file or bgef file")

        if self.out_dir is None:
            now = time.strftime("%Y%m%d%H%M%S")
            self.out_dir = f"./cell_correct_result_{now}"
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)

        if self.bgef_path is None:
            self.bgef_path = self.generate_bgef()

    def get_sn(self):
        if self.bgef_path is not None:
            file_name = os.path.basename(self.bgef_path)
        else:
            file_name = os.path.basename(self.gem_path)
        return file_name.split('.')[0]

    def get_file_name(self, ext=None):
        ext = ext.lstrip('.') if ext is not None else ""
        if self.bgef_path is not None:
            file_name = os.path.basename(self.bgef_path)
        else:
            file_name = os.path.basename(self.gem_path)
        file_prefix = file_name.split('.')[0]
        if ext == "":
            return file_prefix
        else:
            return f"{file_prefix}.{ext}"

    @log_consumed_time
    def generate_bgef(self, threads=10):
        file_name = self.get_file_name('bgef')
        bgef_path = os.path.join(self.out_dir, file_name)
        if os.path.exists(bgef_path):
            os.remove(bgef_path)
        bgef_writer_cy.generate_bgef(self.gem_path, bgef_path, n_thread=threads, bin_sizes=[1])
        return bgef_path

    def generate_cgef_with_mask(self, mask_path, ext_in_ext):
        file_name = self.get_file_name(f'{ext_in_ext}.cellbin.gef')
        cgef_path = os.path.join(self.out_dir, file_name)
        logger.info(f"start to generate cellbin gef ({cgef_path})")
        if os.path.exists(cgef_path):
            os.remove(cgef_path)
        tk = self.tc.start()
        cgef_writer_cy.generate_cgef(cgef_path, self.bgef_path, mask_path, [256, 256])
        logger.info(
            f"generate cellbin gef finished, time consumed : {self.tc.get_time_consumed(key=tk, restart=False)}")
        return cgef_path

    def get_data_from_bgef_and_cgef(self, bgef_path, cgef_path, sample_n=-1):
        tk = self.tc.start()
        logger.info("start to get data from bgef and cgef")
        genes, data = self.cad.get_cell_data(bgef_path, cgef_path)
        logger.info(f"get data finished, time consumed : {self.tc.get_time_consumed(tk)}")
        genes = pd.DataFrame(genes, columns=['geneID']).reset_index().rename(columns={'index': 'geneid'})
        data = pd.DataFrame(data.tolist(), dtype='int32').rename(columns={'midcnt': 'UMICount', 'cellid': 'label'})
        data = pd.merge(data, genes, on=['geneid'])[['geneID', 'x', 'y', 'UMICount', 'label', 'geneid']]
        if sample_n > 0:
            logger.info(f"sample {sample_n} from raw data")
            data = data.sample(sample_n, replace=False)
        logger.info(f"merged genes to data, time consumed : {self.tc.get_time_consumed(tk)}")

        return genes, data

    @log_consumed_time
    def generate_raw_data(self, sample_n=-1):
        if self.raw_cgef_path is None:
            self.raw_cgef_path = self.generate_cgef_with_mask(self.mask_path, 'raw')

        logger.info("start to generate raw data")
        genes, raw_data = self.get_data_from_bgef_and_cgef(self.bgef_path, self.raw_cgef_path, sample_n=sample_n)
        return genes, raw_data

    @log_consumed_time
    def generate_raw_gem(self, raw_data: pd.DataFrame):
        file_name = self.get_file_name('raw.gem')
        raw_gem_path = os.path.join(self.out_dir, file_name)
        raw_data.to_csv(raw_gem_path, sep="\t", index=False, columns=['geneID', 'x', 'y', 'UMICount', 'label'])

    @log_consumed_time
    def generate_adjusted_cgef(self, adjusted_data: pd.DataFrame, outline_path):
        adjusted_data_np = adjusted_data[['x', 'y', 'UMICount', 'label', 'geneid']].to_numpy(dtype=np.uint32)
        cell_data, dnb_data = generate_cell_and_dnb(adjusted_data_np)
        cell_type = np.dtype({
            'names': ['cellid', 'offset', 'count'],
            'formats': [np.uint32, np.uint32, np.uint32]
        }, align=True)
        dnb_type = np.dtype({
            'names': ['x', 'y', 'count', 'gene_id'],
            'formats': [np.int32, np.int32, np.uint16, np.uint32]
        }, align=True)
        cell = np.array(cell_data, dtype=cell_type)
        dnb = np.array(dnb_data, dtype=dnb_type)
        file_name = self.get_file_name('adjusted.cellbin.gef')
        cgef_file_adjusted = os.path.join(self.out_dir, file_name)
        if os.path.exists(cgef_file_adjusted):
            os.remove(cgef_file_adjusted)
        if outline_path is not None:
            self.cad.write_cgef_adjustdata(cgef_file_adjusted, cell, dnb, outline_path)
        else:
            self.cad.write_cgef_adjustdata(cgef_file_adjusted, cell, dnb)
        logger.info(f"generate adjusted cellbin gef finished ({cgef_file_adjusted})")
        return cgef_file_adjusted

    @log_consumed_time
    def generate_adjusted_gem(self, adjusted_data: pd.DataFrame):
        file_name = self.get_file_name("adjusted.gem")
        gem_file_adjusted = os.path.join(self.out_dir, file_name)
        columns = ['geneID', 'x', 'y', 'UMICount', 'label']
        if 'tag' in adjusted_data.columns:
            columns.append('tag')
        adjusted_data.to_csv(gem_file_adjusted, sep="\t", index=False, columns=columns)
        logger.info(f"generate adjusted gem finished ({gem_file_adjusted})")
        return gem_file_adjusted

    @log_consumed_time
    def cgef_to_gem(self, cgef_path):
        file_name = self.get_file_name("adjusted.gem")
        gem_file_adjusted = os.path.join(self.out_dir, file_name)
        obj = gef_to_gem_cy.gefToGem(gem_file_adjusted, self.sn)
        obj.cgef2gem(cgef_path, self.bgef_path)
        return gem_file_adjusted

    @log_consumed_time
    def bgef_to_gem(self, mask_path):
        file_name = self.get_file_name("adjusted.gem")
        gem_file_adjusted = os.path.join(self.out_dir, file_name)
        obj = gef_to_gem_cy.gefToGem(gem_file_adjusted, self.sn)
        obj.bgef2cgem(mask_path, self.bgef_path)

    def __set_processes_count(self, process_count, method):
        if process_count is not None:
            if not isinstance(process_count, int):
                raise TypeError("the type of prameter 'process_count' must be int.")

        if method == 'GMM':
            if process_count is None or process_count == 0:
                process_count = 10 if cpu_count() > 10 else cpu_count()
            elif process_count < 0 or process_count > cpu_count():
                process_count = cpu_count()
        elif method == 'FAST':
            process_count = 1
        elif method == 'EDM':
            if process_count is None or process_count == 0:
                process_count = 1
            elif process_count < 0 or process_count > cpu_count():
                process_count = cpu_count()
        else:
            pass
        return process_count

    @log_consumed_time
    def correcting(self,
                   threshold=20,
                   process_count=None,
                   only_save_result=False,
                   sample_n=-1,
                   method='EDM',
                   distance=10,
                   **kwargs
                   ):
        if method is None:
            method = 'EDM'
        method = method.upper()
        process_count = self.__set_processes_count(process_count, method)
        if method in ('GMM', 'FAST'):
            genes, raw_data = self.generate_raw_data(sample_n)
        if method == 'GMM':
            correction = CellCorrection(self.mask_path, raw_data, threshold, process_count, err_log_dir=self.out_dir)
            adjusted_data = correction.cell_correct()
        elif method == 'FAST':
            adjusted_data = cell_correction_fast.cell_correct(raw_data, self.mask_path)
        elif method == 'EDM':
            n_split_data_jobs = kwargs.get('n_split_data_jobs', -1)
            new_mask_path = cell_correction_fast_by_mask.main(
                self.mask_path,
                n_jobs=process_count,
                distance=distance,
                out_path=self.out_dir,
                n_split_data_jobs=n_split_data_jobs
            )
            cgef_file_adjusted = self.generate_cgef_with_mask(new_mask_path, 'adjusted')
        else:
            raise ValueError(
                f"Unexpected value({method}) for parameter method, available values include ['GMM', 'FAST', 'EDM'].")

        if method in ('GMM', 'FAST'):
            dc = DrawContours(adjusted_data, self.out_dir)
            outline_path = dc.get_contours()
            cgef_file_adjusted = self.generate_adjusted_cgef(adjusted_data, outline_path)

        if not only_save_result:
            return read_gef(cgef_file_adjusted, bin_type='cell_bins')
        else:
            return cgef_file_adjusted


[docs]@log_consumed_time def cell_correct(out_dir: str, threshold: int = 20, gem_path: str = None, bgef_path: str = None, raw_cgef_path: str = None, mask_path: str = None, process_count: int = None, only_save_result: bool = False, method: Literal['GMM', 'FAST', 'EDM'] = 'EDM', distance: int = 10, **kwargs ): """ Correct cells using one of file conbinations as following: * GEM and mask * BGEF and mask * GEM and raw CGEF (not corrected) * BGEF and raw CGEF (not corrected) :param out_dir: the path to save intermediate result, like mask (if generated from ssDNA image), BGEF (generated from GEM), CGEF (generated from GEM and mask), etc. and final corrected result. :param threshold: threshold size, default to 20. :param gem_path: the path to GEM file. :param bgef_path: the path to BGEF file. :param raw_cgef_path: the path to CGEF file which not has been corrected. :param mask_path: the path to mask file. :param process_count: the count of the processes or threads will be started when correct cells, defaults to None by default, it will be set to 10 when `method` is set to 'GMM' and will be set to 1 when `method` is set to 'FAST' or 'EDM'. if it is set to -1, all of the cores will be used. :param only_save_result: if `True`, only save result to disk; if `False`, return an StereoExpData object. :param method: correct in different method if `method` is set, otherwise `EDM`. :param distance: outspread distance based on cellular contour of cell segmentation image, in pixels, only available for 'EDM' method. :return: An StereoExpData object if `only_save_result` is set to `False`, otherwise the path of corrected CGEF file. """ # noqa cc = CellCorrect(gem_path=gem_path, bgef_path=bgef_path, raw_cgef_path=raw_cgef_path, mask_path=mask_path, out_dir=out_dir) return cc.correcting(threshold=threshold, process_count=process_count, only_save_result=only_save_result, method=method, distance=distance, **kwargs)