Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 12, 2024
1 parent e3d406a commit 40b9c75
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 65 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- pandas
- pip
- scanpy
- squidpy>=1.2.0
- squidpy
- tensorflow=2.13
- tensorflow-probability=0.21
- pip:
Expand Down
60 changes: 3 additions & 57 deletions mNSF/process_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_D_fromAnnData(ad):
return D


def get_listD_chunked(list_D_,list_nchunk=[1]*len(list_D_)):
def get_listDtrain(list_D_,nbatch=1):
"""
Prepare the training data by creating TensorFlow Datasets.
Expand All @@ -124,66 +124,12 @@ def get_listD_chunked(list_D_,list_nchunk=[1]*len(list_D_)):
"""
list_Dtrain=list()
nsample=len(list_D_)
# data chunking
nsample_splitted = sum(list_nchunk)
list_D_chunk = list()
for ksample in range(0,nsample):
D=list_D_[ksample]
nchunk = list_nchunk[ksample]
X = D['X']
Y = D['Y']
nspot = X.shape[1]
nspot_perChunk = int(nspot/nchunk)
for kchunk in range(0,nchunk):
st = (kchunk-1)*nspot_perChunk
end_ = kchunk*nspot_perChunk
if (kchunk==nchunk-1):end_=nspot
X_chunk=X[st:end_,]
Y_chunk=Y[st:end_,]
D_chunk = get_D(X,Y)
list_D_chunk.append(D_chunk)
return list_D_chunk

def get_listDtrain(list_D_,nbatch=1,list_nchunk=[False]*len(list_D_)):
"""
Prepare the training data by creating TensorFlow Datasets.
This function converts the data dictionaries into TensorFlow Dataset objects,
which are efficient for training machine learning models.
Args:
list_D_: List of data dictionaries, one for each sample
nbatch: Number of batches to split the data into (default is 1)
Returns:
list_Dtrain: List of TensorFlow Datasets for training
"""
list_Dtrain=list()
nsample=len(list_D_)
# data chunking
nsample_splitted = sum(list_nchunk)
list_D_chunk = list()
for ksample in range(0,nsample):
D=list_D_[ksample]
nchunk = list_nchunk[ksample]
X = D['X']
Y = D['Y']
nspot = X.shape[1]
nspot_perChunk = int(nspot/nchunk)
for kchunk in range(0,nchunk):
st = (kchunk-1)*nspot_perChunk
end_ = kchunk*nspot_perChunk
if (kchunk==nchunk-1):end_=nspot
X_chunk=X[st:end_,]
Y_chunk=Y[st:end_,]
D_chunk = get_D(X,Y)
list_D_chunk.append(D_chunk)
for ksample_splitted in range(0,nsample_splitted):
D_chunk=list_D_chunk[ksample]
Ntr = D_chunk["Y"].shape[0] # Number of observations in this sample
Ntr = D["Y"].shape[0] # Number of observations in this sample

# Convert dictionary to TensorFlow Dataset
Dtrain = Dataset.from_tensor_slices(D_chunk)
Dtrain = Dataset.from_tensor_slices(D)

# Batch the data
if (nbatch==1): D_train = Dtrain.batch(round(Ntr)+1)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_small_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,21 @@ def _run(

# step 0 Data loading
D, X = load_data(data_dir, n_sample)
ninduced
listDtrain = process_multiSample.get_listDtrain(D)
list_D_chunked = process_multiSample.get_listD_chunked(D)
for ksample in range(0, len(list_D_chunked)):

for ksample in range(0, len(D)):
random.seed(10)
ninduced = round(list_D_chunked[ksample]["X"].shape[0] * 0.35)
ninduced = round(D[ksample]["X"].shape[0] * 0.35)
D_tmp = D[ksample]
D[ksample]["Z"] = D_tmp["X"][random.sample(range(0, D_tmp["X"].shape[0] - 1), ninduced), :]

# step 1 initialize model
fit = process_multiSample.ini_multiSample(list_D_chunked, n_loadings, "nb", chol=False)
fit = process_multiSample.ini_multiSample(D, n_loadings, "nb", chol=False)

# step 2 fit model

(pp := (output_dir / "models" / "pp")).mkdir(parents=True, exist_ok=True)
fit = training_multiSample.train_model_mNSF(fit, pp, listDtrain, list_D_chunked, legacy=legacy, num_epochs=epochs)
fit = training_multiSample.train_model_mNSF(fit, pp, listDtrain, D, legacy=legacy, num_epochs=epochs)
(output_dir / "list_fit_smallData.pkl").write_bytes(pickle.dumps(fit))

# step 3 save results
Expand Down

0 comments on commit 40b9c75

Please sign in to comment.