diff --git a/docs/release-notes/0.10.12.md b/docs/release-notes/0.11.0.md similarity index 89% rename from docs/release-notes/0.10.12.md rename to docs/release-notes/0.11.0.md index bc43d0a8..5a55807c 100644 --- a/docs/release-notes/0.10.12.md +++ b/docs/release-notes/0.11.0.md @@ -1,4 +1,4 @@ -### 0.10.12 {small}`the-future` +### 0.11.0 {small}`the-future` ```{rubric} Features ``` diff --git a/docs/release-notes/index.md b/docs/release-notes/index.md index 1faf4152..65f29b12 100644 --- a/docs/release-notes/index.md +++ b/docs/release-notes/index.md @@ -2,9 +2,11 @@ # Release notes -## Version 0.10.0 -```{include} /release-notes/0.10.12.md +## Version 0.11.0 +```{include} /release-notes/0.11.0.md ``` + +## Version 0.10.0 ```{include} /release-notes/0.10.11.md ``` ```{include} /release-notes/0.10.10.md diff --git a/pyproject.toml b/pyproject.toml index 9a5612e8..46b0e92c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ lint.ignore = [ "src/rapids_singlecell/preprocessing/_harmonypy_gpu.py" = ["PLR0917"] "src/rapids_singlecell/decoupler_gpu/_method_mlm.py" = ["PLR0917"] "src/rapids_singlecell/decoupler_gpu/_method_wsum.py" = ["PLR0917"] +"src/rapids_singlecell/preprocessing/_harmony/__init__.py" = ["PLR0917"] [tool.ruff.lint.isort] known-first-party = ["rapids_singlecell"] diff --git a/src/rapids_singlecell/preprocessing/_harmony/__init__.py b/src/rapids_singlecell/preprocessing/_harmony/__init__.py new file mode 100644 index 00000000..7dee07bd --- /dev/null +++ b/src/rapids_singlecell/preprocessing/_harmony/__init__.py @@ -0,0 +1,446 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import cupy as cp +import numpy as np +from cuml import KMeans as cumlKMeans + +from ._fuses import _calc_R, _get_factor, _get_pen, _log_div_OE, _R_multi_m +from ._kernels._normalize import _get_normalize_kernel_optimized + +if TYPE_CHECKING: + import pandas as pd + + +def _normalize_cp_p1(X: cp.ndarray) -> cp.ndarray: + """ + Normalize rows of a matrix using an optimized kernel with shared memory and warp shuffle. + + Parameters: + X (cp.ndarray): Input 2D array. + + Returns: + cp.ndarray: Row-normalized 2D array. + """ + assert X.ndim == 2, "Input must be a 2D array." + + rows, cols = X.shape + + # Fixed block size of 32 + block_dim = 32 + grid_dim = rows # One block per row + + normalize_p1 = _get_normalize_kernel_optimized(X.dtype) + # Launch the kernel + normalize_p1((grid_dim,), (block_dim,), (X, rows, cols)) + return X + + +def _normalize_cp(X: cp.ndarray, p: int = 2) -> cp.ndarray: + if p == 2: + return X / cp.linalg.norm(X, ord=2, axis=1, keepdims=True).clip(min=1e-12) + + else: + return _normalize_cp_p1(X) + + +def _get_batch_codes(batch_mat: pd.DataFrame, batch_key: str | list[str]) -> pd.Series: + if type(batch_key) is str: + batch_vec = batch_mat[batch_key] + + elif len(batch_key) == 1: + batch_key = batch_key[0] + + batch_vec = batch_mat[batch_key] + + else: + df = batch_mat[batch_key].astype("str") + batch_vec = df.apply(lambda row: ",".join(row), axis=1) + + return batch_vec.astype("category") + + +def _one_hot_tensor_cp(X: pd.Series) -> cp.array: + ids = cp.array(X.cat.codes.values.copy(), dtype=cp.int32).reshape(-1) + n_col = X.cat.categories.size + Phi = cp.eye(n_col)[ids] + + return Phi + + +def harmonize( + Z: cp.array, + batch_mat: pd.DataFrame, + batch_key: str | list[str], + *, + n_clusters: int = None, + max_iter_harmony: int = 10, + max_iter_clustering: int = 200, + tol_harmony: float = 1e-4, + tol_clustering: float = 1e-5, + ridge_lambda: float = 1.0, + sigma: float = 0.1, + block_proportion: float = 0.05, + theta: float = 2.0, + tau: int = 0, + correction_method: str = "fast", + random_state: int = 0, + verbose: bool = True, +) -> cp.array: + """ + Integrate data using Harmony algorithm. + + Parameters + ---------- + + X + The input embedding with rows for cells (N) and columns for embedding coordinates (d). + + batch_mat + The cell barcode information as data frame, with rows for cells (N) and columns for cell attributes. + + batch_key + Cell attribute(s) from ``batch_mat`` to identify batches. + + n_clusters + Number of clusters used in Harmony algorithm. If ``None``, choose the minimum of 100 and N / 30. + + max_iter_harmony + Maximum iterations on running Harmony if not converged. + + max_iter_clustering + Within each Harmony iteration, maximum iterations on the clustering step if not converged. + + tol_harmony + Tolerance on justifying convergence of Harmony over objective function values. + + tol_clustering + Tolerance on justifying convergence of the clustering step over objective function values within each Harmony iteration. + + ridge_lambda + Hyperparameter of ridge regression on the correction step. + + sigma + Weight of the entropy term in objective function. + + block_proportion + Proportion of block size in one update operation of clustering step. + + theta + Weight of the diversity penalty term in objective function. + + tau + Discounting factor on ``theta``. By default, there is no discounting. + + correction_method + Choose which method for the correction step: ``original`` for original method, ``fast`` for improved method. By default, use improved method. + + random_state + Random seed for reproducing results. + + verbose + If ``True``, print verbose output. + + Returns + ------- + The integrated embedding by Harmony, of the same shape as the input embedding. + + Examples + -------- + >>> adata = anndata.read_h5ad("filename.h5ad") + >>> X_harmony = harmonize(adata.obsm['X_pca'], adata.obs, 'Channel') + + >>> adata = anndata.read_h5ad("filename.h5ad") + >>> X_harmony = harmonize(adata.obsm['X_pca'], adata.obs, ['Channel', 'Lab']) + """ + + Z_norm = _normalize_cp(Z) + n_cells = Z.shape[0] + + batch_codes = _get_batch_codes(batch_mat, batch_key) + n_batches = batch_codes.cat.categories.size + N_b = cp.array(batch_codes.value_counts(sort=False).values, dtype=Z.dtype) + Pr_b = (N_b.reshape(-1, 1) / len(batch_codes)).astype(Z.dtype) + + Phi = _one_hot_tensor_cp(batch_codes).astype(Z.dtype) + if n_clusters is None: + n_clusters = int(min(100, n_cells / 30)) + theta = (cp.ones(n_batches) * theta).astype(Z.dtype) + + if tau > 0: + theta = theta * (1 - cp.exp(-N_b / (n_clusters * tau)) ** 2) + + theta = theta.reshape(1, -1) + assert block_proportion > 0 and block_proportion <= 1 + assert correction_method in {"fast", "original"} + + cp.random.seed(random_state) + + # Initialize centroids + R, E, O, objectives_harmony = _initialize_centroids( + Z_norm, + n_clusters, + sigma, + Pr_b, + Phi, + theta, + random_state, + ) + if verbose: + print("Initialization is completed.") + + for i in range(max_iter_harmony): + _clustering( + Z_norm, + Pr_b, + Phi, + R, + E, + O, + n_clusters, + theta, + tol_clustering, + objectives_harmony, + max_iter_clustering, + sigma, + block_proportion, + ) + + Z_hat = _correction(Z, R, Phi, O, ridge_lambda, correction_method) + Z_norm = _normalize_cp(Z_hat, p=2) + if verbose: + print(f"\tCompleted {i + 1} / {max_iter_harmony} iteration(s).") + + if _is_convergent_harmony(objectives_harmony, tol=tol_harmony): + if verbose: + print(f"Reach convergence after {i + 1} iteration(s).") + break + + return Z_hat + + +def _initialize_centroids( + Z_norm: cp.ndarray, + n_clusters: int, + sigma: float, + Pr_b: cp.ndarray, + Phi: cp.ndarray, + theta: cp.ndarray, + random_state: int, + n_init: int = 10, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, list]: + kmeans = cumlKMeans( + n_clusters=n_clusters, init="k-means||", max_iter=25, random_state=random_state + ) + kmeans.fit(Z_norm) + Y = kmeans.cluster_centers_.astype(Z_norm.dtype) + Y_norm = _normalize_cp(Y, p=2) + + # Initialize R + R = _calc_R(-2 / sigma, cp.dot(Z_norm, Y_norm.T)) + R = _normalize_cp(R, p=1) + + E = cp.dot(Pr_b, cp.sum(R, axis=0, keepdims=True)) + O = cp.dot(Phi.T, R) + + objectives_harmony = [] + _compute_objective(Y_norm, Z_norm, R, theta, sigma, O, E, objectives_harmony) + + return R, E, O, objectives_harmony + + +def _clustering( + Z_norm: cp.ndarray, + Pr_b: cp.ndarray, + Phi: cp.ndarray, + R: cp.ndarray, + E: cp.ndarray, + O: cp.ndarray, + n_clusters: int, + theta: cp.ndarray, + tol: float, + objectives_harmony: list, + max_iter: int, + sigma: float, + block_proportion: float, +): + n_cells = Z_norm.shape[0] + objectives_clustering = [] + block_size = int(n_cells * block_proportion) + term = -2 / sigma + for _ in range(max_iter): + # Compute Cluster Centroids + + Y = cp.dot(R.T, Z_norm) # Compute centroids + Y_norm = _normalize_cp(Y, p=2) # Normalize centroids + + idx_list = cp.arange(n_cells) + cp.random.shuffle(idx_list) + pos = 0 + while pos < len(idx_list): + idx_in = idx_list[pos : (pos + block_size)] + R_in = R[idx_in] # Slice rows for R + Phi_in = Phi[idx_in] # Slice rows for Phi + # O-=Phi_in.T@R_in + cp.cublas.gemm("T", "N", Phi_in, R_in, alpha=-1, beta=1, out=O) + # E-=Pr_b@R_in + cp.cublas.gemm( + "N", + "N", + Pr_b, + cp.sum(R_in, axis=0, keepdims=True), + alpha=-1, + beta=1, + out=E, + ) + + # Update and Normalize R + R_out = _calc_R(term, cp.dot(Z_norm[idx_in], Y_norm.T)) + + # Precompute penalty term and apply + penalty_term = _get_pen(E, O, theta.T) + omega = cp.dot(Phi_in, penalty_term) + R_out *= omega + + # Normalize R_out and update R + R_out = _normalize_cp(R_out, p=1) + + R[idx_in] = R_out + + # Compute O and E with full data using precomputed terms + # O+=Phi_in.T@R_in + cp.cublas.gemm("T", "N", Phi_in, R_out, alpha=1, beta=1, out=O) + # E+=Pr_b@R_in + cp.cublas.gemm( + "N", + "N", + Pr_b, + cp.sum(R_out, axis=0, keepdims=True), + alpha=1, + beta=1, + out=E, + ) + pos += block_size + _compute_objective(Y_norm, Z_norm, R, theta, sigma, O, E, objectives_clustering) + + if _is_convergent_clustering(objectives_clustering, tol): + objectives_harmony.append(objectives_clustering[-1]) + break + + +def _correction( + X: cp.ndarray, + R: cp.ndarray, + Phi: cp.ndarray, + O: cp.ndarray, + ridge_lambda: float, + correction_method: str, +) -> cp.ndarray: + if correction_method == "fast": + return _correction_fast(X, R, Phi, O, ridge_lambda) + else: + return _correction_original(X, R, Phi, ridge_lambda) + + +def _correction_original( + X: cp.ndarray, R: cp.ndarray, Phi: cp.ndarray, ridge_lambda: float +) -> cp.ndarray: + n_cells = X.shape[0] + n_clusters = R.shape[1] + n_batches = Phi.shape[1] + Phi_1 = cp.concatenate((cp.ones((n_cells, 1), dtype=X.dtype), Phi), axis=1) + + Z = X.copy() + id_mat = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype) + id_mat[0, 0] = 0 + Lambda = ridge_lambda * id_mat + for k in range(n_clusters): + Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1) + inv_mat = cp.linalg.inv(cp.dot(Phi_t_diag_R, Phi_1) + Lambda) + W = cp.dot(inv_mat, cp.dot(Phi_t_diag_R, X)) + W[0, :] = 0 + Z -= cp.dot(Phi_t_diag_R.T, W) + + return Z + + +def _correction_fast( + X: cp.ndarray, R: cp.ndarray, Phi: cp.ndarray, O: cp.ndarray, ridge_lambda: float +) -> cp.ndarray: + n_cells = X.shape[0] + n_clusters = R.shape[1] + n_batches = Phi.shape[1] + Phi_1 = cp.concatenate((cp.ones((n_cells, 1), dtype=X.dtype), Phi), axis=1) + Z = X.copy() + P = cp.eye(n_batches + 1, n_batches + 1, dtype=X.dtype) + for k in range(n_clusters): + O_k = O[:, k] + N_k = cp.sum(O_k) + + factor = _get_factor(O_k, ridge_lambda) + c = N_k + cp.sum(-factor * O_k**2) + c_inv = 1 / c + + P[0, 1:] = -factor * O_k + + P_t_B_inv = cp.zeros((factor.size + 1, factor.size + 1), dtype=X.dtype) + + # Set diagonal entries + P_t_B_inv[0, 0] = c_inv + P_t_B_inv[1:, 1:] = cp.diag(factor) + + # Set off-diagonal entries + P_t_B_inv[1:, 0] = P[0, 1:] * c_inv + inv_mat = cp.dot(P_t_B_inv, P) + + Phi_t_diag_R = Phi_1.T * R[:, k].reshape(1, -1) + W = cp.dot(inv_mat, cp.dot(Phi_t_diag_R, X)) + W[0, :] = 0 + + Z -= cp.dot(Phi_t_diag_R.T, W) + + return Z + + +def _compute_objective( + Y_norm: cp.ndarray, + Z_norm: cp.ndarray, + R: cp.ndarray, + theta: cp.ndarray, + sigma: float, + O: cp.ndarray, + E: cp.ndarray, + objective_arr: list, +) -> None: + kmeans_error = cp.sum(_R_multi_m(R, cp.dot(Z_norm, Y_norm.T))) + R = R / R.sum(axis=1, keepdims=True) + entropy = cp.sum(R * cp.log(R + 1e-12)) + entropy_term = sigma * entropy + diversity_penalty = sigma * cp.sum(cp.dot(theta, _log_div_OE(O, E))) + objective = kmeans_error + entropy_term + diversity_penalty + objective_arr.append(objective) + + +def _is_convergent_harmony(objectives_harmony: list, tol: float) -> bool: + if len(objectives_harmony) < 2: + return False + + obj_old = objectives_harmony[-2] + obj_new = objectives_harmony[-1] + + return (obj_old - obj_new) < tol * np.abs(obj_old) + + +def _is_convergent_clustering( + objectives_clustering: list, tol: list, window_size: int = 3 +) -> bool: + if len(objectives_clustering) < window_size + 1: + return False + obj_old = 0.0 + obj_new = 0.0 + for i in range(window_size): + obj_old += objectives_clustering[-2 - i] + obj_new += objectives_clustering[-1 - i] + + return (obj_old - obj_new) < tol * np.abs(obj_old) diff --git a/src/rapids_singlecell/preprocessing/_harmony/_fuses.py b/src/rapids_singlecell/preprocessing/_harmony/_fuses.py new file mode 100644 index 00000000..f7059c2b --- /dev/null +++ b/src/rapids_singlecell/preprocessing/_harmony/_fuses.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import cupy as cp + + +@cp.fuse +def _get_factor(O_k: cp.ndarray, ridge_lambda: float) -> cp.ndarray: + return 1 / (O_k + ridge_lambda) + + +@cp.fuse +def _get_pen(E: cp.ndarray, O: cp.ndarray, theta: cp.ndarray) -> cp.ndarray: + return cp.power(cp.divide(E + 1, O + 1), theta) + + +@cp.fuse +def _calc_R(term: cp.ndarray, mm: cp.ndarray) -> cp.ndarray: + return cp.exp(term * (1 - mm)) + + +@cp.fuse +def _log_div_OE(O: cp.ndarray, E: cp.ndarray) -> cp.ndarray: + return O * cp.log((O + 1) / (E + 1)) + + +@cp.fuse +def _R_multi_m(R: cp.ndarray, other: cp.ndarray) -> cp.ndarray: + return R * 2 * (1 - other) diff --git a/src/rapids_singlecell/preprocessing/_harmony/_kernels/_normalize.py b/src/rapids_singlecell/preprocessing/_harmony/_kernels/_normalize.py new file mode 100644 index 00000000..45bee2f0 --- /dev/null +++ b/src/rapids_singlecell/preprocessing/_harmony/_kernels/_normalize.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from cuml.common.kernel_utils import cuda_kernel_factory + +normalize_kernel_optimized = r""" +({0} * X, int rows, int cols) { + __shared__ {0} shared[32]; // Shared memory for partial sums (one per thread) + + int row = blockIdx.x; // One block per row + int tid = threadIdx.x; // Thread index within the block + + // Ensure we're within matrix bounds + if (row >= rows) return; + + // Step 1: Compute partial sums within each thread + {0} norm = 0.0; + for (int col = tid; col < cols; col += blockDim.x) { + norm += fabs(X[row * cols + col]);// Manhattan norm + + } + + // Store partial sum in shared memory + shared[tid] = norm; + __syncthreads(); + + // Step 2: Perform shared memory reduction using warp shuffle + #pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + shared[tid] += __shfl_down_sync(0xFFFFFFFF, shared[tid], offset); + } + __syncthreads(); + + // First thread calculates the final norm + if (tid == 0) { + {0} final_norm = shared[0]; + final_norm = fmaxf(final_norm, 1e-12); + shared[0] = 1.0 / final_norm; // Store reciprocal for normalization + } + __syncthreads(); + + // Step 3: Normalize the row + for (int col = tid; col < cols; col += blockDim.x) { + X[row * cols + col] *= shared[0]; + } +} +""" + + +def _get_normalize_kernel_optimized(dtype): + return cuda_kernel_factory( + normalize_kernel_optimized, (dtype,), "normalize_kernel_optimized" + ) diff --git a/src/rapids_singlecell/preprocessing/_harmony_integrate.py b/src/rapids_singlecell/preprocessing/_harmony_integrate.py index 61c33d1c..0b3cc419 100644 --- a/src/rapids_singlecell/preprocessing/_harmony_integrate.py +++ b/src/rapids_singlecell/preprocessing/_harmony_integrate.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal +import cupy as cp import numpy as np if TYPE_CHECKING: @@ -15,6 +16,7 @@ def harmony_integrate( basis: str = "X_pca", adjusted_basis: str = "X_pca_harmony", dtype: type = np.float64, + correction_method: Literal["fast", "original"] = "original", **kwargs, ) -> None: """ @@ -43,6 +45,8 @@ def harmony_integrate( dtype The data type to use for the Harmony. If you use 32-bit you may experience numerical instability. + correction_method + Choose which method for the correction step: ``original`` for original method, ``fast`` for improved method. kwargs Any additional arguments will be passed to ``harmonpy_gpu.run_harmony()``. @@ -54,10 +58,13 @@ def harmony_integrate( different experiments are integrated. """ - from . import _harmonypy_gpu + from ._harmony import harmonize X = adata.obsm[basis].astype(dtype) + if isinstance(X, np.ndarray): + X = cp.array(X) + harmony_out = harmonize( + X, adata.obs, key, correction_method=correction_method, **kwargs + ) - harmony_out = _harmonypy_gpu.run_harmony(X, adata.obs, key, dtype=dtype, **kwargs) - - adata.obsm[adjusted_basis] = harmony_out.Z_corr.T.get() + adata.obsm[adjusted_basis] = harmony_out.get() diff --git a/src/rapids_singlecell/preprocessing/_harmonypy_gpu.py b/src/rapids_singlecell/preprocessing/_harmonypy_gpu.py deleted file mode 100644 index 019eb7c7..00000000 --- a/src/rapids_singlecell/preprocessing/_harmonypy_gpu.py +++ /dev/null @@ -1,378 +0,0 @@ -# harmonypy - A data alignment algorithm. -# Copyright (C) 2018 Ilya Korsunsky -# 2019 Kamil Slowikowski -# 2022 Severin Dicks -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -import logging - -import cupy as cp -import numpy as np -import pandas as pd -from cuml import KMeans - -# create logger -logger = logging.getLogger("harmonypy_gpu") -logger.setLevel(logging.DEBUG) -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -ch.setFormatter(formatter) -logger.addHandler(ch) - - -# from IPython.core.debugger import set_trace -def run_harmony( - data_mat: np.ndarray, - meta_data: pd.DataFrame, - vars_use, - *, - theta=None, - lamb=None, - sigma=0.1, - nclust=None, - tau=0, - block_size=0.05, - max_iter_harmony=10, - max_iter_kmeans=20, - epsilon_cluster=1e-5, - epsilon_harmony=1e-4, - plot_convergence=False, - verbose=True, - reference_values=None, - cluster_prior=None, - random_state=0, - dtype=cp.float64, -): - """Run Harmony.""" - # theta = None - # lamb = None - # sigma = 0.1 - # nclust = None - # tau = 0 - # block_size = 0.05 - # epsilon_cluster = 1e-5 - # epsilon_harmony = 1e-4 - # plot_convergence = False - # verbose = True - # reference_values = None - # cluster_prior = None - # random_state = 0 - - N = meta_data.shape[0] - if data_mat.shape[1] != N: - data_mat = data_mat.T - - assert ( - data_mat.shape[1] == N - ), "data_mat and meta_data do not have the same number of cells" - - if nclust is None: - nclust = np.min([np.round(N / 30.0), 100]).astype(int) - - if isinstance(sigma, float) and nclust > 1: - sigma = np.repeat(sigma, nclust) - - if isinstance(vars_use, str): - vars_use = [vars_use] - - phi = pd.get_dummies(meta_data[vars_use]).to_numpy().T - phi_n = meta_data[vars_use].describe().loc["unique"].to_numpy().astype(int) - - if theta is None: - theta = np.repeat([1] * len(phi_n), phi_n) - elif isinstance(theta, float) or isinstance(theta, int): - theta = np.repeat([theta] * len(phi_n), phi_n) - elif len(theta) == len(phi_n): - theta = np.repeat([theta], phi_n) - - assert len(theta) == np.sum(phi_n), "each batch variable must have a theta" - - if lamb is None: - lamb = np.repeat([1] * len(phi_n), phi_n) - elif isinstance(lamb, float) or isinstance(lamb, int): - lamb = np.repeat([lamb] * len(phi_n), phi_n) - elif len(lamb) == len(phi_n): - lamb = np.repeat([lamb], phi_n) - - assert len(lamb) == np.sum(phi_n), "each batch variable must have a lambda" - - # Number of items in each category. - N_b = phi.sum(axis=1) - # Proportion of items in each category. - Pr_b = N_b / N - - if tau > 0: - theta = theta * (1 - np.exp(-((N_b / (nclust * tau)) ** 2))) - - lamb_mat = np.diag(np.insert(lamb, 0, 0)) - - phi_moe = np.vstack((np.repeat(1, N), phi)) - - cp.random.seed(random_state) - - ho = Harmony( - data_mat, - phi, - phi_moe, - Pr_b, - sigma, - theta, - max_iter_harmony, - max_iter_kmeans, - epsilon_cluster, - epsilon_harmony, - nclust, - block_size, - lamb_mat, - verbose, - random_state, - dtype=dtype, - ) - - return ho - - -class Harmony: - def __init__( - self, - Z, - Phi, - Phi_moe, - Pr_b, - sigma, - theta, - max_iter_harmony, - max_iter_kmeans, - epsilon_kmeans, - epsilon_harmony, - K, - block_size, - lamb, - verbose, - random_state, - dtype, - ): - self.Z_corr = cp.array(Z, dtype=dtype) - self.Z_orig = cp.array(Z, dtype=dtype) - - self.Z_cos = self.Z_orig / self.Z_orig.max(axis=0) - self.Z_cos = self.Z_cos / cp.linalg.norm(self.Z_cos, ord=2, axis=0) - - self.Phi = cp.array(Phi, dtype=dtype) - self.Phi_moe = cp.array(Phi_moe, dtype=dtype) - self.N = self.Z_corr.shape[1] - self.Pr_b = cp.array(Pr_b, dtype=dtype) - self.B = self.Phi.shape[0] # number of batch variables - self.d = self.Z_corr.shape[0] - self.window_size = 3 - self.epsilon_kmeans = epsilon_kmeans - self.epsilon_harmony = epsilon_harmony - - self.lamb = cp.array(lamb, dtype=dtype) - self.sigma = cp.array(sigma, dtype=dtype) - self.sigma_prior = cp.array(sigma, dtype=dtype) - self.block_size = block_size - self.K = K # number of clusters - self.max_iter_harmony = max_iter_harmony - self.max_iter_kmeans = max_iter_kmeans - self.verbose = verbose - self.theta = cp.array(theta, dtype=dtype) - self.random_state = random_state - - self.objective_harmony = [] - self.objective_kmeans = [] - self.objective_kmeans_dist = [] - self.objective_kmeans_entropy = [] - self.objective_kmeans_cross = [] - self.kmeans_rounds = [] - self.dtype = dtype - - self.allocate_buffers() - self.init_cluster() - self.harmonize(self.max_iter_harmony, self.verbose) - - def result(self): - return self.Z_corr - - def allocate_buffers(self): - self._scale_dist = cp.zeros((self.K, self.N), dtype=self.dtype) - self.dist_mat = cp.zeros((self.K, self.N), dtype=self.dtype) - self.O = cp.zeros((self.K, self.B), dtype=self.dtype) - self.E = cp.zeros((self.K, self.B), dtype=self.dtype) - self.W = cp.zeros((self.B + 1, self.d), dtype=self.dtype) - self.Phi_Rk = cp.zeros((self.B + 1, self.N), dtype=self.dtype) - - def init_cluster(self): - # Start with cluster centroids - kmeans_obj = KMeans( - n_clusters=self.K, random_state=self.random_state, init="k-means||" - ).fit(self.Z_cos.T) - self.Y = kmeans_obj.cluster_centers_.T - # (1) Normalize - self.Y = self.Y / cp.linalg.norm(self.Y, ord=2, axis=0) - # (2) Assign cluster probabilities - self.dist_mat = 2 * (1 - cp.dot(self.Y.T, self.Z_cos)) - self.R = -self.dist_mat - self.R = self.R / self.sigma[:, None] - self.R -= cp.max(self.R, axis=0) - self.R = cp.exp(self.R) - self.R = self.R / cp.sum(self.R, axis=0) - # (3) Batch diversity statistics - self.E = cp.outer(cp.sum(self.R, axis=1), self.Pr_b) - self.O = cp.inner(self.R, self.Phi) - self.compute_objective() - # Save results - self.objective_harmony.append(self.objective_kmeans[-1]) - - def compute_objective(self): - kmeans_error = cp.sum(cp.multiply(self.R, self.dist_mat)) - # Entropy - _entropy = cp.sum(safe_entropy(self.R) * self.sigma[:, cp.newaxis]) - # Cross Entropy - x = self.R * self.sigma[:, cp.newaxis] - y = cp.tile(self.theta[:, cp.newaxis], self.K).T - z = cp.log((self.O + 1) / (self.E + 1)) - w = cp.dot(y * z, self.Phi) - _cross_entropy = cp.sum(x * w) - # Save results - self.objective_kmeans.append(kmeans_error + _entropy + _cross_entropy) - self.objective_kmeans_dist.append(kmeans_error) - self.objective_kmeans_entropy.append(_entropy) - self.objective_kmeans_cross.append(_cross_entropy) - - def harmonize(self, iter_harmony=10, verbose=True): - converged = False - for i in range(1, iter_harmony + 1): - if verbose: - logger.info(f"Iteration {i} of {iter_harmony}") - # STEP 1: Clustering - self.cluster() - # STEP 2: Regress out covariates - # self.moe_correct_ridge() - self.Z_cos, self.Z_corr, self.W, self.Phi_Rk = moe_correct_ridge( - self.Z_orig, - self.Z_cos, - self.Z_corr, - self.R, - self.W, - self.K, - self.Phi_Rk, - self.Phi_moe, - self.lamb, - ) - # STEP 3: Check for convergence - converged = self.check_convergence(1) - if converged: - if verbose: - logger.info( - "Converged after {} iteration{}".format(i, "s" if i > 1 else "") - ) - break - if verbose and not converged: - logger.info("Stopped before convergence") - return 0 - - def cluster(self): - # Z_cos has changed - # R is assumed to not have changed - # Update Y to match new integrated data - self.dist_mat = 2 * (1 - cp.dot(self.Y.T, self.Z_cos)) - for i in range(self.max_iter_kmeans): - # print("kmeans {}".format(i)) - # STEP 1: Update Y - self.Y = cp.dot(self.Z_cos, self.R.T) - self.Y = self.Y / cp.linalg.norm(self.Y, ord=2, axis=0) - # STEP 2: Update dist_mat - self.dist_mat = 2 * (1 - cp.dot(self.Y.T, self.Z_cos)) - # STEP 3: Update R - self.update_R() - # STEP 4: Check for convergence - self.compute_objective() - if i > self.window_size: - converged = self.check_convergence(0) - if converged: - break - self.kmeans_rounds.append(i) - self.objective_harmony.append(self.objective_kmeans[-1]) - return 0 - - def update_R(self): - self._scale_dist = -self.dist_mat - self._scale_dist = self._scale_dist / self.sigma[:, None] - self._scale_dist -= cp.max(self._scale_dist, axis=0) - self._scale_dist = cp.exp(self._scale_dist) - # Update cells in blocks - update_order = cp.arange(self.N) - cp.random.shuffle(update_order) - n_blocks = cp.ceil(1 / self.block_size).astype(int) - blocks = cp.array_split(update_order, int(n_blocks)) - for b in blocks: - # STEP 1: Remove cells - self.E -= cp.outer(cp.sum(self.R[:, b], axis=1), self.Pr_b) - self.O -= cp.dot(self.R[:, b], self.Phi[:, b].T) - # STEP 2: Recompute R for removed cells - self.R[:, b] = self._scale_dist[:, b] - self.R[:, b] = cp.multiply( - self.R[:, b], - cp.dot( - cp.power((self.E + 1) / (self.O + 1), self.theta), self.Phi[:, b] - ), - ) - self.R[:, b] = self.R[:, b] / cp.linalg.norm(self.R[:, b], ord=1, axis=0) - # STEP 3: Put cells back - self.E += cp.outer(cp.sum(self.R[:, b], axis=1), self.Pr_b) - self.O += cp.dot(self.R[:, b], self.Phi[:, b].T) - return 0 - - def check_convergence(self, i_type): - obj_old = 0.0 - obj_new = 0.0 - # Clustering, compute new window mean - if i_type == 0: - okl = len(self.objective_kmeans) - for i in range(self.window_size): - obj_old += self.objective_kmeans[okl - 2 - i] - obj_new += self.objective_kmeans[okl - 1 - i] - if abs(obj_old - obj_new) / abs(obj_old) < self.epsilon_kmeans: - return True - return False - # Harmony - if i_type == 1: - obj_old = self.objective_harmony[-2] - obj_new = self.objective_harmony[-1] - if (obj_old - obj_new) / abs(obj_old) < self.epsilon_harmony: - return True - return False - return True - - -def safe_entropy(x: cp.array): - y = cp.multiply(x, cp.log(x)) - y[~cp.isfinite(y)] = 0.0 - return y - - -def moe_correct_ridge(Z_orig, Z_cos, Z_corr, R, W, K, Phi_Rk, Phi_moe, lamb): - Z_corr = Z_orig.copy() - for i in range(K): - Phi_Rk = cp.multiply(Phi_moe, R[i, :]) - x = cp.dot(Phi_Rk, Phi_moe.T) + lamb - W = cp.dot(cp.dot(cp.linalg.inv(x), Phi_Rk), Z_orig.T) - W[0, :] = 0 # do not remove the intercept - Z_corr -= cp.dot(W.T, Phi_Rk) - Z_cos = Z_corr / cp.linalg.norm(Z_corr, ord=2, axis=0) - return Z_cos, Z_corr, W, Phi_Rk