From 51acb4e382ebe74bb164f4167d0d5562f41eb291 Mon Sep 17 00:00:00 2001 From: Yi Wang <37149810+yiwang12@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:35:16 +0100 Subject: [PATCH] Update process_multiSample.py --- mNSF/process_multiSample.py | 90 +++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 29 deletions(-) diff --git a/mNSF/process_multiSample.py b/mNSF/process_multiSample.py index beec292..6b40219 100644 --- a/mNSF/process_multiSample.py +++ b/mNSF/process_multiSample.py @@ -165,44 +165,76 @@ def get_chunked_data(X, Y, nchunk, method='random'): nspot = X.shape[0] D_unchunked = get_D(X, Y) - if method == 'spatial': - # Calculate grid size for approximately equal chunks - grid_size = int(np.sqrt(nchunk)) - x_coords = X.iloc[:, 0].values - y_coords = X.iloc[:, 1].values + if method == 'balanced_kmeans': + from sklearn.cluster import KMeans + coords = X.iloc[:, :2].values - # Calculate quantiles instead of linear bins - x_quantiles = np.percentile(x_coords, np.linspace(0, 100, grid_size + 1)) - y_quantiles = np.percentile(y_coords, np.linspace(0, 100, grid_size + 1)) + # Initialize with regular k-means + kmeans = KMeans(n_clusters=nchunk, random_state=42) + labels = kmeans.fit_predict(coords) + centers = kmeans.cluster_centers_ - # Assign spots to spatial bins - chunk_id = [] - for i in range(nspot): - x_bin = np.digitize(x_coords[i], x_quantiles) - 1 - y_bin = np.digitize(y_coords[i], y_quantiles) - 1 - if x_bin >= grid_size: x_bin = grid_size - 1 - if y_bin >= grid_size: y_bin = grid_size - 1 - chunk_id.append(x_bin * grid_size + y_bin) - - chunk_id = np.array(chunk_id) + # Target size for each cluster + target_size = nspot // nchunk - # Ensure we don't exceed nchunk - if len(np.unique(chunk_id)) > nchunk: - unique_counts = np.bincount(chunk_id) - valid_chunks = np.argsort(unique_counts)[-nchunk:] - mask = np.isin(chunk_id, valid_chunks) - chunk_id[~mask] = valid_chunks[0] + # Iteratively balance clusters + max_iter = 20 + for _ in range(max_iter): + # Calculate distances to all centers + distances = np.zeros((nspot, nchunk)) + for i in range(nchunk): + distances[:, i] = np.sum((coords - centers[i])**2, axis=1) + + # Sort points by distance to assigned cluster + cluster_sizes = np.bincount(labels, minlength=nchunk) + new_labels = np.zeros(nspot, dtype=int) + + # For each cluster + unassigned = np.ones(nspot, dtype=bool) + + # First, handle clusters that are too small + for i in range(nchunk): + if cluster_sizes[i] < target_size: + # Get closest unassigned points + dist_to_cluster = distances[unassigned, i] + n_needed = target_size - cluster_sizes[i] + closest = np.argsort(dist_to_cluster)[:n_needed] + unassigned_indices = np.where(unassigned)[0] + points_to_assign = unassigned_indices[closest] + new_labels[points_to_assign] = i + unassigned[points_to_assign] = False + + # Then assign remaining points to closest clusters that aren't full + remaining_points = np.where(unassigned)[0] + for point in remaining_points: + # Find closest cluster that isn't full + dist_to_clusters = distances[point] + for closest in np.argsort(dist_to_clusters): + if np.sum(new_labels == closest) < target_size: + new_labels[point] = closest + break + + # Update centers + old_centers = centers.copy() + for i in range(nchunk): + if np.sum(new_labels == i) > 0: + centers[i] = coords[new_labels == i].mean(axis=0) + + # Check for convergence + if np.allclose(old_centers, centers): + break + + labels = new_labels - # Create chunks - unique_chunks = np.unique(chunk_id) - for k in unique_chunks: - mask = chunk_id == k + # Create chunks based on final labels + for k in range(nchunk): + mask = labels == k Y_chunk = D_unchunked['Y'][mask, :] X_chunk = X.iloc[mask] Y_chunk = pd.DataFrame(Y_chunk) D = get_D(X_chunk, Y_chunk, rescale_spatial_coords=False) list_D_sampleTmp.append(D) - list_X_sampleTmp.append(X_chunk) + list_X_sampleTmp.append(X_chunk) elif method == 'random': indices = np.random.permutation(nspot) nspot_perChunk = int(nspot/nchunk)