diff --git a/mNSF/process_multiSample.py b/mNSF/process_multiSample.py index f73a4ee..6c11b35 100644 --- a/mNSF/process_multiSample.py +++ b/mNSF/process_multiSample.py @@ -414,34 +414,47 @@ def interpret_npf_v3(list_fit,list_X,S=10,**kwargs): return interpret_nonneg(np.exp(Fhat_c),list_fit[0].W.numpy(),sort=False,**kwargs) -def reorder_chunked_results(factors, list_chunk_mappings, list_X_unchunked): +def reorder_chunked_results(factors, list_chunk_mapping, list_X_unchunked): """ Reorder 2D factors matrix from chunked analysis back to original order. Args: factors: 2D numpy array of factors (n_spots x n_factors) - list_chunk_mappings: list of chunk mappings for each sample + list_chunk_mapping: list of dictionaries containing chunk mapping info per sample list_X_unchunked: list of original coordinate dataframes for each sample Returns: numpy array: factors reordered to match original indices """ + # Get total number of spots and factors total_spots = sum(len(X) for X in list_X_unchunked) - reordered = np.zeros((total_spots, factors.shape[1])) + n_factors = factors.shape[1] - current_pos = 0 # Tracks position in input factors array + # Initialize output array + reordered = np.zeros((total_spots, n_factors), dtype=factors.dtype) + + # Keep track of position in input factors array + current_pos = 0 # Process each sample - for mapping, X_orig in zip(list_chunk_mappings, list_X_unchunked): - # Get original indices for each chunk in this sample - for orig_indices in mapping['original_indices']: - chunk_size = len(orig_indices) - # Get the chunk's factors + for sample_idx, (mapping, X_orig) in enumerate(zip(list_chunk_mapping, list_X_unchunked)): + # Calculate start position for this sample in final array + sample_start = sum(len(X) for X in list_X_unchunked[:sample_idx]) + + # Process each chunk in this sample + for chunk_indices in mapping['original_indices']: + chunk_size = len(chunk_indices) chunk_factors = factors[current_pos:current_pos + chunk_size] - # Place them in their original positions - reordered[orig_indices] = chunk_factors - current_pos += chunk_size + # Map chunk indices to global indices + global_indices = chunk_indices + sample_start + + # Place factors in their original positions + reordered[global_indices] = chunk_factors + + # Update position counter + current_pos += chunk_size + return reordered def interpret_nonneg(factors,loadings,lda_mode=False,sort=False):