from typing import Tuple, Optional, Literal, Union, List
from copy import deepcopy
import numpy as np
import pandas as pd
from stereo.core.stereo_exp_data import StereoExpData
from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
from stereo.algorithm.spt.multiple_time import (
transfer_matrix,
generate_animate_input,
map_data
)
from stereo.algorithm.spt.single_time import Trainer
# from stereo.algorithm.spt.plot import PlotSpaTrack
from stereo.plots.plot_spa_track import PlotSpaTrack
class MSSpaTrack(MSDataAlgorithmBase):
[docs] def main(
self,
cluster_res_key: str = 'cluster'
):
"""
Create an object of SpaTrack for multiple samples.
:param cluster_res_key: the key of clustering result to be used in cells/obs
"""
# if cluster_res_key not in self.pipeline_res:
# raise KeyError(f'Cannot find clustering result by key {cluster_res_key}')
if 'spa_track' not in self.pipeline_res:
self.pipeline_res['spa_track'] = {}
self.cluster_res_key = cluster_res_key
self.pipeline_res['spa_track']['cluster_res_key'] = cluster_res_key
self.plot = PlotSpaTrack(ms_data=self.ms_data, pipeline_res=self.pipeline_res)
return self
[docs] def transfer_matrix(
self,
data1_index: Union[str, int] = None,
data2_index: Union[str, int] = None,
spatial_key: str = 'spatial',
alpha: float = 0.1,
epsilon: float = 0.01,
rho: float = np.inf,
G_1: float = None,
G_2: float = None,
**kwargs
):
"""
Calculates transfer matrix between two time.
:param data1_index: The index in the ms_data of the first data.
:param data2_index: The index in the ms_data of the second data.
:param spatial_key: The key to get position information of cells, defaults to 'spatial'
:param alpha: Alignment tuning parameter. Note:0 <= alpha <= 1.
When ``alpha = 0`` only the gene expression data is taken into account,
while ``alpha =1`` only the spatial coordinates are taken into account.
:param epsilon: Weight for entropy regularization term, defaults to 0.01
:param rho: Weight for KL divergence penalizing unbalanced transport, defaults to np.inf
:param G_1: Distance matrix within spatial data 1 (spots, spots), defaults to None
:param G_2: Distance matrix within spatial data 2 (spots, spots), defaults to None
"""
assert spatial_key is not None, 'spatial_key must be provided'
data1 = self.ms_data[data1_index]
data2 = self.ms_data[data2_index]
assert spatial_key in data1.cells_matrix and spatial_key in data2.cells_matrix, 'spatial_key must be existed in both data'
data1_name = self.ms_data.names[data1_index] if isinstance(data1_index, int) else data1_index
data2_name = self.ms_data.names[data2_index] if isinstance(data2_index, int) else data2_index
if 'transfer_matrices' not in self.pipeline_res['spa_track']:
self.pipeline_res['spa_track']['transfer_matrices'] = {}
tm = transfer_matrix(data1, data2, layer=None, spatial_key=spatial_key, alpha=alpha, epsilon=epsilon,
rho=rho, G_1=G_1, G_2=G_2, **kwargs)
self.pipeline_res['spa_track']['transfer_spatial_key'] = spatial_key
self.pipeline_res['spa_track']['transfer_matrices'][(data1_name, data2_name)] = tm
def map_data(
self,
data1_index: Union[str, int],
data2_index: Union[str, int]
):
data1 = self.ms_data[data1_index]
data2 = self.ms_data[data2_index]
data1_name = self.ms_data.names[data1_index] if isinstance(data1_index, int) else data1_index
data2_name = self.ms_data.names[data2_index] if isinstance(data2_index, int) else data2_index
transfer_matrix = self.pipeline_res['spa_track']['transfer_matrices'][(data1_name, data2_name)]
pi = map_data(transfer_matrix, data1, data2)
if 'mapped_data' not in self.pipeline_res['spa_track']:
self.pipeline_res['spa_track']['mapped_data'] = {}
self.pipeline_res['spa_track']['mapped_data'][(data1_name, data2_name)] = pi
return pi
[docs] def gr_training(
self,
data1_index: Union[str, int],
data2_index: Union[str, int],
tfs_path: str = None,
min_cells_1: int = None,
min_cells_2: int = None,
cell_select_per_time: int = 10,
cell_generate_per_time: int = 500,
train_ratio: float = 0.8,
use_gpu: bool = True,
random_state: int = 0,
training_times: int = 10,
iter_times: int = 30,
mapping_num: int = 3000,
filename: str = "weights.csv",
lr_ratio: float = 0.1
):
"""
Create and run a trainer for gene regulatory network training in **2_time** mode(two samples).
:param data1_index: The index in the ms_data of the first data
:param data2_index: The index in the ms_data of the second data
:param tfs_path: The path of the tf names file, defaults to None
:param min_cells_1: The minimum number of cells for filtering the first data
:param min_cells_2: The minimum number of cells for filtering the second data
:param cell_select_per_time: The number of randomly selected cells at each time point, defaults to 10
:param cell_generate_per_time: The number of cells generated at each time point, defaults to 500
:param train_ratio: Ratio of training data, defaults to 0.8
:param use_gpu: Whether to use gpu, by default, to use if available.
:param random_state: Random seed of numpy and torch, fixed for reproducibility, defaults to 0
:param training_times: Number of times to randomly initialize the model and retrain, defaults to 10
:param iter_times: The number of iterations for each training model, defaults to 30
:param mapping_num: The number of top weight pairs you want to extract, defaults to 3000
:param filename: The name of the file to save the weights, defaults to "weights.csv"
:param lr_ratio: The learning rate, defaults to 0.1
:return: A trainer object for gene regulatory network training.
"""
data_list = [deepcopy(self.ms_data[data1_index]), deepcopy(self.ms_data[data2_index])]
min_cells = [min_cells_1, min_cells_2]
cell_mapping = self.map_data(data1_index, data2_index)
trainer = Trainer(
data_type="2_time",
data=data_list,
tfs_path=tfs_path,
cell_mapping=cell_mapping,
min_cells=min_cells,
cell_select_per_time=cell_select_per_time,
cell_generate_per_time=cell_generate_per_time,
train_ratio=train_ratio,
use_gpu=use_gpu,
random_state=random_state
)
trainer.run(
training_times=training_times,
iter_times=iter_times,
mapping_num=mapping_num,
filename=filename,
lr_ratio=lr_ratio
)
return trainer