Skip to content

Commit

Permalink
Update mnsf-tutorial-dlpfc.md
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Nov 18, 2024
1 parent 6c448ce commit 9adb4b4
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tutorial/mnsf-tutorial-dlpfc.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,12 @@ First, let's implement both optimization techniques:
list_D_chunked=list()
list_X_chunked=list()
for ksample in range(0,nsample):
Y = pd.read_csv(f'path/to/Y_sample{ksample+1}.csv')
X = pd.read_csv(f'path/to/X_sample{ksample+1}.csv')
list_D_sampleTmp,list_X_sampleTmp = process_multiSample.get_chunked_data(X,Y,nchunk)
list_D_chunked = list_D_chunked + list_D_sampleTmp
list_X_chunked = list_X_chunked + list_X_sampleTmp
Y=pd.read_csv(path.join('//dcs04/hansen/data/ywang/ST/DLPFC/processed_Data//Y_features_sele_sample'+str(ksample*4+1)+'_500genes.csv'))
X=pd.read_csv(path.join('//dcs04/hansen/data/ywang/ST/DLPFC/processed_Data///X_allSpots_sample'+str(ksample*4+1)+'.csv'))
list_D_sampleTmp,list_X_sampleTmp, chunk_mapping = process_multiSample.get_chunked_data(X.iloc[:,:],Y.iloc[:,:],nchunk,method = "random") #choose method = "balanced_kmeans" for chunking the spots based on the spatial coordinates
list_D = list_D + list_D_sampleTmp
list_X = list_X + list_X_sampleTmp
list_chunk_mapping.append(chunk_mapping)

# Extracts the training data from our processed data. This function prepares the data in the format required for model training.
list_Dtrain = process_multiSample.get_listDtrain(list_D_chunked)
Expand Down Expand Up @@ -392,6 +393,7 @@ After training, we can visualize the results. Here's how to plot the mNSF factor

```python
Fplot = misc.t2np(list_fit[0].sample_latent_GP_funcs(list_D_chunked[0]["X"], S=3, chol=False)).T
Fplot = process_multiSample.reorder_chunked_results(Fplot,list_chunk_mapping,list_X_unchunked)
hmkw = {"figsize": (4, 4), "bgcol": "white", "subplot_space": 0.1, "marker": "s", "s": 10}
fig, axes = visualize.multiheatmap(list_D[0]["X"], Fplot, (1, 2), cmap="RdBu", **hmkw)
```
Expand Down

0 comments on commit 9adb4b4

Please sign in to comment.