diff --git a/mNSF/process_multiSample.py b/mNSF/process_multiSample.py index 7547fa0..84e90fa 100644 --- a/mNSF/process_multiSample.py +++ b/mNSF/process_multiSample.py @@ -158,13 +158,21 @@ def get_chunked_data(X, Y, nchunk, method='random'): method: Chunking method - 'kmeans' or 'random' (default: 'random') Returns: - tuple: (list_D_sampleTmp, list_X_sampleTmp) containing chunked data dictionaries and coordinates + tuple: (list_D_sampleTmp, list_X_sampleTmp, chunk_mapping) where chunk_mapping contains: + - original_indices: list of arrays containing original indices for each chunk + - chunk_indices: array indicating which chunk each original spot belongs to """ list_D_sampleTmp = [] list_X_sampleTmp = [] nspot = X.shape[0] D_unchunked = get_D(X, Y) + # Initialize mapping information + chunk_mapping = { + 'original_indices': [], # Will store original indices for each chunk + 'chunk_indices': np.zeros(nspot, dtype=int) # Will store chunk assignment for each original spot + } + if method == 'balanced_kmeans': from sklearn.cluster import KMeans coords = X.iloc[:, :2].values @@ -229,6 +237,8 @@ def get_chunked_data(X, Y, nchunk, method='random'): # Create chunks based on final labels for k in range(nchunk): mask = labels == k + chunk_mapping['original_indices'].append(np.where(mask)[0]) + chunk_mapping['chunk_indices'][mask] = k Y_chunk = D_unchunked['Y'][mask, :] X_chunk = X.iloc[mask] Y_chunk = pd.DataFrame(Y_chunk) @@ -238,11 +248,13 @@ def get_chunked_data(X, Y, nchunk, method='random'): elif method == 'random': indices = np.random.permutation(nspot) nspot_perChunk = int(nspot/nchunk) - for k in range(0,nchunk): st = nspot_perChunk*k end_ = nspot_perChunk*(k+1) if(k==nchunk-1): end_ = nspot + chunk_indices = indices[st:end_] + chunk_mapping['original_indices'].append(chunk_indices) + chunk_mapping['chunk_indices'][chunk_indices] = k Y_chunk = D_unchunked['Y'][st:end_,:] X_chunk = D_unchunked['X'][st:end_,:] Y_chunk = pd.DataFrame(Y_chunk) @@ -402,56 +414,35 @@ def interpret_npf_v3(list_fit,list_X,S=10,**kwargs): return interpret_nonneg(np.exp(Fhat_c),list_fit[0].W.numpy(),sort=False,**kwargs) -def reorder_spatial_factors(factors, list_D, list_X_original): +def reorder_chunked_results(factors, list_chunk_mappings, list_X_unchunked): """ - Reorder factors to match the original spatial coordinates order. + Reorder 2D factors matrix from chunked analysis back to original order. Args: - factors: numpy array of factors from mNSF analysis - list_D: list of chunked data dictionaries - list_X_original: list of original spatial coordinate dataframes + factors: 2D numpy array of factors (n_spots x n_factors) + list_chunk_mappings: list of chunk mappings for each sample + list_X_unchunked: list of original coordinate dataframes for each sample Returns: - numpy array of reordered factors + numpy array: factors reordered to match original indices """ - # Calculate chunks per sample based on the total number of chunks divided by number of samples - chunks_per_sample = len(list_D) // len(list_X_original) - reordered_factors = np.zeros_like(factors) - current_pos = 0 + total_spots = sum(len(X) for X in list_X_unchunked) + reordered = np.zeros((total_spots, factors.shape[1])) - for sample_idx, X_orig in enumerate(list_X_original): - # Get start and end indices for this sample's chunks - start_chunk = sample_idx * chunks_per_sample - end_chunk = start_chunk + chunks_per_sample if sample_idx < len(list_X_original)-1 else len(list_D) - - # Collect all coordinates and their corresponding factors for this sample - sample_coords = [] - sample_factors = [] - - # Get coordinates and factors from each chunk of this sample - for chunk_idx in range(start_chunk, end_chunk): - chunk_start = sum(len(list_D[i]['X']) for i in range(chunk_idx)) - chunk_end = chunk_start + len(list_D[chunk_idx]['X']) - - sample_coords.extend(list_D[chunk_idx]['X'].values) - sample_factors.extend(factors[chunk_start:chunk_end]) - - sample_coords = np.array(sample_coords) - sample_factors = np.array(sample_factors) - - # For each spot in the original data, find its position in the chunked data - for orig_idx, orig_spot in X_orig.iterrows(): - # Find matching position in chunked data - matches = np.where( - (np.abs(sample_coords[:, 0] - orig_spot.iloc[0]) < 1e-10) & - (np.abs(sample_coords[:, 1] - orig_spot.iloc[1]) < 1e-10) - )[0] - - if len(matches) > 0: - reordered_factors[current_pos] = sample_factors[matches[0]] - current_pos += 1 + current_pos = 0 # Tracks position in input factors array - return reordered_factors + # Process each sample + for mapping, X_orig in zip(list_chunk_mappings, list_X_unchunked): + # Get original indices for each chunk in this sample + for orig_indices in mapping['original_indices']: + chunk_size = len(orig_indices) + # Get the chunk's factors + chunk_factors = factors[current_pos:current_pos + chunk_size] + # Place them in their original positions + reordered[orig_indices] = chunk_factors + current_pos += chunk_size + + return reordered def interpret_nonneg(factors,loadings,lda_mode=False,sort=False):