Skip to content

Commit

Permalink
Update process_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 18, 2024
1 parent c965c90 commit 51acb4e
Showing 1 changed file with 61 additions and 29 deletions.
90 changes: 61 additions & 29 deletions mNSF/process_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,44 +165,76 @@ def get_chunked_data(X, Y, nchunk, method='random'):
nspot = X.shape[0]
D_unchunked = get_D(X, Y)

if method == 'spatial':
# Calculate grid size for approximately equal chunks
grid_size = int(np.sqrt(nchunk))
x_coords = X.iloc[:, 0].values
y_coords = X.iloc[:, 1].values
if method == 'balanced_kmeans':
from sklearn.cluster import KMeans
coords = X.iloc[:, :2].values

# Calculate quantiles instead of linear bins
x_quantiles = np.percentile(x_coords, np.linspace(0, 100, grid_size + 1))
y_quantiles = np.percentile(y_coords, np.linspace(0, 100, grid_size + 1))
# Initialize with regular k-means
kmeans = KMeans(n_clusters=nchunk, random_state=42)
labels = kmeans.fit_predict(coords)
centers = kmeans.cluster_centers_

# Assign spots to spatial bins
chunk_id = []
for i in range(nspot):
x_bin = np.digitize(x_coords[i], x_quantiles) - 1
y_bin = np.digitize(y_coords[i], y_quantiles) - 1
if x_bin >= grid_size: x_bin = grid_size - 1
if y_bin >= grid_size: y_bin = grid_size - 1
chunk_id.append(x_bin * grid_size + y_bin)

chunk_id = np.array(chunk_id)
# Target size for each cluster
target_size = nspot // nchunk

# Ensure we don't exceed nchunk
if len(np.unique(chunk_id)) > nchunk:
unique_counts = np.bincount(chunk_id)
valid_chunks = np.argsort(unique_counts)[-nchunk:]
mask = np.isin(chunk_id, valid_chunks)
chunk_id[~mask] = valid_chunks[0]
# Iteratively balance clusters
max_iter = 20
for _ in range(max_iter):
# Calculate distances to all centers
distances = np.zeros((nspot, nchunk))
for i in range(nchunk):
distances[:, i] = np.sum((coords - centers[i])**2, axis=1)

# Sort points by distance to assigned cluster
cluster_sizes = np.bincount(labels, minlength=nchunk)
new_labels = np.zeros(nspot, dtype=int)

# For each cluster
unassigned = np.ones(nspot, dtype=bool)

# First, handle clusters that are too small
for i in range(nchunk):
if cluster_sizes[i] < target_size:
# Get closest unassigned points
dist_to_cluster = distances[unassigned, i]
n_needed = target_size - cluster_sizes[i]
closest = np.argsort(dist_to_cluster)[:n_needed]
unassigned_indices = np.where(unassigned)[0]
points_to_assign = unassigned_indices[closest]
new_labels[points_to_assign] = i
unassigned[points_to_assign] = False

# Then assign remaining points to closest clusters that aren't full
remaining_points = np.where(unassigned)[0]
for point in remaining_points:
# Find closest cluster that isn't full
dist_to_clusters = distances[point]
for closest in np.argsort(dist_to_clusters):
if np.sum(new_labels == closest) < target_size:
new_labels[point] = closest
break

# Update centers
old_centers = centers.copy()
for i in range(nchunk):
if np.sum(new_labels == i) > 0:
centers[i] = coords[new_labels == i].mean(axis=0)

# Check for convergence
if np.allclose(old_centers, centers):
break

labels = new_labels

# Create chunks
unique_chunks = np.unique(chunk_id)
for k in unique_chunks:
mask = chunk_id == k
# Create chunks based on final labels
for k in range(nchunk):
mask = labels == k
Y_chunk = D_unchunked['Y'][mask, :]
X_chunk = X.iloc[mask]
Y_chunk = pd.DataFrame(Y_chunk)
D = get_D(X_chunk, Y_chunk, rescale_spatial_coords=False)
list_D_sampleTmp.append(D)
list_X_sampleTmp.append(X_chunk)
list_X_sampleTmp.append(X_chunk)
elif method == 'random':
indices = np.random.permutation(nspot)
nspot_perChunk = int(nspot/nchunk)
Expand Down

0 comments on commit 51acb4e

Please sign in to comment.