diff --git a/README.md b/README.md index 7474e1a..c41c566 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ # GREEN architecture ![CI](https://github.com/Roche/neuro-green/actions/workflows/lint_and_test.yaml/badge.svg) --- + +## About the architecture +The model is a deep learning architecture designed for EEG data that combines wavelet transforms and Riemannian geometry. It is capable of learning from raw EEG data and can be applied to both classification and regression tasks. The model is composed of the following layers: +It is based on the following layers: + - Convolution: Uses complex-valued Gabor wavelets with parameters that are learned during training. + - Pooling: Derives features from the wavelet-transformed signal, such as covariance matrices. + - Shrinkage layer: applies [shrinkage](https://scikit-learn.org/1.5/modules/covariance.html#basic-shrinkage) to the covariance matrices. + - Riemannian Layers: Applies transformations to the matrices, leveraging the geometry of the Symmetric Positive Definite (SPD) manifold. + - Fully Connected Layers: Standard fully connected layers for final processing. + ![alt text](assets/concept_figure.png) @@ -9,18 +19,59 @@ ``` pip install -e . ``` + +## Dependencies +``` +scikit-learn +torch +geotorch +lightning +mne +``` + +## Examples +Examples illustrating how to train the presented model can be found in the `green/research_code` folder. The notebook `example.ipynb` shows how to train the model on raw EEG data. And the notebook `example_wo_wav.ipynb` shows how to train a submodel that uses covariance matrices as input. + +In addition, being pure PyTorch, the GREEN model can easily be integrated to [`braindecode`](https://braindecode.org/stable/index.html) routines. + +```python +import torch +from braindecode import EEGRegressor +from green.wavelet_layers import RealCovariance +from green.research_code.pl_utils import get_green + +green_model = get_green( + n_freqs=5, # Learning 5 wavelets + n_ch=22, # EEG data with 22 channels + sfreq=100, # Sampling frequency of 100 Hz + dropout=0.5, # Dropout rate of 0.5 in FC layers + hidden_dim=[100], # Use 100 units in the hidden layer + pool_layer=RealCovariance(), # Compute covariance after wavelet transform + bi_out=[20], # Use a BiMap layer outputing a 20x20 matrix + out_dim=1, # Output dimension of 1, for regression +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +EarlyStopping(monitor="valid_loss", patience=10, load_best=True) +clf = EEGClassifier( + module=green_model, + criterion=torch.nn.CrossEntropyLoss, + optimizer=torch.optim.AdamW, + device=device, + callbacks=[], # Callbacks can be added here, e.g. EarlyStopping +) +``` + ## Citation When using our code, please cite the reference article: ``` bibtex -@article {Paillard2024.05.14.594142, +@article {paillard_2024_green, author = {Paillard, Joseph and Hipp, Joerg F and Engemann, Denis A}, title = {GREEN: a lightweight architecture using learnable wavelets and Riemannian geometry for biomarker exploration}, - elocation-id = {2024.05.14.594142}, year = {2024}, doi = {10.1101/2024.05.14.594142}, URL = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142}, - eprint = {https://www.biorxiv.org/content/early/2024/05/14/2024.05.14.594142.full.pdf}, journal = {bioRxiv} } ``` \ No newline at end of file diff --git a/green/research_code/example.ipynb b/green/research_code/example.ipynb index f5b0705..2eb57bb 100644 --- a/green/research_code/example.ipynb +++ b/green/research_code/example.ipynb @@ -1,5 +1,14 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training GREEN with `lightning`\n", + "\n", + "The notebook is a simple example of how to train the GREEN model. It uses dummy data, `mne.Epochs`." + ] + }, { "cell_type": "code", "execution_count": 1, diff --git a/green/research_code/example_wo_wav.ipynb b/green/research_code/example_wo_wav.ipynb index f369978..44d143d 100644 --- a/green/research_code/example_wo_wav.ipynb +++ b/green/research_code/example_wo_wav.ipynb @@ -1,5 +1,14 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a GREEN submodel without learning wavelets\n", + "\n", + "This notebook provides a minimal example of training a Green submodel without learning wavelets. The difference with the full model is that inputs are covariance matrices instead of raw EEG data. " + ] + }, { "cell_type": "code", "execution_count": 1, @@ -11,8 +20,7 @@ "import torch\n", "\n", "from research_code.pl_utils import get_green_g2, GreenClassifierLM\n", - "from research_code.crossval_utils import pl_crossval\n", - "\n" + "from research_code.crossval_utils import pl_crossval" ] }, {