From 655ae364d25d06a2bcff9a3790677de2d764c380 Mon Sep 17 00:00:00 2001 From: Yi Wang <37149810+yiwang12@users.noreply.github.com> Date: Thu, 14 Nov 2024 23:16:17 +0100 Subject: [PATCH] Update process_multiSample.py --- mNSF/process_multiSample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mNSF/process_multiSample.py b/mNSF/process_multiSample.py index cadac5a..1d9ba9c 100644 --- a/mNSF/process_multiSample.py +++ b/mNSF/process_multiSample.py @@ -213,12 +213,12 @@ def get_listDtrain(list_D_,nbatch=1,list_nchunk=None): if (kchunk==nchunk-1):end_=nspot X_chunk=X[st:end_,] Y_chunk=Y[st:end_,] - D_chunk = get_D(X,Y) + D_chunk = get_D(X_chunk,X_chunk) list_D_chunk.append(D_chunk) print("len(list_D_chunk)") print(len(list_D_chunk)) for ksample_splitted in range(0,nsample_splitted): - D_chunk=list_D_chunk[ksample] + D_chunk=list_D_chunk[ksample_splitted] Ntr = D_chunk["Y"].shape[0] # Number of observations in this sample # Convert dictionary to TensorFlow Dataset