Skip to content

Commit

Permalink
Update process_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 17, 2024
1 parent abea139 commit aeab5e9
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mNSF/process_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import numpy as np
from tensorflow.data import Dataset
import pickle
import pandas as pd

# Define the kernel function for Gaussian Process
# MaternThreeHalves is a specific covariance function used in Gaussian processes
Expand Down Expand Up @@ -139,6 +140,31 @@ def get_listDtrain(list_D_,nbatch=1):
list_Dtrain.append(D_train)
return list_Dtrain

def get_vec_batch(nsample, nchunk):
vec_batch = []
for ksample in range(0,nsample):
vec_batch= vec_batch + [False] + [True] * (nchunk-1)
return vec_batch


def get_chunked_data(X, Y, nchunk):
list_D_sampleTmp=list()
list_X_sampleTmp=list()
nspot=X.shape[0]
nspot_perChunk = int(nspot/nchunk)
D_unchunked = get_D(X,Y)
for k in range(0,nchunk):
st = nspot_perChunk*k
end_ = nspot_perChunk*(k+1)
if(k==nchunk-1): end_ = nspot
Y_chunk = D_unchunked['Y'][st:end_,:]
X_chunk = D_unchunked['X'][st:end_,:]
Y_chunk = pd.DataFrame(Y_chunk)
X_chunk = pd.DataFrame(X_chunk)
D = get_D(X_chunk,Y_chunk,rescale_spatial_coords=False)
list_D_sampleTmp.append(D)
list_X_sampleTmp.append(X_chunk)
return list_D_sampleTmp, list_X_sampleTmp

def get_listSampleID(list_D_):
"""
Expand Down

0 comments on commit aeab5e9

Please sign in to comment.