Source code for stereo.core.ms_pipeline

import os
from joblib import (
    Parallel,
    delayed,
    cpu_count
)

import numpy as np

from stereo.log_manager import logger
from stereo.core import StPipeline
from stereo.core.result import MSDataPipeLineResult
from stereo.plots.decorator import download, download_only


class _scope_slice(object):

    def __getitem__(self, item):
        if isinstance(item, (int, np.integer, str, np.str_)):
            return [item]
        else:
            return item


class MSDataPipeLine(object):
    ATTR_NAME = 'tl'
    BASE_CLASS = StPipeline

    def __init__(self, _ms_data):
        super().__init__()
        self.ms_data = _ms_data
        self._result = MSDataPipeLineResult(self.ms_data)
        self._result_keys = dict()
        self._key_record = dict()
        self.__mode = "integrate"
        self.__scope = slice(None)

    @property
    def result(self):
        return self._result

    # @result.setter
    # def result(self, new_result):
    #     self._result = new_result

    @property
    def key_record(self):
        return self._key_record

    @key_record.setter
    def key_record(self, key_record):
        self._key_record = key_record
    
    @property
    def result_keys(self):
        return self._result_keys
    
    @result_keys.setter
    def result_keys(self, result_keys):
        self._result_keys = result_keys
        self._reset_result_keys()
    
    @property
    def mode(self):
        return self.__mode
    
    @mode.setter
    def mode(self, mode):
        self.__mode = mode
    
    @property
    def scope(self):    
        return self.__scope
    
    @scope.setter
    def scope(self, scope):
        self.__scope = scope
    
    def _reset_result_keys(self):
        for scope_key, result_keys in self._result_keys.items():
            self._result_keys[scope_key] = []
            for rk in result_keys:
                if rk in self.result[scope_key]:
                    self._result_keys[scope_key].append(rk)

    def _use_integrate_method(self, item, *args, **kwargs):
        if "mode" in kwargs:
            del kwargs["mode"]

        scope = kwargs.get("scope", slice(None))
        del kwargs["scope"]

        # if item in {"cal_qc", "filter_cells", "filter_genes", "sctransform", "log1p", "normalize_total",
        #             "scale", "raw_checkpoint", "batches_integrate"}:
        #     if scope != slice(None):
        #         raise Exception(f'{item} use integrate should use all sample')
        #     ms_data_view = self.ms_data
        # elif scope == slice(None):
        if len(self.ms_data[scope]) == len(self.ms_data):
            ms_data_view = self.ms_data
            if ms_data_view.merged_data is None:
                ms_data_view.integrate()
        else:
            ms_data_view = self.ms_data[scope]

        scope_key = self.ms_data.generate_scope_key(ms_data_view._names)
        self.ms_data.scopes_data[scope_key] = ms_data_view.merged_data

        # def callback_func(key, value):
        #     # key_name = "scope_[" + ",".join(
        #     #     [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]"
        #     self.ms_data.tl.result.setdefault(key_name, {})
        #     self.ms_data.tl.result[key_name][key] = value

        # ms_data_view._merged_data.tl.result.set_item_callback = callback_func

        # def get_item_method(name):
        #     # key_name = "scope_[" + ",".join(
        #     #     [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]"
        #     scope_result = self.ms_data.tl.result.get(key_name, None)
        #     if scope_result is None:
        #         raise KeyError
        #     method_result = scope_result.get(name, None)
        #     return method_result

        # ms_data_view._merged_data.tl.result.get_item_method = get_item_method

        # def contain_method(item):
        #     # key_name = "scope_[" + ",".join(
        #     #     [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]"
        #     scope_result = self.ms_data.tl.result.get(key_name, None)
        #     if scope_result is None:
        #         return False
        #     method_result = scope_result.get(item, None)
        #     if method_result is None:
        #         return False
        #     return True

        # ms_data_view._merged_data.tl.result.contain_method = contain_method

        # def reset_key_record(key, res_key):
        #     # key_name = "scope_[" + ",".join(
        #     #     [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]"

        #     ms_data_view._merged_data.tl._reset_key_record(key, res_key)
        #     self._key_record[key_name] = ms_data_view._merged_data.tl.key_record

        # ms_data_view._merged_data.tl.reset_key_record = reset_key_record

        new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
        if new_attr is None:
            if self.__class__.ATTR_NAME == "tl":
                from stereo.algorithm.algorithm_base import AlgorithmBase
                merged_data = ms_data_view.merged_data
                new_attr = AlgorithmBase.get_attribute_helper(item, merged_data, merged_data.tl.result)
                if new_attr:
                    logger.info(f'register algorithm {item} to {type(merged_data)}-{id(merged_data)}')
                    return new_attr(*args, **kwargs)
            else:
                from stereo.plots.plot_base import PlotBase
                merged_data = ms_data_view.merged_data
                new_attr = download(PlotBase.get_attribute_helper(item, merged_data, merged_data.tl.result))
                if new_attr:
                    logger.info(f'register plot_func {item} to {type(merged_data)}-{id(merged_data)}')
                    return new_attr(*args, **kwargs)

        logger.info(f'data_obj(idx=0) in ms_data start to run {item}')
        return new_attr(
            ms_data_view.merged_data.__getattribute__(self.__class__.ATTR_NAME),
            *args,
            **kwargs
        )

        # def log_delayed_task(idx, *arg, **kwargs):
        #     logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
        #     return new_attr(*arg, **kwargs)

        # return log_delayed_task(
        #     0,
        #     ms_data_view.merged_data.__getattribute__(self.__class__.ATTR_NAME),
        #     *args,
        #     **kwargs
        # )

    def _run_isolated_method(self, item, *args, **kwargs):
        if "mode" in kwargs:
            del kwargs["mode"]
        ms_data_view = self.ms_data[kwargs["scope"]]
        if "scope" in kwargs:
            del kwargs["scope"]

        new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
        if self.__class__.ATTR_NAME == 'tl':
            n_jobs = min(len(ms_data_view.data_list), cpu_count())
        else:
            n_jobs = 1
        if new_attr:
            def log_delayed_task(idx, *arg, **kwargs):
                logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
                if self.__class__.ATTR_NAME == 'plt':
                    out_path = kwargs.get('out_path', None)
                    if out_path is not None:
                        path_name, ext = os.path.splitext(out_path)
                        kwargs['out_path'] = f'{path_name}_{idx}{ext}'
                new_attr(*arg, **kwargs)

            Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
                delayed(log_delayed_task)(idx, obj.__getattribute__(self.__class__.ATTR_NAME), *args, **kwargs)
                for idx, obj in enumerate(ms_data_view.data_list)
            )
        else:
            if self.__class__.ATTR_NAME == 'tl':
                from stereo.algorithm.algorithm_base import AlgorithmBase
                base = AlgorithmBase
            else:
                from stereo.plots.plot_base import PlotBase
                base = PlotBase

            def log_delayed_task(idx, obj, *arg, **kwargs):
                logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
                new_attr = base.get_attribute_helper(item, obj, obj.tl.result)
                if base == PlotBase:
                    out_path = kwargs.get('out_path', None)
                    if out_path is not None:
                        path_name, ext = os.path.splitext(out_path)
                        kwargs['out_path'] = f'{path_name}_{idx}{ext}'
                    new_attr = download_only(new_attr)
                if new_attr:
                    new_attr(*arg, **kwargs)
                else:
                    raise Exception

            Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
                delayed(log_delayed_task)(idx, obj, *args, **kwargs)
                for idx, obj in enumerate(ms_data_view.data_list)
            )

    def __getattr__(self, item):
        dict_attr = self.__dict__.get(item, None)
        if dict_attr:
            return dict_attr

        # start with __ may not be our algorithm function, and will cause import problem
        if item.startswith('__'):
            raise AttributeError

        if self.__class__.ATTR_NAME == 'tl':
            from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
            run_method = MSDataAlgorithmBase.get_attribute_helper(item, self.ms_data, self.result)
            if run_method:
                return run_method
        elif self.__class__.ATTR_NAME == 'plt':
            from stereo.plots.ms_plot_base import MSDataPlotBase
            run_method = MSDataPlotBase.get_attribute_helper(item, self.ms_data, self.ms_data.tl.result)
            if run_method:
                return download(run_method)

        def temp(*args, **kwargs):
            if "scope" not in kwargs:
                # kwargs["scope"] = slice_generator[:]
                kwargs["scope"] = self.__scope
            if "mode" not in kwargs:
                kwargs["mode"] = self.__mode

            if kwargs["mode"] == "integrate":
                return self._use_integrate_method(item, *args, **kwargs)
            elif kwargs["mode"] == "isolated":
                self._run_isolated_method(item, *args, **kwargs)
            else:
                raise Exception("`mode` should be one of [`integrate`, `isolated`]")
            # if "mode" in kwargs:
            #     if kwargs["mode"] == "integrate":
            #         # if self.ms_data.merged_data:
            #         #     return self._use_integrate_method(item, *args, **kwargs)
            #         # else:
            #         #     raise Exception(
            #         #         "`mode` integrate should merge first, using `ms_data.integrate`"
            #         #     )
            #         return self._use_integrate_method(item, *args, **kwargs)
            #     elif kwargs["mode"] == "isolated":
            #         self._run_isolated_method(item, *args, **kwargs)
            #     else:
            #         raise Exception("`mode` should be one of [`integrate`, `isolated`]")
            # else:
            #     # if self.ms_data.merged_data:
            #     #     return self._use_integrate_method(item, *args, **kwargs)
            #     # else:
            #     #     raise Exception(
            #     #         "`mode` integrate should merge first, using `ms_data.integrate`"
            #     #     )
            #     return self._use_integrate_method(item, *args, **kwargs)

        return temp
    
[docs] def set_scope_and_mode( self, scope: slice = slice(None), mode: str = "integrate" ): """ Set the `scope` and `mode` globally for Multi-slice analysis. :param scope: the scope, defaults to slice(None) :param mode: the mode, defaults to "integrate" """ assert mode in ("integrate", "isolated"), 'mode should be one of [`integrate`, `isolated`]' self.__mode = mode self.__scope = scope if self.__class__.ATTR_NAME == 'tl': self.ms_data.plt.scope = scope self.ms_data.plt.mode = mode elif self.__class__.ATTR_NAME == 'plt': self.ms_data.tl.scope = scope self.ms_data.tl.mode = mode else: pass
slice_generator = _scope_slice()