Skip to content

Commit

Permalink
add content to readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Dec 18, 2024
1 parent 21ee855 commit 57b80d9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
57 changes: 54 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# GREEN architecture
![CI](https://github.com/Roche/neuro-green/actions/workflows/lint_and_test.yaml/badge.svg)
---

## About the architecture
The model is a deep learning architecture designed for EEG data that combines wavelet transforms and Riemannian geometry. It is capable of learning from raw EEG data and can be applied to both classification and regression tasks. The model is composed of the following layers:
It is based on the following layers:
- Convolution: Uses complex-valued Gabor wavelets with parameters that are learned during training.
- Pooling: Derives features from the wavelet-transformed signal, such as covariance matrices.
- Shrinkage layer: applies [shrinkage](https://scikit-learn.org/1.5/modules/covariance.html#basic-shrinkage) to the covariance matrices.
- Riemannian Layers: Applies transformations to the matrices, leveraging the geometry of the Symmetric Positive Definite (SPD) manifold.
- Fully Connected Layers: Standard fully connected layers for final processing.

![alt text](assets/concept_figure.png)


Expand All @@ -9,18 +19,59 @@
```
pip install -e .
```

## Dependencies
```
scikit-learn
torch
geotorch
lightning
mne
```

## Examples
Examples illustrating how to train the presented model can be found in the `green/research_code` folder. The notebook `example.ipynb` shows how to train the model on raw EEG data. And the notebook `example_wo_wav.ipynb` shows how to train a submodel that uses covariance matrices as input.

In addition, being pure PyTorch, the GREEN model can easily be integrated to [`braindecode`](https://braindecode.org/stable/index.html) routines.

```python
import torch
from braindecode import EEGRegressor
from green.wavelet_layers import RealCovariance
from green.research_code.pl_utils import get_green

green_model = get_green(
n_freqs=5, # Learning 5 wavelets
n_ch=22, # EEG data with 22 channels
sfreq=100, # Sampling frequency of 100 Hz
dropout=0.5, # Dropout rate of 0.5 in FC layers
hidden_dim=[100], # Use 100 units in the hidden layer
pool_layer=RealCovariance(), # Compute covariance after wavelet transform
bi_out=[20], # Use a BiMap layer outputing a 20x20 matrix
out_dim=1, # Output dimension of 1, for regression
)

device = "cuda" if torch.cuda.is_available() else "cpu"
EarlyStopping(monitor="valid_loss", patience=10, load_best=True)
clf = EEGClassifier(
module=green_model,
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.AdamW,
device=device,
callbacks=[], # Callbacks can be added here, e.g. EarlyStopping
)
```

## Citation
When using our code, please cite the reference article:

``` bibtex
@article {Paillard2024.05.14.594142,
@article {paillard_2024_green,
author = {Paillard, Joseph and Hipp, Joerg F and Engemann, Denis A},
title = {GREEN: a lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration},
elocation-id = {2024.05.14.594142},
year = {2024},
doi = {10.1101/2024.05.14.594142},
URL = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142},
eprint = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142.full.pdf},
journal = {bioRxiv}
}
```
9 changes: 9 additions & 0 deletions green/research_code/example.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training GREEN with `lightning`\n",
"\n",
"The notebook is a simple example of how to train the GREEN model. It uses dummy data, `mne.Epochs`."
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
12 changes: 10 additions & 2 deletions green/research_code/example_wo_wav.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training a GREEN submodel without learning wavelets\n",
"\n",
"This notebook provides a minimal example of training a Green submodel without learning wavelets. The difference with the full model is that inputs are covariance matrices instead of raw EEG data. "
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand All @@ -11,8 +20,7 @@
"import torch\n",
"\n",
"from research_code.pl_utils import get_green_g2, GreenClassifierLM\n",
"from research_code.crossval_utils import pl_crossval\n",
"\n"
"from research_code.crossval_utils import pl_crossval"
]
},
{
Expand Down

0 comments on commit 57b80d9

Please sign in to comment.