Source code for stereo.algorithm.paste.methods

from copy import deepcopy
from typing import (
    List,
    Tuple,
    Optional
)

import numpy as np
# from anndata import AnnData
import ot
from sklearn.decomposition import NMF


from stereo.core.cell import Cell
from stereo.core.gene import Gene
from stereo.core.stereo_exp_data import StereoExpData
from stereo.log_manager import LogManager
from stereo.log_manager import logger
from .helper import (
    kl_divergence_backend,
    to_dense_array,
    extract_data_matrix
)


[docs]def pairwise_align( sliceA: StereoExpData, sliceB: StereoExpData, alpha: float = 0.1, dissimilarity: str = 'kl', use_rep: Optional[str] = None, G_init=None, a_distribution=None, b_distribution=None, norm: bool = False, numItermax: int = 200, filter_gene: bool = True, backend=ot.backend.NumpyBackend(), use_gpu: bool = False, return_obj: bool = False, verbose: bool = False, gpu_verbose: bool = True, **kwargs ) -> Tuple[np.ndarray, Optional[int]]: """ Calculates and returns optimal alignment of two slices. :param sliceA: Slice A to align. :param sliceB: Slice B to align. :param alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1. :param dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``. :param use_rep: If ``None``, uses ``slice.X`` to calculate dissimilarity between spots, otherwise uses the representation given by ``slice.obsm[use_rep]``. :param G_init (array-like, optional): Initial mapping to be used in FGW-OT, otherwise default is uniform mapping. :param a_distribution (array-like, optional): Distribution of sliceA spots, otherwise default is uniform. :param b_distribution (array-like, optional): Distribution of sliceB spots, otherwise default is uniform. :param numItermax: Max number of iterations during FGW-OT. :param norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged. :param backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``. :param use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch. :param return_obj: If ``True``, additionally returns objective function output of FGW-OT. :param verbose: If ``True``, FGW-OT is verbose. :param gpu_verbose: If ``True``, print whether gpu is being used to user. :return: Alignment of spots. If ``return_obj = True``, additionally returns objective function output of FGW-OT. """ # noqa # Determine if gpu or cpu is being used if use_gpu: try: import torch backend = ot.backend.TorchBackend() except Exception: logger.warning("We currently only have gpu support for Pytorch. Please install torch.") backend = ot.backend.NumpyBackend() if isinstance(backend, ot.backend.TorchBackend): if torch.cuda.is_available(): if gpu_verbose: logger.info("gpu is available, using gpu.") else: if gpu_verbose: logger.warning("gpu is not available, resorting to torch cpu.") use_gpu = False else: logger.warning( "We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). " "Reverting to selected backend cpu.") use_gpu = False else: if gpu_verbose: logger.info("Using selected backend cpu. If you want to use gpu, set use_gpu = True.") if filter_gene: common_genes = np.intersect1d(sliceA.genes.gene_name, sliceB.genes.gene_name) LogManager.stop_logging() sliceA.tl.filter_genes(gene_list=common_genes) sliceB.tl.filter_genes(gene_list=common_genes) LogManager.start_logging() # check if slices are valid for s in [sliceA, sliceB]: if s.shape[0] == 0 or s.shape[1] == 0: raise ValueError(f"Found empty `StereoExpData`:\n{s}.") # Backend nx = backend # Calculate spatial distances coordinatesA = deepcopy(sliceA.position) coordinatesA = nx.from_numpy(coordinatesA) coordinatesB = deepcopy(sliceB.position) coordinatesB = nx.from_numpy(coordinatesB) if isinstance(nx, ot.backend.TorchBackend): coordinatesA = coordinatesA.float() coordinatesB = coordinatesB.float() D_A = ot.dist(coordinatesA, coordinatesA, metric='euclidean') D_B = ot.dist(coordinatesB, coordinatesB, metric='euclidean') if isinstance(nx, ot.backend.TorchBackend) and use_gpu: D_A = D_A.cuda() D_B = D_B.cuda() # Calculate expression dissimilarity A_X, B_X = nx.from_numpy(to_dense_array(extract_data_matrix(sliceA, use_rep))), nx.from_numpy( to_dense_array(extract_data_matrix(sliceB, use_rep))) if isinstance(nx, ot.backend.TorchBackend) and use_gpu: A_X = A_X.cuda() B_X = B_X.cuda() if dissimilarity.lower() == 'euclidean' or dissimilarity.lower() == 'euc': M = ot.dist(A_X, B_X) else: s_A = A_X + 0.01 s_B = B_X + 0.01 M = kl_divergence_backend(s_A, s_B) M = nx.from_numpy(M) if isinstance(nx, ot.backend.TorchBackend) and use_gpu: M = M.cuda() # init distributions if a_distribution is None: a = nx.ones((sliceA.shape[0],)) / sliceA.shape[0] else: a = nx.from_numpy(a_distribution) if b_distribution is None: b = nx.ones((sliceB.shape[0],)) / sliceB.shape[0] else: b = nx.from_numpy(b_distribution) if isinstance(nx, ot.backend.TorchBackend) and use_gpu: a = a.cuda() b = b.cuda() if norm: D_A /= nx.min(D_A[D_A > 0]) D_B /= nx.min(D_B[D_B > 0]) # Run OT if G_init is not None: G_init = nx.from_numpy(G_init) if isinstance(nx, ot.backend.TorchBackend): G_init = G_init.float() if use_gpu: G_init.cuda() pi, logw = my_fused_gromov_wasserstein(M, D_A, D_B, a, b, G_init=G_init, loss_fun='square_loss', alpha=alpha, log=True, numItermax=numItermax, verbose=verbose, use_gpu=use_gpu) pi = nx.to_numpy(pi) obj = nx.to_numpy(logw['fgw_dist']) if isinstance(backend, ot.backend.TorchBackend) and use_gpu: torch.cuda.empty_cache() if return_obj: return pi, obj return pi
[docs]def center_align( initial_slice: StereoExpData, slices: List[StereoExpData], lmbda=None, alpha: float = 0.1, n_components: int = 15, threshold: float = 0.001, max_iter: int = 10, nmf_max_iter: int = 200, dissimilarity: str = 'kl', norm: bool = False, random_seed: Optional[int] = None, pis_init: Optional[List[np.ndarray]] = None, distributions=None, backend=ot.backend.NumpyBackend(), use_gpu: bool = False, verbose: bool = False, gpu_verbose: bool = True, ) -> Tuple[StereoExpData, List[np.ndarray]]: """ Computes center alignment of slices. :param initial_slice: Slice to use as the initialization for center alignment; Make sure to include gene expression and spatial information. :param slices: List of slices to use in the center alignment. :param lmbda (array-like, optional): List of probability weights assigned to each slice; If ``None``, use uniform weights. :param alpha: Alignment tuning parameter. Note: 0 <= alpha <= 1. :param n_components: Number of components in NMF decomposition. :param threshold: Threshold for convergence of W and H during NMF decomposition. :param max_iter: Maximum number of iterations for our center alignment algorithm. :param dissimilarity: Expression dissimilarity measure: ``'kl'`` or ``'euclidean'``. :param norm: If ``True``, scales spatial distances such that neighboring spots are at distance 1. Otherwise, spatial distances remain unchanged. :param random_seed: Set random seed for reproducibility. :param pis_init: Initial list of mappings between 'A' and 'slices' to solver. Otherwise, default will automatically calculate mappings. :param distributions (List[array-like], optional): Distributions of spots for each slice. Otherwise, default is uniform. :param backend: Type of backend to run calculations. For list of backends available on system: ``ot.backend.get_backend_list()``. :param use_gpu: If ``True``, use gpu. Otherwise, use cpu. Currently we only have gpu support for Pytorch. :param verbose: If ``True``, FGW-OT is verbose. :param gpu_verbose: If ``True``, print whether gpu is being used to user. :return: - Inferred center slice with full and low dimensional representations (W, H) of the gene expression matrix. - List of pairwise alignment mappings of the center slice (rows) to each input slice (columns). """ # noqa # Determine if gpu or cpu is being used if use_gpu: try: import torch backend = ot.backend.TorchBackend() except Exception: logger.warning("We currently only have gpu support for Pytorch. Please install torch.") backend = ot.backend.NumpyBackend() if isinstance(backend, ot.backend.TorchBackend): if torch.cuda.is_available(): if gpu_verbose: logger.info("gpu is available, using gpu.") else: if gpu_verbose: logger.warning("gpu is not available, resorting to torch cpu.") use_gpu = False else: logger.warning( "We currently only have gpu support for Pytorch, please set backend = ot.backend.TorchBackend(). " "Reverting to selected backend cpu.") use_gpu = False else: if gpu_verbose: logger.info("Using selected backend cpu. If you want to use gpu, set use_gpu = True.") if lmbda is None: lmbda = len(slices) * [1 / len(slices)] if distributions is None: distributions = len(slices) * [None] # get common genes common_genes = initial_slice.genes.gene_name for s in slices: common_genes = np.intersect1d(common_genes, s.genes.gene_name) # subset common genes LogManager.stop_logging() initial_slice.tl.filter_genes(gene_list=common_genes) for i in range(len(slices)): slices[i].tl.filter_genes(gene_list=common_genes) LogManager.start_logging() logger.info('Filtered all slices for common genes. There are ' + str(len(common_genes)) + ' common genes.') # Run initial NMF if dissimilarity.lower() == 'euclidean' or dissimilarity.lower() == 'euc': model = NMF(n_components=n_components, init='random', random_state=random_seed, verbose=verbose) else: model = NMF(n_components=n_components, solver='mu', beta_loss='kullback-leibler', init='random', random_state=random_seed, verbose=verbose) if pis_init is None: pis = [None for i in range(len(slices))] W = model.fit_transform(initial_slice.exp_matrix) else: pis = pis_init W = model.fit_transform(initial_slice.shape[0] * sum( [lmbda[i] * np.dot(pis[i], to_dense_array(slices[i].exp_matrix)) for i in range(len(slices))])) H = model.components_ center_coordinates = initial_slice.position if not isinstance(center_coordinates, np.ndarray): logger.warning("Warning: initial_slice.position is not of type numpy array.") # Initialize center_slice center_slice = StereoExpData(exp_matrix=np.dot(W, H)) center_slice.genes = Gene(gene_name=common_genes) center_slice.cells = Cell(cell_name=initial_slice.cells.cell_name) center_slice.position = center_coordinates # Minimize R iteration_count = 0 R = 0 R_diff = 100 while R_diff > threshold and iteration_count < max_iter: print("Iteration: " + str(iteration_count)) pis, r = center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity=dissimilarity, norm=norm, G_inits=pis, distributions=distributions, verbose=verbose) W, H = center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, dissimilarity=dissimilarity, max_iter=nmf_max_iter, verbose=verbose) R_new = np.dot(r, lmbda) iteration_count += 1 R_diff = abs(R - R_new) print("Objective ", R_new) print("Difference: " + str(R_diff) + "\n") R = R_new center_slice = deepcopy(initial_slice) center_slice.exp_matrix = np.dot(W, H) if center_slice.attr is None: center_slice.attr = {} center_slice.attr['paste_W'] = W center_slice.attr['paste_H'] = H center_slice.attr['full_rank'] = center_slice.shape[0] * sum( [lmbda[i] * np.dot(pis[i], to_dense_array(slices[i].exp_matrix)) for i in range(len(slices))]) center_slice.attr['obj'] = R center_slice.array2sparse() return center_slice, pis
# --------------------------- HELPER METHODS ----------------------------------- def center_ot(W, H, slices, center_coordinates, common_genes, alpha, backend, use_gpu, dissimilarity='kl', norm=False, G_inits=None, distributions=None, verbose=False): center_slice = StereoExpData() center_slice.exp_matrix = np.dot(W, H) center_slice.genes = Gene(gene_name=common_genes) center_slice.position = center_coordinates if distributions is None: distributions = len(slices) * [None] pis = [] r = [] print('Solving Pairwise Slice Alignment Problem.') for i in range(len(slices)): p, r_q = pairwise_align(center_slice, slices[i], filter_gene=False, alpha=alpha, dissimilarity=dissimilarity, norm=norm, return_obj=True, G_init=G_inits[i], b_distribution=distributions[i], backend=backend, use_gpu=use_gpu, verbose=verbose, gpu_verbose=False) pis.append(p) r.append(r_q) return pis, np.array(r) def center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, dissimilarity='kl', max_iter=200, verbose=False): print('Solving Center Mapping NMF Problem.') n = W.shape[0] B = n * sum([lmbda[i] * np.dot(pis[i], to_dense_array(slices[i].exp_matrix)) for i in range(len(slices))]) if dissimilarity.lower() == 'euclidean' or dissimilarity.lower() == 'euc': model = NMF(n_components=n_components, init='random', random_state=random_seed, max_iter=max_iter, verbose=verbose) else: model = NMF(n_components=n_components, solver='mu', beta_loss='kullback-leibler', init='random', random_state=random_seed, max_iter=max_iter, verbose=verbose) W_new = model.fit_transform(B) H_new = model.components_ return W_new, H_new def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init = None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False,numItermax=200, tol_rel=1e-9, tol_abs=1e-9, use_gpu = False, **kwargs): """ Adapted fused_gromov_wasserstein with the added capability of defining a G_init (inital mapping). Also added capability of utilizing different POT backends to speed up computation. For more info, see: https://pythonot.github.io/gen_modules/ot.gromov.html """ p, q = ot.utils.list_to_array(p, q) p0, q0, C10, C20, M0 = p, q, C1, C2, M nx = ot.backend.get_backend(p0, q0, C10, C20, M0) constC, hC1, hC2 = ot.gromov.init_matrix(C1, C2, p, q, loss_fun) if G_init is None: G0 = p[:, None] * q[None, :] else: G0 = (1/nx.sum(G_init)) * G_init if use_gpu: G0 = G0.cuda() def f(G): return ot.gromov.gwloss(constC, hC1, hC2, G) def df(G): return ot.gromov.gwggrad(constC, hC1, hC2, G) if loss_fun == 'kl_loss': armijo = True # there is no closed form line-search with KL if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs) else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=nx, **kwargs) if log: res, log = ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=numItermax, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) fgw_dist = log['loss'][-1] log['fgw_dist'] = fgw_dist log['u'] = log['u'] log['v'] = log['v'] return res, log else: return ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, numItermax=numItermax, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ Solve the linesearch in the FW iterations Parameters ---------- G : array-like, shape(ns,nt) The transport map at a given iteration of the FW deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration cost_G : float Value of the cost at `G` C1 : array-like (ns,ns), optional Structure matrix in the source domain. C2 : array-like (nt,nt), optional Structure matrix in the target domain. M : array-like (ns,nt) Cost matrix between the features. reg : float Regularization parameter. alpha_min : float, optional Minimum value for alpha alpha_max : float, optional Maximum value for alpha nx : backend, optional If let to its default value None, a backend test will be conducted. Returns ------- alpha : float The optimal step size of the FW fc : int nb of function call. Useless here cost_G : float The value of the cost for the next iteration .. _references-solve-linesearch: References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if nx is None: G, deltaG, C1, C2, M = ot.utils.list_to_array(G, deltaG, C1, C2, M) if isinstance(M, int) or isinstance(M, float): nx = ot.backend.get_backend(G, deltaG, C1, C2) else: nx = ot.backend.get_backend(G, deltaG, C1, C2, M) dot = nx.dot(nx.dot(C1, deltaG), C2.T) a = -2 * reg * nx.sum(dot * deltaG) b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG)) alpha = ot.optim.solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) # the new cost is deduced from the line search quadratic function cost_G = cost_G + a * (alpha ** 2) + b * alpha return alpha, 1, cost_G