diff --git a/mNSF/process_multiSample.py b/mNSF/process_multiSample.py index 1e833bb..53ac8d7 100644 --- a/mNSF/process_multiSample.py +++ b/mNSF/process_multiSample.py @@ -108,6 +108,38 @@ def get_D_fromAnnData(ad): D["Z"]=D['X'] return D +def get_listX_chunked(list_X_,list_nchunk=None): + """ + Prepare the training data by creating TensorFlow Datasets. + + This function converts the data dictionaries into TensorFlow Dataset objects, + which are efficient for training machine learning models. + + Args: + list_D_: List of data dictionaries, one for each sample + nbatch: Number of batches to split the data into (default is 1) + + Returns: + list_Dtrain: List of TensorFlow Datasets for training + """ + if(list_nchunk is None):list_nchunk=[1]*len(list_X_) + list_Dtrain=list() + nsample=len(list_X_) + # data chunking + nsample_splitted = sum(list_nchunk) + list_X_chunk = list() + for ksample in range(0,nsample): + X=list_X_[ksample] + nchunk = list_nchunk[ksample] + nspot = X.shape[0] + nspot_perChunk = int(nspot/nchunk) + for kchunk in range(0,nchunk): + st = (kchunk)*nspot_perChunk + end_ = (kchunk+1)*nspot_perChunk + if (kchunk==nchunk-1):end_=nspot + X_chunk=X[st:end_,] + list_X_chunk.append(X_chunk) + return list_X_chunk def get_listD_chunked(list_D_,list_nchunk=None): """ @@ -134,7 +166,7 @@ def get_listD_chunked(list_D_,list_nchunk=None): nchunk = list_nchunk[ksample] X = D['X'] Y = D['Y'] - nspot = X.shape[1] + nspot = X.shape[0] nspot_perChunk = int(nspot/nchunk) for kchunk in range(0,nchunk): st = (kchunk)*nspot_perChunk @@ -146,7 +178,60 @@ def get_listD_chunked(list_D_,list_nchunk=None): list_D_chunk.append(D_chunk) return list_D_chunk - +def get_listDtrain(list_D_,nbatch=1,list_nchunk=None): + """ + Prepare the training data by creating TensorFlow Datasets. + + This function converts the data dictionaries into TensorFlow Dataset objects, + which are efficient for training machine learning models. + + Args: + list_D_: List of data dictionaries, one for each sample + nbatch: Number of batches to split the data into (default is 1) + + Returns: + list_Dtrain: List of TensorFlow Datasets for training + """ + if(list_nchunk is None):list_nchunk=[1]*len(list_D_) + list_Dtrain=list() + nsample=len(list_D_) + # data chunking + nsample_splitted = sum(list_nchunk) + list_D_chunk = list() + for ksample in range(0,nsample): + D=list_D_[ksample] + nchunk = list_nchunk[ksample] + X = D['X'] + Y = D['Y'] + nspot = X.shape[0] + nspot_perChunk = int(nspot/nchunk) + for kchunk in range(0,nchunk): + print("kchunk") + print(kchunk) + st = (kchunk)*nspot_perChunk + end_ = (1+kchunk)*nspot_perChunk + if (kchunk==nchunk-1):end_=nspot + X_chunk=X[st:end_,] + Y_chunk=Y[st:end_,] + D_chunk = get_D(X,Y) + list_D_chunk.append(D_chunk) + print("len(list_D_chunk)") + print(len(list_D_chunk)) + for ksample_splitted in range(0,nsample_splitted): + D_chunk=list_D_chunk[ksample] + Ntr = D_chunk["Y"].shape[0] # Number of observations in this sample + + # Convert dictionary to TensorFlow Dataset + Dtrain = Dataset.from_tensor_slices(D_chunk) + + # Batch the data + if (nbatch==1): D_train = Dtrain.batch(round(Ntr)+1) + else: + Ntr_batch=round(Ntr/nbatch)+1 # Number of observations in this sample + D_train = Dtrain.batch(round(Ntr_batch)+1) + list_Dtrain.append(D_train) + return list_Dtrain + def get_listSampleID(list_D_): """ @@ -173,6 +258,82 @@ def get_listSampleID(list_D_): +def ini_multiSample(list_D_,L_, lik = 'nb', disp = "default",chol=True): + """ + Initialize mNSF (multi-sample Non-negative Spatial Factorization). + + This function sets up the initial state for the mNSF model, including + creating ProcessFactorization objects for each sample and initializing + their parameters. + + Args: + list_D_: List of data dictionaries, one for each sample + L_: Number of factors to use in the factorization + lik: Likelihood function ('nb' for negative binomial) + disp: Dispersion parameter for the negative binomial distribution + + Returns: + list_fit_: List of initialized ProcessFactorization objects + """ + list_X=list() + list_Z=list() + list_sampleID_=list() + nsample_=len(list_D_) + index__=0 + for ksample in range(0,nsample_): + D=list_D_[ksample] + list_X.append(D['X']) + list_Z.append(D['Z']) + Ntr = D["Z"].shape[0] + list_sampleID_.append(np.arange(index__,index__+Ntr)) + index__=index__+Ntr + list_fit_=list() + J_=list_D_[0]["Y"].shape[1] + for ksample in range(0,nsample_): + D=list_D_[ksample] + fit=pf.ProcessFactorization(J_,L_,D['Z'],X=list_X[ksample],psd_kernel=ker,nonneg=True,lik=lik,disp = disp, chol = chol) + fit.init_loadings(D["Y"],X=D['X'],sz=D["sz"],shrinkage=0.3) + list_fit_.append(fit) + if ksample==0: + X_concatenated=D['X'] + Z_concatenated=D['Z'] + Y_concatenated=D['Y'] + sz_concatenated=D['sz'] + else: + X_concatenated=np.concatenate((X_concatenated, D['X']), axis=0) + Z_concatenated=np.concatenate((Z_concatenated, D['Z']), axis=0) + Y_concatenated=np.concatenate((Y_concatenated, D['Y']), axis=0) + sz_concatenated=np.concatenate((sz_concatenated, D['sz']), axis=0) + fit_multiSample=pf_multiSample.ProcessFactorization_multiSample(J_,L_, + Z_concatenated, + nsample=nsample_, + psd_kernel=ker,nonneg=True,lik=lik) + fit_multiSample.init_loadings(Y_concatenated, + list_X=list_X, + list_Z=list_Z, + sz=sz_concatenated,shrinkage=0.3) + for ksample in range(0,nsample_): + indices=list_sampleID_[ksample] + #print("indices") + #print(indices) + indices=indices.astype(int) + #print(indices) + #print(fit_multiSample.delta.numpy()[:,indices]) + delta=fit_multiSample.delta.numpy()[:,indices] + beta0=fit_multiSample.beta0.numpy()[((ksample)*L_):((ksample+1)*L_),:] + beta=fit_multiSample.beta.numpy()[((ksample)*L_):((ksample+1)*L_),:] + W=fit_multiSample.W.numpy() + list_fit_[ksample].delta.assign(delta) + list_fit_[ksample].beta0.assign(beta0) + list_fit_[ksample].beta.assign(beta) + list_fit_[ksample].W.assign(W) + #list_para_tmp=training_multiSample.store_paras_from_tf_to_np(list_fit_[k]) + save_object(list_fit_[ksample], 'fit_'+str(ksample+1)+'_restore.pkl') + #save_object(list_fit_[ksample], 'fit_'+str(ksample+1)+'.pkl') + #save_object(list_para_tmp, 'list_para_'+ str(k+1) +'.pkl') + #save_object(list_para_tmp, 'list_para_'+ str(k+1) +'_restore.pkl') + return list_fit_ + def save_object(obj, filename): """ @@ -188,247 +349,35 @@ def save_object(obj, filename): with open(filename, 'wb') as outp: # Overwrites any existing file. pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL) -def get_listX_chunked(list_X_, list_nchunk=None): - """ - Chunk spatial coordinates into smaller pieces. - - Args: - list_X_: List of spatial coordinate arrays - list_nchunk: List of integers specifying number of chunks for each sample - - Returns: - list_X_chunk: List of chunked coordinate arrays - """ - if list_nchunk is None: - list_nchunk = [1] * len(list_X_) - - list_X_chunk = [] - - for ksample, X in enumerate(list_X_): - nchunk = list_nchunk[ksample] - nspot = X.shape[0] - nspot_perChunk = int(nspot / nchunk) - - for kchunk in range(nchunk): - start = kchunk * nspot_perChunk - end = (kchunk + 1) * nspot_perChunk if kchunk < nchunk - 1 else nspot - X_chunk = X[start:end] - list_X_chunk.append(X_chunk) - - return list_X_chunk - -def get_listDtrain(list_D_, nbatch=1, list_nchunk=None): - """ - Prepare the training data by creating TensorFlow Datasets from chunked data. - - Args: - list_D_: List of data dictionaries - nbatch: Number of batches for each chunk - list_nchunk: List of integers specifying number of chunks for each sample - - Returns: - list_Dtrain: List of TensorFlow Datasets for training - """ - if list_nchunk is None: - list_nchunk = [1] * len(list_D_) - - # First chunk the data - list_D_chunk = get_listD_chunked(list_D_, list_nchunk) - list_Dtrain = [] - - # Create TensorFlow datasets for each chunk - for D_chunk in list_D_chunk: - Ntr = D_chunk["Y"].shape[0] - - # Convert dictionary to TensorFlow Dataset - Dtrain = Dataset.from_tensor_slices(D_chunk) - - # Batch the data - if nbatch == 1: - D_train = Dtrain.batch(Ntr) - else: - Ntr_batch = round(Ntr/nbatch) - D_train = Dtrain.batch(Ntr_batch) - - list_Dtrain.append(D_train) - - return list_Dtrain -def ini_multiSample(list_D_, L_, lik='nb', disp="default", chol=True): - """ - Initialize mNSF model with proper handling of chunked data. - - Args: - list_D_: List of data dictionaries (can be chunked) - L_: Number of factors - lik: Likelihood function type ('nb' for negative binomial) - disp: Dispersion parameter - chol: Whether to use Cholesky decomposition - - Returns: - list_fit_: List of initialized ProcessFactorization objects - """ - # First, reconstruct original data structure from chunks - original_D = [] - current_idx = 0 - sample_indices = [] - - # Get the number of unique samples by examining data structure - sample_sizes = [] - current_sample = 0 - current_size = 0 - - # Determine original sample sizes from the chunked data - for i, D in enumerate(list_D_): - if i > 0: - # Check if this chunk belongs to a new sample by comparing X coordinates - if not np.array_equal(D['X'][:1], list_D_[i-1]['X'][-1:]): - sample_sizes.append(current_size) - current_size = D['Y'].shape[0] - current_sample += 1 - else: - current_size += D['Y'].shape[0] - else: - current_size = D['Y'].shape[0] - sample_sizes.append(current_size) - - # Initialize lists to store concatenated data for each original sample - X_list = [] - Y_list = [] - Z_list = [] - sz_list = [] - - # Reconstruct original samples from chunks - start_idx = 0 - for size in sample_sizes: - end_idx = start_idx - X_temp = [] - Y_temp = [] - Z_temp = [] - sz_temp = [] - # Collect chunks belonging to this sample - while end_idx < len(list_D_) and sum(len(d['Y']) for d in Y_temp) < size: - X_temp.append(list_D_[end_idx]['X']) - Y_temp.append(list_D_[end_idx]['Y']) - Z_temp.append(list_D_[end_idx]['Z']) - sz_temp.append(list_D_[end_idx]['sz']) - end_idx += 1 - - # Concatenate chunks for this sample - X_list.append(np.concatenate(X_temp, axis=0)) - Y_list.append(np.concatenate(Y_temp, axis=0)) - Z_list.append(np.concatenate(Z_temp, axis=0)) - sz_list.append(np.concatenate(sz_temp, axis=0)) - - # Store indices for this sample - sample_indices.append(np.arange(current_idx, current_idx + size)) - current_idx += size - start_idx = end_idx - - # Create original data dictionaries - original_D = [{'X': X, 'Y': Y, 'Z': Z, 'sz': sz} - for X, Y, Z, sz in zip(X_list, Y_list, Z_list, sz_list)] - - # Initialize list_fit_ using reconstructed original data - list_fit_ = [] - nsample_ = len(original_D) - J_ = original_D[0]["Y"].shape[1] - - # Concatenate all data for multi-sample initialization - X_concatenated = np.concatenate(X_list, axis=0) - Z_concatenated = np.concatenate(Z_list, axis=0) - Y_concatenated = np.concatenate(Y_list, axis=0) - sz_concatenated = np.concatenate(sz_list, axis=0) - - # Initialize individual fits - for ksample in range(nsample_): - D = original_D[ksample] - fit = pf.ProcessFactorization( - J_, L_, - D['Z'], - X=D['X'], - psd_kernel=ker, - nonneg=True, - lik=lik, - disp=disp, - chol=chol - ) - fit.init_loadings(D["Y"], X=D['X'], sz=D["sz"], shrinkage=0.3) - list_fit_.append(fit) - - # Initialize multi-sample fit - fit_multiSample = pf_multiSample.ProcessFactorization_multiSample( - J_, L_, - Z_concatenated, - nsample=nsample_, - psd_kernel=ker, - nonneg=True, - lik=lik - ) - - fit_multiSample.init_loadings( - Y_concatenated, - list_X=X_list, - list_Z=Z_list, - sz=sz_concatenated, - shrinkage=0.3 - ) - - # Transfer parameters to individual fits - for ksample in range(nsample_): - indices = sample_indices[ksample] - delta = fit_multiSample.delta.numpy()[:, indices] - beta0 = fit_multiSample.beta0.numpy()[ksample*L_:(ksample+1)*L_, :] - beta = fit_multiSample.beta.numpy()[ksample*L_:(ksample+1)*L_, :] - W = fit_multiSample.W.numpy() - - list_fit_[ksample].delta.assign(delta) - list_fit_[ksample].beta0.assign(beta0) - list_fit_[ksample].beta.assign(beta) - list_fit_[ksample].W.assign(W) - - # Save checkpoint - save_object(list_fit_[ksample], f'fit_{ksample+1}_restore.pkl') +def interpret_npf_v3(list_fit,list_X,list_nchunk=None, S=10,**kwargs): + """ + Interpret the non-negative process factorization results. - return list_fit_ -def interpret_npf_v3(list_fit, list_X, list_nchunk=None, S=10, **kwargs): - """ - Interpret the factorization results with support for chunked data. + This function samples from the learned Gaussian processes to generate + interpretable factors and loadings. Args: list_fit: List of fitted ProcessFactorization objects - list_X: List of spatial coordinates - list_nchunk: List specifying number of chunks per sample - S: Number of samples to draw - **kwargs: Additional arguments for interpret_nonneg + list_X: List of spatial coordinates for each sample + S: Number of samples to draw from the Gaussian processes + **kwargs: Additional keyword arguments to pass to interpret_nonneg Returns: - Dictionary containing interpretable results - """ - # Get chunked coordinates - listX_chunked = get_listX_chunked(list_X, list_nchunk) - - # Sample latent GPs for each chunk - Fhat_chunks = [] - for ksample, X_chunk in enumerate(listX_chunked): - # Get the corresponding fit object (accounting for chunking) - fit_idx = ksample // (list_nchunk[ksample] if list_nchunk else 1) - Fhat_tmp = misc.t2np( - list_fit[fit_idx].sample_latent_GP_funcs(X_chunk, S=S, chol=False) - ).T - Fhat_chunks.append(Fhat_tmp) - - # Concatenate all chunks - Fhat_c = np.concatenate(Fhat_chunks, axis=0) - - # Interpret the results - return interpret_nonneg( - np.exp(Fhat_c), - list_fit[0].W.numpy(), - sort=False, - **kwargs - ) + Dictionary containing interpretable loadings W, factors eF, and total counts vector + """ + listX_chunked = get_listX_chunked(list_X,list_nchunk) + nsample=len(list_fit) + for ksample in range(0,len(listX_chunked)): + Fhat_tmp = misc.t2np(list_fit[ksample].sample_latent_GP_funcs(listX_chunked[ksample],S=S,chol=False)).T #NxL + if ksample==0: + Fhat_c=Fhat_tmp + else: + Fhat_c=np.concatenate((Fhat_c,Fhat_tmp), axis=0) + return interpret_nonneg(np.exp(Fhat_c),list_fit[0].W.numpy(),sort=False,**kwargs) + + def interpret_nonneg(factors,loadings,lda_mode=False,sort=False): @@ -486,6 +435,3 @@ def rescale_as_lda(factors,loadings,sort=False): return W[:,o],eF[:,o],eFsum else: return W,eF,eFsum,wsum - - -