Skip to content

Commit

Permalink
Update process_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 18, 2024
1 parent ac5bed3 commit c965c90
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions mNSF/process_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,33 +168,41 @@ def get_chunked_data(X, Y, nchunk, method='random'):
if method == 'spatial':
# Calculate grid size for approximately equal chunks
grid_size = int(np.sqrt(nchunk))
x_coords = X.iloc[:, 0]
y_coords = X.iloc[:, 1]
x_coords = X.iloc[:, 0].values
y_coords = X.iloc[:, 1].values

# Create bins for x and y coordinates
x_bins = np.linspace(np.min(x_coords), np.max(x_coords), grid_size + 1)
y_bins = np.linspace(np.min(y_coords), np.max(y_coords), grid_size + 1)
# Calculate quantiles instead of linear bins
x_quantiles = np.percentile(x_coords, np.linspace(0, 100, grid_size + 1))
y_quantiles = np.percentile(y_coords, np.linspace(0, 100, grid_size + 1))

# Assign spots to spatial bins
chunk_id = []
for i in range(nspot):
x_bin = np.digitize(x_coords[i], x_bins) - 1
y_bin = np.digitize(y_coords[i], y_bins) - 1
x_bin = np.digitize(x_coords[i], x_quantiles) - 1
y_bin = np.digitize(y_coords[i], y_quantiles) - 1
if x_bin >= grid_size: x_bin = grid_size - 1
if y_bin >= grid_size: y_bin = grid_size - 1
chunk_id.append(x_bin * grid_size + y_bin)

chunk_id = np.array(chunk_id)
unique_chunks = np.unique(chunk_id)

# Create chunks from spatial bins
# Ensure we don't exceed nchunk
if len(np.unique(chunk_id)) > nchunk:
unique_counts = np.bincount(chunk_id)
valid_chunks = np.argsort(unique_counts)[-nchunk:]
mask = np.isin(chunk_id, valid_chunks)
chunk_id[~mask] = valid_chunks[0]

# Create chunks
unique_chunks = np.unique(chunk_id)
for k in unique_chunks:
mask = chunk_id == k
Y_chunk = D_unchunked['Y'][mask, :]
X_chunk = D_unchunked['X'][mask, :]
X_chunk = X.iloc[mask]
Y_chunk = pd.DataFrame(Y_chunk)
X_chunk = pd.DataFrame(X_chunk)
D = get_D(X_chunk, Y_chunk, rescale_spatial_coords=False)
list_D_sampleTmp.append(D)
list_X_sampleTmp.append(X_chunk)
list_X_sampleTmp.append(X_chunk)
elif method == 'random':
indices = np.random.permutation(nspot)
nspot_perChunk = int(nspot/nchunk)
Expand Down

0 comments on commit c965c90

Please sign in to comment.