Skip to content

Commit

Permalink
Update process_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 12, 2024
1 parent 71e34b5 commit f6e5770
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions mNSF/process_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_D_fromAnnData(ad):
return D


def get_listDtrain(list_D_,nbatch=1):
def get_listDtrain(list_D_,nbatch=1,list_nchunk=[False]*len(list_D_)):
"""
Prepare the training data by creating TensorFlow Datasets.
Expand All @@ -124,12 +124,32 @@ def get_listDtrain(list_D_,nbatch=1):
"""
list_Dtrain=list()
nsample=len(list_D_)
# data chunking
nsample_splitted = sum(list_nchunk)
list_D_chunk = list()
for ksample in range(0,nsample):
D=list_D_[ksample]
Ntr = D["Y"].shape[0] # Number of observations in this sample
nchunk = list_nchunk[ksample]
X = D['X']
Y = D['Y']
nspot = X.shape[1]
nspot_perChunk = int(nspot/nchunk)
for kchunk in range(0,nchunk):
st = (kchunk-1)*nspot_perChunk
end_ = kchunk*nspot_perChunk
if (kchunk==nchunk-1):end_=nspot
X_chunk=X[st:end_,]
Y_chunk=Y[st:end_,]
D_chunk = list()
D_chunk["X"]=X
D_chunk["Y"]=Y
list_D_chunk.append(D_chunk)
for ksample_splitted in range(0,nsample_splitted):
D_chunk=list_D_chunk[ksample]
Ntr = D_chunk["Y"].shape[0] # Number of observations in this sample

# Convert dictionary to TensorFlow Dataset
Dtrain = Dataset.from_tensor_slices(D)
Dtrain = Dataset.from_tensor_slices(D_chunk)

# Batch the data
if (nbatch==1): D_train = Dtrain.batch(round(Ntr)+1)
Expand Down

0 comments on commit f6e5770

Please sign in to comment.