import os
from joblib import (
Parallel,
delayed,
cpu_count
)
import numpy as np
from stereo.log_manager import logger
from stereo.core import StPipeline
from stereo.core.result import MSDataPipeLineResult
from stereo.plots.decorator import download, download_only
class _scope_slice(object):
def __getitem__(self, item):
if isinstance(item, (int, np.integer, str, np.str_)):
return [item]
else:
return item
class MSDataPipeLine(object):
ATTR_NAME = 'tl'
BASE_CLASS = None
def __init__(self, _ms_data):
super().__init__()
self.ms_data = _ms_data
self._result = MSDataPipeLineResult(self.ms_data)
self._result_keys = dict()
self._key_record = dict()
self.__mode = "integrate"
self.__scope = slice(None)
self.__class__.BASE_CLASS = getattr(self.ms_data[0], self.__class__.ATTR_NAME).__class__
@property
def result(self):
return self._result
# @result.setter
# def result(self, new_result):
# self._result = new_result
@property
def key_record(self):
return self._key_record
@key_record.setter
def key_record(self, key_record):
self._key_record = key_record
@property
def result_keys(self):
return self._result_keys
@result_keys.setter
def result_keys(self, result_keys):
self._result_keys = self._reset_result_keys(result_keys)
@property
def mode(self):
return self.__mode
@mode.setter
def mode(self, mode):
self.__mode = mode
@property
def scope(self):
return self.__scope
@scope.setter
def scope(self, scope):
self.__scope = scope
def _reset_result_keys(self, origin_result_keys: dict = None):
result_keys = {}
for scope_key, scope_result_keys in origin_result_keys.items():
result_keys[scope_key] = []
for rk in scope_result_keys:
if rk in self.result[scope_key]:
result_keys[scope_key].append(rk)
return result_keys
def _use_integrate_method(self, item, *args, **kwargs):
if "mode" in kwargs:
del kwargs["mode"]
scope = kwargs.get("scope", slice(None))
del kwargs["scope"]
if len(self.ms_data[scope]) == len(self.ms_data):
ms_data_view = self.ms_data
if ms_data_view.merged_data is None:
ms_data_view.integrate()
else:
ms_data_view = self.ms_data[scope]
scope_key = self.ms_data.generate_scope_key(ms_data_view._names)
self.ms_data.scopes_data[scope_key] = ms_data_view.merged_data
# new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
new_attr = getattr(self.__class__.BASE_CLASS, item, None)
if new_attr is None:
if self.__class__.ATTR_NAME == "tl":
from stereo.algorithm.algorithm_base import AlgorithmBase
merged_data = ms_data_view.merged_data
new_attr = AlgorithmBase.get_attribute_helper(item, merged_data, merged_data.tl.result)
if new_attr:
logger.info(f'register algorithm {item} to {type(merged_data)}-{id(merged_data)}')
return new_attr(*args, **kwargs)
else:
from stereo.plots.plot_base import PlotBase
merged_data = ms_data_view.merged_data
new_attr = download(PlotBase.get_attribute_helper(item, merged_data, merged_data.tl.result))
if new_attr:
logger.info(f'register plot_func {item} to {type(merged_data)}-{id(merged_data)}')
return new_attr(*args, **kwargs)
logger.info(f'data_obj(idx=0) in ms_data start to run {item}')
return new_attr(
# ms_data_view.merged_data.__getattribute__(self.__class__.ATTR_NAME),
getattr(ms_data_view.merged_data, self.__class__.ATTR_NAME),
*args,
**kwargs
)
def _run_isolated_method(self, item, *args, **kwargs):
if "mode" in kwargs:
del kwargs["mode"]
ms_data_view = self.ms_data[kwargs["scope"]]
if "scope" in kwargs:
del kwargs["scope"]
# new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
new_attr = getattr(self.__class__.BASE_CLASS, item, None)
if self.__class__.ATTR_NAME == 'tl':
n_jobs = min(len(ms_data_view.data_list), cpu_count())
else:
n_jobs = 1
if new_attr:
def log_delayed_task(idx, obj, *arg, **kwargs):
logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
if self.__class__.ATTR_NAME == 'plt':
out_path = kwargs.get('out_path', None)
if out_path is not None:
path_name, ext = os.path.splitext(out_path)
kwargs['out_path'] = f'{path_name}_{idx}{ext}'
tl_or_plt = getattr(obj, self.__class__.ATTR_NAME)
new_attr(tl_or_plt, *arg, **kwargs)
Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
# delayed(log_delayed_task)(idx, obj.__getattribute__(self.__class__.ATTR_NAME), *args, **kwargs)
delayed(log_delayed_task)(idx, obj, *args, **kwargs)
for idx, obj in enumerate(ms_data_view.data_list)
)
else:
if self.__class__.ATTR_NAME == 'tl':
from stereo.algorithm.algorithm_base import AlgorithmBase
base = AlgorithmBase
else:
from stereo.plots.plot_base import PlotBase
base = PlotBase
def log_delayed_task(idx, obj, *arg, **kwargs):
logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
new_attr = base.get_attribute_helper(item, obj, obj.tl.result)
if base.__name__ == 'PlotBase':
out_path = kwargs.get('out_path', None)
if out_path is not None:
path_name, ext = os.path.splitext(out_path)
kwargs['out_path'] = f'{path_name}_{idx}{ext}'
new_attr = download_only(new_attr)
if new_attr:
new_attr(*arg, **kwargs)
else:
raise Exception
Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
delayed(log_delayed_task)(idx, obj, *args, **kwargs)
for idx, obj in enumerate(ms_data_view.data_list)
)
def __getattr__(self, item):
dict_attr = self.__dict__.get(item, None)
if dict_attr:
return dict_attr
# start with __ may not be our algorithm function, and will cause import problem
if item.startswith('__'):
raise AttributeError
if self.__class__.ATTR_NAME == 'tl':
from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
run_method = MSDataAlgorithmBase.get_attribute_helper(item, self.ms_data, self.result)
if run_method:
return run_method
elif self.__class__.ATTR_NAME == 'plt':
from stereo.plots.ms_plot_base import MSDataPlotBase
run_method = MSDataPlotBase.get_attribute_helper(item, self.ms_data, self.ms_data.tl.result)
if run_method:
return download(run_method)
def temp(*args, **kwargs):
if "scope" not in kwargs:
# kwargs["scope"] = slice_generator[:]
kwargs["scope"] = self.__scope
if "mode" not in kwargs:
kwargs["mode"] = self.__mode
if kwargs["mode"] == "integrate":
return self._use_integrate_method(item, *args, **kwargs)
elif kwargs["mode"] == "isolated":
self._run_isolated_method(item, *args, **kwargs)
else:
raise Exception("`mode` should be one of [`integrate`, `isolated`]")
return temp
[docs] def set_scope_and_mode(
self,
scope: slice = slice(None),
mode: str = "integrate"
):
"""
Set the `scope` and `mode` globally for Multi-slice analysis.
:param scope: the scope, defaults to slice(None)
:param mode: the mode, defaults to "integrate"
"""
assert mode in ("integrate", "isolated"), 'mode should be one of [`integrate`, `isolated`]'
self.__mode = mode
self.__scope = scope
if self.__class__.ATTR_NAME == 'tl':
self.ms_data.plt.scope = scope
self.ms_data.plt.mode = mode
elif self.__class__.ATTR_NAME == 'plt':
self.ms_data.tl.scope = scope
self.ms_data.tl.mode = mode
else:
pass
slice_generator = _scope_slice()