from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Dict, Union, Literal, Optional
from copy import deepcopy
import numpy as np
import pandas as pd
from . import(
StPipeline, AnnBasedStPipeline,
StereoExpData, AnnBasedStereoExpData
)
from .ms_pipeline import MSDataPipeLine
from ..plots.plot_collection import PlotCollection
def _default_idx() -> int:
# return -1 function `__get_auto_key` will start with 0 instead of 1
return -1
@dataclass
class _MSDataView(object):
_msdata: MSData = None
_names: List[str] = field(default_factory=list)
_data_list: List[StereoExpData] = field(default_factory=list)
_name_dict: Dict[str, StereoExpData] = field(default_factory=dict)
_merged_data: StereoExpData = None
_tl = None
_plt = None
def __post_init__(self):
for name, data in zip(self._names, self._data_list):
self._name_dict[name] = data
def __get_data_list(self, key_idx_list):
data_list = []
names = []
for ki in key_idx_list:
if isinstance(ki, (str, np.str_)):
data_list.append(self._name_dict[ki])
names.append(ki)
elif isinstance(ki, (int, np.integer)):
ki = int(ki)
data_list.append(self._data_list[ki])
names.append(self._names[ki])
elif isinstance(ki, (list, tuple, np.ndarray, pd.Index)):
temp_data_list, temp_names = self.__get_data_list(ki)
data_list.extend(temp_data_list)
names.extend(temp_names)
else:
raise KeyError(ki)
return data_list, names
def __check_slice(self, slice_obj: slice):
if not isinstance(slice_obj, slice):
raise TypeError(f'{slice_obj} should be slice')
if slice_obj.start is not None and isinstance(slice_obj.start, (str, np.str_)):
if slice_obj.start in self._name_dict:
new_start = self._names.index(slice_obj.start)
else:
new_start = None
else:
new_start = slice_obj.start
if slice_obj.stop is not None and isinstance(slice_obj.stop, (str, np.str_)):
if slice_obj.stop in self._name_dict:
new_stop = self._names.index(slice_obj.stop)
else:
new_stop = None
else:
new_stop = slice_obj.stop
if slice_obj.step is not None and not isinstance(slice_obj.step, (int, np.integer)):
raise TypeError(f'slice.step should be int')
return slice(new_start, new_stop, slice_obj.step)
def __getitem__(self, key: Union[str, int, list, tuple, np.ndarray, pd.Index, slice]) -> Union[StereoExpData, _MSDataView]:
if isinstance(key, (str, np.str_)):
return self._name_dict[key]
elif isinstance(key, (int, np.integer)):
return self._data_list[key]
elif isinstance(key, (list, tuple, np.ndarray, pd.Index)):
data_list, names = self.__get_data_list(key)
return _MSDataView(_msdata=self._msdata, _data_list=data_list, _names=names)
elif isinstance(key, slice):
key = self.__check_slice(key)
data_list = self._data_list[key]
names = self._names[key]
return _MSDataView(_msdata=self._msdata, _data_list=data_list, _names=names)
else:
raise KeyError(key)
@property
def tl(self):
if self._tl is None:
self._tl = TL(self)
return self._tl
@property
def plt(self):
if self._plt is None:
self._plt = PLT(self)
return self._plt
@property
def data_list(self):
return self._data_list
@property
def names(self):
return self._names
@property
def num_slice(self):
return len(self._data_list)
def __str__(self):
return f'''data_list: {len(self._data_list)}'''
def __len__(self):
return len(self._data_list)
@property
def merged_data(self):
if self._merged_data is None:
self._merged_data = self._msdata.integrate(scope=self._names)
return self._merged_data
@merged_data.setter
def merged_data(self, merged_data):
self._merged_data = merged_data
def to_msdata(self) -> MSData:
return MSData(
_data_list=deepcopy(self._data_list),
_merged_data=deepcopy(self._merged_data),
_names=deepcopy(self._names),
_var_type=self._msdata._var_type,
_relationship=self._msdata._relationship,
_relationship_info=deepcopy(self._msdata._relationship_info)
)
_NON_EDITABLE_ATTRS = {'data_list', 'names', '_obs', '_var', '_relationship', '_relationship_info'}
_RELATIONSHIP_ENUM = {'continuous', 'time_series', 'other'}
@dataclass
class _MSDataStruct(object):
"""
`MSData` is a composite structure of several `StereoExpData` organized by some relationship.
Parameters
----------
data_list: List[StereoExpData] `stereo_exp_data` array
An array of `stereo_exp_data` organized by some relationship defined by `_relationship` and `_relationship_info`
merged_data: `stereo_exp_data` object
An `stereo_exp_data` merged with `data_list` used batches integrate.
names: List[str] `stereo_exp_data` array's names
An array of `stereo_exp_data`s' unique names.
obs: pandas.DataFrame = None
`pandas.DataFrame` describes all the cells or bins observed, indexes mean cell names or bin names, columns mean
some math statistic or types produced by bio-information algorithm.
var: pd.DataFrame = None
`pandas.DataFrame` describes genes, similar to `_obs`.
_var_type: str = 'intersect'
Which claims that `_var` is intersected by lots of `genes` from different samples.
relationship: str = 'other'
Relationship about samples in `_data_list`.
_relationship_info: object
Relationship extra info.
tl: object
`MSData` algorithms collections, include all tools methods inherited from `stereo_exp_data` and multi-samples
methods. Methods from `stereo_exp_data` will organized with mutilthreads while running.
plt: object
`MSData` plot methods collections, same as `tl`.
Examples
--------
Constructing MSData from two `stereo_exp_data`s.
>>> from stereo.io.reader import read_gef
>>> from stereo.core.ms_data import MSData
>>> data1 = read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.gef")
>>> data2 = read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.tissue.gef")
>>> ms_data = MSData(_data_list=[data1, data2], _names=['raw', 'tissue'], _relationship='other', _var_type='intersect') # noqa
>>> ms_data
ms_data: {'raw': (9004, 25523), 'tissue': (9111, 20816)}
num_slice: 2
names: ['raw', 'tissue']
obs: ['test_obs_1']
var: ['test_var_1']
relationship: other
var_type: intersect to 20760
tl.result: defaultdict(<class 'list'>, {})
Constructing MSData one by one using add method.
>>> from stereo.core.ms_data import MSData
>>> from stereo.io.reader import read_gef
>>> ms_data = MSData()
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.tissue.gef")
>>> ms_data
ms_data: {'0': (9004, 25523), '1': (9111, 20816)}
num_slice: 2
names: ['0', '1']
obs: ['test_obs_1']
var: ['test_var_1']
relationship: other
var_type: intersect to 20760
tl.result: defaultdict(<class 'list'>, {})
Slice features like python list.
>>> from stereo.core.ms_data import MSData
>>> from stereo.io.reader import read_gef
>>> ms_data = MSData()
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.tissue.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.tissue.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.gef")
>>> ms_data += read_gef("../demo_data/SS200000135TL_D1/SS200000135TL_D1.tissue.gef")
>>> ms_data[3:]
_MSDataView(_names=['3', '4', '5'], _data_list=[StereoExpData object with n_cells X n_genes = 9111 X 20816
bin_type: bins
bin_size: 100
offset_x = 0
offset_y = 2
cells: ['cell_name']
genes: ['gene_name'], ...)
Slice features like DataFrame.
>>> ms_data[('1', '4'):]
_MSDataView(_names=['1', '4'], _data_list=[StereoExpData object with n_cells X n_genes = 9111 X 20816
bin_type: bins
bin_size: 100
offset_x = 0
offset_y = 2
cells: ['cell_name']
genes: ['gene_name'], ...)
"""
# base attributes
# TODO temporarily length 10
_data_list: List[StereoExpData] = field(default_factory=list)
_merged_data: StereoExpData = None
_names: List[str] = field(default_factory=list)
_obs: pd.DataFrame = None
_var: pd.DataFrame = None
_var_type: str = 'intersect'
_relationship: str = 'other'
# TODO not define yet
_relationship_info: dict = field(default_factory=dict)
# code-supported attributes
_name_dict: Dict[str, StereoExpData] = field(default_factory=dict)
_data_dict: Dict[int, str] = field(default_factory=dict)
__idx_generator: int = _default_idx()
def __check_data_list(self, data_list):
if not isinstance(data_list, list):
raise TypeError('data_list must be a list object')
if len(data_list) == 0:
return data_list
first_data = data_list[0]
for data in data_list[1:]:
if type(data) != type(first_data):
raise TypeError('each data in data_list must be the same type, available types: StereoExpData and AnnBasedStereoExpData')
return data_list
def __post_init__(self) -> object:
while len(self._data_list) > len(self._names):
self._names.append(self.__get_auto_key())
if not self._name_dict or not self._data_dict:
self.reset_name(default_key=False)
self.__check_data_list(self._data_list)
return self
def __iter__(self):
return iter(self._data_list)
@property
def data_list(self):
return self._data_list
@data_list.setter
def data_list(self, data_list: List[StereoExpData]):
self.__check_data_list(data_list)
assert len(data_list) == len(self._names), 'length of data_list must be equal to length of names'
self._data_list = list(data_list)
@property
def merged_data(self):
return self._merged_data
@merged_data.setter
def merged_data(self, value: StereoExpData):
self._merged_data = value
@property
def names(self):
return self._names
@names.setter
def names(self, value: List[str]):
if len(value) != len(self._data_list):
raise Exception('new names\' length should be same as data_list')
self._names = list(value)
self.reset_name(default_key=False)
@property
def var_type(self):
return self._var_type
@var_type.setter
def var_type(self, value: str):
if value not in {'intersect', 'union'}:
raise Exception(f'new var_type must be in {"intersect", "union"}')
self._var_type = value
@property
def relationship(self):
return self._relationship
@relationship.setter
def relationship(self, value: str):
if value not in _RELATIONSHIP_ENUM:
raise Exception(f'new relationship must be in {_RELATIONSHIP_ENUM}')
self._relationship = value
@property
def relationship_info(self):
return self._relationship_info
@relationship_info.setter
def relationship_info(self, value: dict):
self._relationship_info = value
def reset_position(self, mode='integrate'):
if mode == 'integrate' and self.merged_data:
self.merged_data.reset_position()
else:
for data in self.data_list:
data.reset_position()
def __len__(self):
return len(self._data_list)
def __copy__(self) -> object:
# TODO temporarily can not copy、deepcopy, return self
return self
def __deepcopy__(self, _) -> object:
return self
def get_data_list(self, key_idx_list):
data_list = []
names = []
for ki in key_idx_list:
if isinstance(ki, (str, np.str_)):
data_list.append(self._name_dict[ki])
names.append(ki)
elif isinstance(ki, (int, np.integer)):
ki = int(ki)
data_list.append(self._data_list[ki])
names.append(self._names[ki])
elif isinstance(ki, (list, tuple, np.ndarray, pd.Index)):
temp_data_list, temp_names = self.get_data_list(ki)
data_list.extend(temp_data_list)
names.extend(temp_names)
else:
raise KeyError(ki)
return data_list, names
def __check_slice(self, slice_obj: slice):
if not isinstance(slice_obj, slice):
raise TypeError(f'{slice_obj} should be slice')
if slice_obj.start is not None and isinstance(slice_obj.start, (str, np.str_)):
if slice_obj.start in self._name_dict:
new_start = self._names.index(slice_obj.start)
else:
new_start = None
else:
new_start = slice_obj.start
if slice_obj.stop is not None and isinstance(slice_obj.stop, (str, np.str_)):
if slice_obj.stop in self._name_dict:
new_stop = self._names.index(slice_obj.stop)
else:
new_stop = None
else:
new_stop = slice_obj.stop
if slice_obj.step is not None and not isinstance(slice_obj.step, (int, np.integer)):
raise TypeError(f'slice.step should be int')
return slice(new_start, new_stop, slice_obj.step)
def __getitem__(self, key: Union[str, int, list, tuple, np.ndarray, pd.Index, slice]) -> Union[StereoExpData, _MSDataView]:
if isinstance(key, (str, np.str_)):
return self._name_dict[key]
elif isinstance(key, (int, np.integer)):
return self._data_list[key]
elif isinstance(key, (list, tuple, np.ndarray, pd.Index)):
data_list, names = self.get_data_list(key)
return _MSDataView(_msdata=self, _data_list=data_list, _names=names)
elif isinstance(key, slice):
key = self.__check_slice(key)
data_list = self._data_list[key]
names = self._names[key]
return _MSDataView(_msdata=self, _data_list=data_list, _names=names)
else:
raise KeyError(key)
def __setitem__(self, key, value):
assert isinstance(key, (int, np.integer, str, np.str_))
assert isinstance(value, StereoExpData)
if key in self._name_dict:
key = self._names.index(key)
if isinstance(key, (int, np.integer)):
if key >= self.num_slice:
raise IndexError("list index out of range")
old_obj = self._data_list[key]
name = self._names[key]
self._data_list[key] = value
self._name_dict[name] = value
self._data_dict.pop(id(old_obj))
self._data_dict[id(value)] = name
else:
self.__real_add(value, key)
def __add__(self, other):
assert isinstance(other, StereoExpData)
self.__real_add(other)
return self
def __contains__(self, item) -> bool:
if type(item) is str:
return item in self._name_dict
elif isinstance(item, StereoExpData):
return id(item) in self._data_dict
else:
raise TypeError('In-Expression: only supports `name` or `StereoExpData-object`')
def add_data(self, data=None, names=None, **kwargs) -> object:
if not data:
raise Exception
if isinstance(data, StereoExpData):
return self.__add_data_objs([data], [names] if names else None)
elif type(data) is str:
return self.__add_data_paths([data], [names] if names else None, **kwargs)
elif type(data) is list:
if len(data) != len(names):
raise Exception('length of objs and length of keys must equal')
if isinstance(data[0], StereoExpData):
return self.__add_data_objs(data, names)
elif type(data[0]) is str:
return self.__add_data_paths(data, names, **kwargs)
raise TypeError
def del_data(self, name):
obj = self._name_dict.pop(name)
self._data_list.index(obj)
self._names.remove(name)
self._data_dict.pop(id(obj))
self._data_list.remove(obj)
def __delitem__(self, key):
self.del_data(key)
def __add_data_objs(self, data_list: List[StereoExpData], keys: List[str] = None) -> object:
if keys:
for key in keys:
if key in self._names:
raise KeyError(f'key={key} already exists')
for data_obj in data_list:
if data_obj in self:
raise Exception
for idx, data_obj in enumerate(data_list):
self.__real_add(data_obj, keys[idx] if keys and idx < len(keys) else None)
return self
def __add_data_paths(self, file_path_list: List[str], keys: List[str] = None, **kwargs) -> object:
from stereo.io.reader import read_gef, read_gem, read_ann_h5ad
data_list = []
# TODO mixed file format, how to handle arguments
bin_sizes = kwargs.get('bin_size', None)
bin_types = kwargs.get('bin_type', None)
spatial_keys = kwargs.get('spatial_key', None)
if bin_sizes is not None or bin_types is not None:
assert len(file_path_list) == len(bin_sizes) == len(bin_types)
for idx, file_path in enumerate(file_path_list):
if file_path.endswith('.gef'):
data_list.append(read_gef(
file_path,
bin_size=bin_sizes[idx] if bin_sizes is not None else 100,
bin_type=bin_types[idx] if bin_types is not None else 'bins',
))
elif file_path.endswith('.gem') or file_path.endswith('.gem.gz'):
data_list.append(read_gem(
file_path,
bin_size=bin_sizes[idx] if bin_sizes is not None else 100,
bin_type=bin_types[idx] if bin_types is not None else 'bins',
))
elif file_path.endswith('.h5ad'):
data_list.append(read_ann_h5ad(
file_path,
spatial_key=spatial_keys[idx],
bin_size=bin_sizes[idx],
bin_type=bin_types[idx],
))
else:
raise Exception(f'file format({file_path}) not support')
return self.__add_data_objs(data_list, keys)
def __get_auto_key(self) -> str:
self.__idx_generator += 1
return str(f'{self.__idx_generator}')
def __real_add(self, obj: StereoExpData, key: Union[str, None] = None) -> object:
if not key:
key = self.__get_auto_key()
while key in self._name_dict:
key = self.__get_auto_key()
self._name_dict[key] = obj
self._data_dict[id(obj)] = key
self._names.append(key)
self._data_list.append(obj)
return self
@property
def obs(self) -> pd.DataFrame:
if self._merged_data:
return self._merged_data.cells.to_df()
return pd.DataFrame()
@property
def var(self) -> pd.DataFrame:
if self.merged_data:
return self._merged_data.genes.to_df()
return pd.DataFrame()
def var_percent(self):
percent_list = []
var_set = set(self.var.index)
for data_obj in self._data_list:
count = 0
for gene in data_obj.gene_names:
if gene in var_set:
count += 1
percent_list.append(count / len(data_obj.gene_names))
return percent_list
@property
def shape(self) -> dict:
return dict(zip(self._names, [data_obj.shape for data_obj in self._data_list]))
@property
def num_slice(self):
return len(self)
def rename(self, mapper: Dict[str, str]) -> object:
# if len(rename_keys) is m, and len(self._data_dict) is n, method time complexity:
# O(2 * (n * m + n + m))
if not mapper:
raise Exception('`rename_keys` is empty or None')
elif not self._name_dict:
raise Exception('`ms_data` is empty')
mapper_values = mapper.values()
set_of_src = set(mapper_values)
if len(set_of_src) != len(mapper_values):
raise Exception('`rename_keys` with values target at same obj')
# avoid circle-renaming, we only support rename to a new key
intersection_src_keys = mapper_values & self._name_dict.keys()
if intersection_src_keys:
raise Exception(f'{intersection_src_keys} already exists, can not rename to!')
# allow intersection_src_keys being empty
intersection_dst_keys = mapper.keys() & self._name_dict.keys()
if len(intersection_dst_keys) != len(mapper.keys()):
raise Exception(f'some keys in {mapper.keys()} not exist in ms_data')
for src in intersection_dst_keys:
dst = mapper[src]
src_obj = self._name_dict.pop(src)
self._name_dict[dst] = src_obj
self._data_dict[id(src_obj)] = dst
self._names = []
for obj in self._data_list:
self._names.append(self._data_dict[id(obj)])
return self
def reset_name(self, start_idx=None, default_key=True) -> object:
# if self.data_list is n, O(3n)
self.__idx_generator = _default_idx() if start_idx is None else start_idx
self._name_dict, self._data_dict = dict(), dict()
for idx, obj in enumerate(self._data_list):
name = self.__get_auto_key() if default_key else self._names[idx]
self._name_dict[name] = obj
self._data_dict[id(obj)] = name
self._names[idx] = name
if len(self._data_list) < len(self._names):
self._names = self._names[0:len(self._data_list)]
return self
class ScopesData(dict):
def __init__(self, ms_data: MSData, *args, **kwargs):
self._ms_data = ms_data
super().__init__(*args, **kwargs)
def __setitem__(self, key, value):
if not isinstance(value, StereoExpData):
raise TypeError(f'value must be a StereoExpData object')
def set_result_key_method(result_key):
self._ms_data.tl.result_keys.setdefault(key, [])
if result_key in self._ms_data.tl.result_keys[key]:
self._ms_data.tl.result_keys[key].remove(result_key)
self._ms_data.tl.result_keys[key].append(result_key)
value.tl.result.set_result_key_method = set_result_key_method
return super().__setitem__(key, value)
[docs]@dataclass
class MSData(_MSDataStruct):
__doc__ = _MSDataStruct.__doc__
_tl = None
_plt = None
_scopes_data: Dict[str, StereoExpData] = None
def __post_init__(self) -> object:
if self._scopes_data is None:
self._scopes_data = ScopesData(self)
else:
self._scopes_data = self.__reset_scopes_data(self._scopes_data)
super().__post_init__()
return self
def __reset_scopes_data(self, value):
if not isinstance(value, dict):
raise TypeError(f'value must be a dict object')
if not isinstance(value, ScopesData):
scopes_data = ScopesData(self)
for scope_key, scope_data in value.items():
scopes_data[scope_key] = scope_data
else:
scopes_data = value
return scopes_data
@property
def tl(self):
if self._tl is None:
self._tl = TL(self)
return self._tl
@property
def plt(self):
if self._plt is None:
self._plt = PLT(self)
return self._plt
@property
def scopes_data(self):
return self._scopes_data
@scopes_data.setter
def scopes_data(self, value):
self._scopes_data = self.__reset_scopes_data(value)
@property
def mss(self):
return self.tl.result
def generate_scope_key(self, scope=None):
if scope is None:
scope = slice(None)
scope_key = scope
try:
if isinstance(scope, (int, np.integer)):
scope_key = f"scope_[{scope}]"
elif isinstance(scope, (str, np.str_)):
if scope in self._name_dict:
scope_key = f"scope_[{self._names.index(scope)}]"
else:
scope_key = scope
elif isinstance(scope, slice):
names = self[scope]._names
scope_key = f"scope_[{','.join([str(self._names.index(name)) for name in names])}]" # noqa
elif isinstance(scope, (list, tuple, np.ndarray, pd.Index)):
_, names = self.get_data_list(scope)
scope_key = f"scope_[{','.join([str(self._names.index(name)) for name in names])}]" # noqa
except:
scope_key = scope
finally:
return scope_key
def remove_scopes_data(self, scope):
scope_key = self.generate_scope_key(scope)
if scope_key in self._scopes_data:
del self._scopes_data[scope_key]
if scope_key in self.tl.result_keys:
del self.tl.result_keys[scope_key]
[docs] def integrate(self, scope=None, remove_existed=False, **kwargs):
"""
Integrate some single-samples specified by `scope` to a merged one.
:param scope: Which scope of samples to be integrated, defaults to None.
Each integrate sample is saved in memory, performing this function
by passing duplicate `scope` will return the saved one.
:param remove_existed: Whether to remove the saved integrate sample when passing a duplicate `scope`, defaults to False.
"""
from stereo.utils.data_helper import merge
if self._var_type not in {"union", "intersect"}:
raise Exception("Please specify the operation on samples with the parameter '_var_type'")
if 'var_type' in kwargs:
del kwargs['var_type']
if 'batch_tags' in kwargs:
del kwargs['batch_tags']
if remove_existed:
self.remove_scopes_data(scope)
scope_key = self.generate_scope_key(scope)
if scope_key in self._scopes_data:
return self._scopes_data[scope_key]
if scope == None:
data_list = self.data_list
else:
data_list = self[scope].data_list
if len(data_list) > 1:
if scope is None:
batch_tags = None
else:
batch_tags = [self._names.index(name) for name in self[scope].names]
merged_data = merge(*data_list, var_type=self._var_type, batch_tags=batch_tags, **kwargs)
else:
merged_data = deepcopy(data_list[0])
batch = self._names.index(self[scope].names[0])
merged_data.cells.cell_name = np.char.add(merged_data.cells.cell_name, f'-{batch}')
merged_data.cells.batch = batch
obs_columns = merged_data.cells._obs.columns.drop('batch')
if len(obs_columns) > 0:
merged_data.cells._obs.drop(columns=obs_columns, inplace=True)
var_columns = merged_data.genes._var.columns
if 'real_gene_name' in var_columns:
var_columns = var_columns.drop('real_gene_name')
if len(var_columns) > 0:
merged_data.genes._var.drop(columns=var_columns, inplace=True)
# def set_result_key_method(key):
# self.tl.result_keys.setdefault(scope_key, [])
# if key in self.tl.result_keys[scope_key]:
# self.tl.result_keys[scope_key].remove(key)
# self.tl.result_keys[scope_key].append(key)
# merged_data.tl.result.set_result_key_method = set_result_key_method
merged_data.tl.review_key_record()
scope_key = self.generate_scope_key(scope)
self._scopes_data[scope_key] = merged_data
if scope == None or scope == slice(None):
self._merged_data = merged_data
return merged_data
def split_after_batching_integrate(self):
if self._var_type == "union":
raise NotImplementedError("Split a union data not implemented yet")
from stereo.utils.data_helper import split
self._data_list = split(self.merged_data)
self.reset_name(default_key=False)
self.merged_data = None
[docs] def to_integrate(
self,
scope: slice,
res_key: str,
_from: slice,
type: Literal['obs', 'var'] = 'obs',
item: Optional[Union[list, np.ndarray, str]] = None,
fill=np.NaN,
cluster: bool = True
):
"""
Integrate an obs column or a var column from some single-samples spcified by `_from` to the merged sample.
:param scope: Which integrate mss group to save result.
:param res_key: New column name in merged sample obs or var.
:param _from: Where to get the single-sample target infomation.
:param type: obs or var level, defaults to 'obs'.
:param item: The column names in single-sample obs or var, defaults to the value of `res_key`.
:param fill: Default value when the merged sample has no conrresponding item, defaults to np.NaN.
:param cluster: Whether it is a clustering result, defaults to True.
.. note::
The length of `scope` must be equal to `_from`.
The `type` just only supports 'obs' currently.
Examples
--------
Constructing MSData from 5 single-samples.
>>> import stereo as st
>>> data1 = st.io.read_h5ad('../data/10.h5ad')
>>> data2 = st.io.read_h5ad('../data/11.h5ad')
>>> data3 = st.io.read_h5ad('../data/12.h5ad')
>>> data4 = st.io.read_h5ad('../data/13.h5ad')
>>> data5 = st.io.read_h5ad('../data/14.h5ad')
>>> ms_data = data1 + data2 + data3 + data4 + data5
>>> ms_data
ms_data: {'0': (493, 30254), '1': (285, 30254), '2': (753, 30254), '3': (731, 30254), '4': (412, 30254)}
num_slice: 5
names: ['0', '1', '2', '3', '4']
obs: []
var: []
relationship: other
var_type: intersect to 0
mss: []
Integrating all samples to a merged one.
>>> ms_data.integrate()
Integrating an obs column named as 'celltype' from first three samples to the merged sample, to name as 'celltype'
>>> from stereo.core.ms_pipeline import slice_generator
>>> ms_data.to_integrate(res_key='celltype', scope=slice_generator[0:3], _from=slice_generator[0:3], type='obs', item=['celltype'] * 3)
Integrating an obs column named as 'celltype' from all samples to the merged sample, to name as 'celltype'
>>> from stereo.core.ms_pipeline import slice_generator
>>> ms_data.to_integrate(res_key='celltype', scope=slice_generator[:], _from=slice_generator[:], type='obs', item=['celltype'] * ms_data.num_slice)
"""
assert self[scope]._names == self[_from]._names, f"`scope`: {scope} should equal with _from: {_from}"
assert isinstance(item, str) or len(item) == len(self[_from]._names), "`item`'s length not equal to _from"
scope_names = self[scope]._names
scope_key = self.generate_scope_key(scope_names)
assert scope_key in self._scopes_data or self._merged_data, f"`to_integrate` need running function `integrate` first"
if type == 'obs':
if scope_key in self._scopes_data:
self._scopes_data[scope_key].cells[res_key] = fill
if self._merged_data is not None:
self._merged_data.cells[res_key] = fill
elif type == 'var':
raise NotImplementedError
else:
raise Exception(f"`type`: {type} not in ['obs', 'var'], this should not happens!")
data_list = self[scope]._data_list
if item is None:
item = res_key
if isinstance(item, str):
item = [item] * len(data_list)
for idx, stereo_exp_data in enumerate(data_list):
if type == 'obs':
res: pd.Series = stereo_exp_data.cells[item[idx]]
sample_idx = self._names.index(scope_names[idx])
new_index: pd.Series = res.index.astype('str') + f'-{sample_idx}'
# res.index = new_index
if scope_key in self._scopes_data:
index_intersect = np.intersect1d(new_index, self._scopes_data[scope_key].cell_names)
# isin = np.isin(new_index, index_intersect)
isin = new_index.isin(index_intersect)
_res = res[isin].to_numpy()
_index = new_index[isin]
self._scopes_data[scope_key].cells.loc[_index, res_key] = _res
if self._scopes_data[scope_key] is self._merged_data:
continue
if self._merged_data is not None:
index_intersect = np.intersect1d(new_index, self._merged_data.cell_names)
# isin = np.isin(new_index, index_intersect)
isin = new_index.isin(index_intersect)
_res = res[isin].to_numpy()
_index = new_index[isin]
self._merged_data.cells.loc[_index, res_key] = _res
elif type == 'var':
raise NotImplementedError
else:
raise Exception(f"`type`: {type} not in ['obs', 'var'], this should not happens!")
if type == 'obs':
if cluster:
if scope_key in self._scopes_data:
self._scopes_data[scope_key].tl.reset_key_record('cluster', res_key)
self._scopes_data[scope_key].tl.result.set_result_key_method(res_key)
self._scopes_data[scope_key].cells[res_key] = self._scopes_data[scope_key].cells[res_key].astype('category')
if self._merged_data is not None and self._merged_data is not self._scopes_data[scope_key]:
self._merged_data.tl.reset_key_record('cluster', res_key)
self._merged_data.tl.result.set_result_key_method(res_key)
self._merged_data.cells[res_key] = self._merged_data.cells[res_key].astype('category')
elif type == 'var':
raise NotImplementedError
else:
raise Exception(f"`type`: {type} not in ['obs', 'var'], this should not happens!")
[docs] def to_isolated(
self,
scope: slice,
res_key: str,
to: slice,
type: Literal['obs', 'var'] = 'obs',
item: Optional[Union[list, np.ndarray, str]] = None,
fill=np.NaN
):
"""
Copy a result from mss group specfied by scope to some single-samples specfied by `to`.
:param scope: Which integrate mss group to get result.
:param res_key: the key to get result from mms group.
:param to: which single-samples are the result copy to.
:param type: obs or var level, defaults to 'obs'
:param item: New column name in obs of single-sample, defaults to the value of `res_key`.
:param fill: Default value when the single-sample has no conrresponding item, defaults to np.NaN
.. note::
The length of `scope` must be equal to `to`.
Only supports clustering result when `type` is 'obs' and hvg result when `type` is 'var'.
Parameter `item` only available for obs type.
Examples
--------
Constructing MSData from 5 single-samples.
>>> import stereo as st
>>> data1 = st.io.read_h5ad('../data/10.h5ad')
>>> data2 = st.io.read_h5ad('../data/11.h5ad')
>>> data3 = st.io.read_h5ad('../data/12.h5ad')
>>> data4 = st.io.read_h5ad('../data/13.h5ad')
>>> data5 = st.io.read_h5ad('../data/14.h5ad')
>>> ms_data = data1 + data2 + data3 + data4 + data5
>>> ms_data
ms_data: {'0': (493, 30254), '1': (285, 30254), '2': (753, 30254), '3': (731, 30254), '4': (412, 30254)}
num_slice: 5
names: ['0', '1', '2', '3', '4']
obs: []
var: []
relationship: other
var_type: intersect to 0
mss: []
Integrating all samples to a merged one.
>>> ms_data.integrate()
... did a clustering, the key of result is 'leiden' ...
Copy the 'leiden' result to first three samples, to name as 'leiden'.
>>> from stereo.core.ms_pipeline import slice_generator
>>> ms_data.to_isolated(scope=slice_generator[0:3], res_key='leiden', to=slice_generator[0:3], type='obs', item=['leiden'] * 3)
Copy the 'leiden' result to all samples, to name as 'leiden'.
>>> from stereo.core.ms_pipeline import slice_generator
>>> ms_data.to_isolated(scope=slice_generator[:], res_key='leiden', to=slice_generator[:], type='obs', item=['leiden'] * 3)
"""
assert self[scope]._names == self[to]._names, f"`scope`: {scope} should equal with to: {to}"
assert isinstance(item, str) or len(item) == len(self[to]._names), "`item`'s length not equal to `to`"
scope_names = self[scope]._names
scope_key = self.generate_scope_key(scope_names)
merged_res: pd.DataFrame = self.tl.result[scope_key][res_key].copy(deep=True)
if type == "obs":
# TODO: only support cluster data
if "bins" not in merged_res.columns or "group" not in merged_res.columns:
raise Exception("Only soupport cluster result currently.")
merged_res.set_index('bins', inplace=True)
elif type == "var":
# TODO: only support hvg data
merged_res.index = self._scopes_data[scope_key].genes.gene_name
data_list = self[scope]._data_list
if item is None:
item = res_key
if isinstance(item, str):
item = [item] * len(data_list)
for idx, stereo_exp_data in enumerate(data_list):
if type == 'obs':
column_name = item[idx]
original_index = stereo_exp_data.cells._obs.index
stereo_exp_data.cells._obs.index = np.char.add(
np.char.add(stereo_exp_data.cells._obs.index.to_numpy().astype('U'), '-'),
stereo_exp_data.cells['batch']
)
stereo_exp_data.cells._obs[column_name] = merged_res['group']
if fill is not np.NaN:
if stereo_exp_data.cells._obs[column_name].dtype.name == 'category':
stereo_exp_data.cells._obs[column_name].cat.add_categories(fill, inplace=True)
stereo_exp_data.cells._obs[column_name].fillna(fill, inplace=True)
if stereo_exp_data.cells._obs[column_name].dtype.name == 'category':
stereo_exp_data.cells._obs[column_name].cat.remove_unused_categories(inplace=True)
stereo_exp_data.cells._obs.index = original_index
elif type == 'var':
intersect = np.intersect1d(stereo_exp_data.genes.gene_name, merged_res.index)
result_df = pd.DataFrame(
fill, index=stereo_exp_data.genes.gene_name, columns=merged_res.columns
)
for column in merged_res.columns:
if merged_res[column].dtype is np.dtype(bool):
result_df[column] = False
result_df.loc[intersect, column] = merged_res.loc[intersect, column]
stereo_exp_data.tl.result[item[idx]] = result_df
stereo_exp_data.tl.reset_key_record('hvg', item[idx])
else:
raise Exception(f"`type`: {type} not in ['obs', 'var'], this should not happens!")
@staticmethod
def to_msdata(
data: StereoExpData,
batch_key: str = 'batch',
relationship: Optional[str] = 'other',
var_type: Optional[str] = 'intersect'
):
if batch_key not in data.cells:
raise KeyError(f"The batch key '{batch_key}' is not in cells or obs.")
from stereo.preprocess.filter import filter_by_clusters
batch_data = pd.DataFrame({
'bins': data.cells.cell_name,
'group': data.cells[batch_key].astype('category')
})
sub_data_list = []
sub_data_names = []
for batch_code in batch_data['group'].cat.categories:
sub_data = filter_by_clusters(data, batch_key, groups=batch_code, inplace=False)
sub_data_list.append(sub_data)
sub_data_names.append(batch_code)
return MSData(_data_list=sub_data_list, _names=sub_data_names, _relationship=relationship, _var_type=var_type)
def __str__(self):
return f'''ms_data: {self.shape}
num_slice: {self.num_slice}
names: {self.names}
merged_data: {None if self._merged_data is None else f"id({id(self._merged_data)})"}
obs: {self.obs.columns.to_list()}
var: {self.var.columns.to_list()}
relationship: {self.relationship}
var_type: {self._var_type} to {len(self.var.index)}
current_mode: {self.tl.mode}
current_scope: {self.generate_scope_key(self.tl.scope)}
scopes_data: {[key + ":" + f"id({id(value)})" for key, value in self._scopes_data.items()]}
mss: {[key + ":" + str(value) for key, value in self.tl.result_keys.items()]}
'''
def __repr__(self):
return self.__str__()
def write(self, filename, to_mudata=False):
if not to_mudata:
from stereo.io.writer import write_h5ms
write_h5ms(self, filename)
else:
from stereo.io.writer import write_h5mu
return write_h5mu(self, filename)
TL = type('TL', (MSDataPipeLine,), {'ATTR_NAME': 'tl', "BASE_CLASS": None})
PLT = type('PLT', (MSDataPipeLine,), {'ATTR_NAME': 'plt', "BASE_CLASS": None})