-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
5,466 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Created by .ignore support plugin (hsz.mobi) | ||
### Python template | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# IPython Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
### VirtualEnv template | ||
# Virtualenv | ||
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ | ||
[Bb]in | ||
[Ii]nclude | ||
[Ll]ib | ||
[Ll]ib64 | ||
[Ll]ocal | ||
[Ss]cripts | ||
pyvenv.cfg | ||
.venv | ||
pip-selfcheck.json | ||
|
||
### JetBrains template | ||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider | ||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||
|
||
# User-specific stuff | ||
.idea/**/workspace.xml | ||
.idea/**/tasks.xml | ||
.idea/**/usage.statistics.xml | ||
.idea/**/dictionaries | ||
.idea/**/shelf | ||
|
||
# AWS User-specific | ||
.idea/**/aws.xml | ||
|
||
# Generated files | ||
.idea/**/contentModel.xml | ||
|
||
# Sensitive or high-churn files | ||
.idea/**/dataSources/ | ||
.idea/**/dataSources.ids | ||
.idea/**/dataSources.local.xml | ||
.idea/**/sqlDataSources.xml | ||
.idea/**/dynamic.xml | ||
.idea/**/uiDesigner.xml | ||
.idea/**/dbnavigator.xml | ||
|
||
# Gradle | ||
.idea/**/gradle.xml | ||
.idea/**/libraries | ||
|
||
# Gradle and Maven with auto-import | ||
# When using Gradle or Maven with auto-import, you should exclude module files, | ||
# since they will be recreated, and may cause churn. Uncomment if using | ||
# auto-import. | ||
# .idea/artifacts | ||
# .idea/compiler.xml | ||
# .idea/jarRepositories.xml | ||
# .idea/modules.xml | ||
# .idea/*.iml | ||
# .idea/modules | ||
# *.iml | ||
# *.ipr | ||
|
||
# CMake | ||
cmake-build-*/ | ||
|
||
# Mongo Explorer plugin | ||
.idea/**/mongoSettings.xml | ||
|
||
# File-based project format | ||
*.iws | ||
|
||
# IntelliJ | ||
out/ | ||
|
||
# mpeltonen/sbt-idea plugin | ||
.idea_modules/ | ||
|
||
# JIRA plugin | ||
atlassian-ide-plugin.xml | ||
|
||
# Cursive Clojure plugin | ||
.idea/replstate.xml | ||
|
||
# SonarLint plugin | ||
.idea/sonarlint/ | ||
|
||
# Crashlytics plugin (for Android Studio and IntelliJ) | ||
com_crashlytics_export_strings.xml | ||
crashlytics.properties | ||
crashlytics-build.properties | ||
fabric.properties | ||
|
||
# Editor-based Rest Client | ||
.idea/httpRequests | ||
|
||
# Android studio 3.1+ serialized cache file | ||
.idea/caches/build_file_checksums.ser | ||
|
||
# idea folder, uncomment if you don't need it | ||
.idea | ||
|
||
wandb | ||
logs | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# EchoGLAD: Hierarchical Graph Neural Networks for Left Ventricle Landmark Detection on Echocardiograms | ||
|
||
This repository provides the official PyTorch implementation of: | ||
|
||
Masoud Mokhtari, Mobina Mahdavi, Hooman Vaseli, Christina Luong, Teresa Tsang, Purang Abolmaesumi, and Renjie Liao, [EchoGLAD: Hierarchical Graph Neural Networks for Left Ventricle Landmark Detection on Echocardiograms](Linktobeadded) (MICCAI 2023) | ||
|
||
## Abstract | ||
The functional assessment of the left ventricle chamber of the heart requires detecting four landmark locations and measuring the internal dimension of the left ventricle and the approximate mass of the surrounding muscle. | ||
The key challenge of automating this task with machine learning is the sparsity of clinical labels, i.e., only a few landmark pixels in a high-dimensional image are annotated, leading many prior works to heavily rely on isotropic label smoothing. | ||
However, such a label smoothing strategy ignores the anatomical information of the image and induces some bias. | ||
To address this challenge, we introduce an **echo**cardiogram-based, hierarchical **g**raph neural network (GNN) for **l**eft ventricle **la**ndmark **d**etection (EchoGLAD). | ||
Our main contributions are: 1) a hierarchical graph representation learning framework for multi-resolution landmark detection via GNNs; 2) induced hierarchical supervision at different levels of granularity using a multi-level loss. | ||
We evaluate our model on a public and a private dataset under the in-distribution (ID) and out-of-distribution (OOD) settings. | ||
For the ID setting, we achieve the state-of-the-art mean absolute errors (MAEs) of 1.46 mm and 1.86 mm on the two datasets. | ||
Our model also shows better OOD generalization than prior works with a testing MAE of 4.3 mm. | ||
|
||
<p align="center"> | ||
<img src="./echoglad.PNG" title="GEMTrans overall architecture" width="700"/> | ||
</p> | ||
|
||
## Reproducing MICCAI 2023 Results | ||
|
||
To reproduce our MICCAI 2023 results on the publicly available UIC dataset, follow the steps below. Please note that the provided weights are for the model trained on the UIC dataset (table provided in the supplementary material). | ||
|
||
1. [Install the required packages](#requirements) | ||
2. [Download the dataset](#dataset) | ||
3. Update the dataset path in the `default.yml` config file (`configs/` directory) under `data.data_dir` | ||
4. Update the dataset labels path in the `default.yml` config file (`configs/` directory) under `data.data_info_file` | ||
4. Ensure the `model.checkpoint_path` in the config file points to `./trained_models/miccai2023.pth` | ||
5. Run the command: | ||
``` | ||
python run.py --config_path ./configs/default.yml --save_dir <dir_to_save_ouput_to> --eval_only True --eval_data_type test" | ||
``` | ||
|
||
## Requirements | ||
|
||
PyTorch and PyTorch Geometric must be separately installed based on your system requirements. In our implementation, we used Python 3.8.10 with the following: | ||
|
||
``` | ||
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 | ||
pip install torch_geometric==2.0.2 torch_scatter==2.0.9 torch_sparse==0.6.12 torch_cluster==1.5.9 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html | ||
``` | ||
|
||
To install the rest of the requirements (preferably in a virtual environment), run: | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Dataset | ||
|
||
We use the UIC public LV landmark dataset in this project. Access it [here](https://data.unityimaging.net/). Download the dataset to a convenient location and provide its directory in the config file as described in [Config File](#config-file). | ||
|
||
## Training | ||
|
||
To train the model (training + validation), create a training configuration yaml file similar to `/configs/default.yml`. Modify the config file as needed, consulting [Config File](#config-file). | ||
Then, run: | ||
|
||
``` | ||
python run.py --config_path <path_to_training_config> --save_dir <dir_to_save_ouput_to> | ||
``` | ||
|
||
## Evaluation | ||
|
||
To evaluate an already trained model, create a training configuration yaml file similar to `/configs/default.yml` that matches the trained model's specifications. Provide the path to the trained model using the `model.checkpoint_path` option in the config file. Finally, run: | ||
|
||
|
||
``` | ||
python run.py --config_path <path_to_eval_config> --save_dir <dir_to_save_ouput_to> --eval_only True --eval_data_type test" | ||
``` | ||
|
||
## Config File | ||
|
||
The default configuration can be found in `./configs/default.yml`. Below is a summary of some important configuration options: | ||
|
||
- **train** | ||
- *criterion* | ||
- *WeightedBceWithLogits*: | ||
- ones_weight: the weight given to positive landmark locations in the image for loss evaluation | ||
- *eval* | ||
- *standards*: Evaluation metrics computed for the model | ||
- *standard*: Metric used to decide when the best model checkpoint is saved (among the standards) | ||
- *maximize*: Determines whether the standard is to be maximized or minimized | ||
- **model** | ||
- *checkpoint_path*: Path to the saved model for inference or continued training | ||
- *embedder*: Contains the config for the initial CNN expanding the images channel-wise | ||
- *landmark*: | ||
- *name*: Indicates the node feature construction type | ||
- *num_gnn_layers*: Number of GNN layers used to process the hierarchical graph | ||
- **data** | ||
- *name*: Dataset name | ||
- *data_dir*: Path to the dataset | ||
- *data_info_file*: Path to the dataset labels dir | ||
- *use_coordinate_graph*: Indicates whether an average location node graph is created | ||
- *use_main_graph_only*: Indicates whether only a pixel level graph is created (as per ablation studies) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
model: &model | ||
checkpoint_path: | ||
embedder: | ||
name: cnn | ||
out_channels: [4] | ||
cnn_dropout_p: 0.1 | ||
pool_sizes: [1] | ||
kernel_sizes: [3] | ||
landmark: | ||
name: unet_hierarchical_patch | ||
encoder_embedding_widths: [128, 64, 32, 16, 8, 4, 2] | ||
encoder_embedding_dims: [8, 16, 32, 64, 128, 256, 512] | ||
gnn_dropout_p: 0.5 | ||
node_embedding_dim: 128 | ||
node_hidden_dim: 128 | ||
classifier_hidden_dim: 32 | ||
classifier_dropout_p: 0.5 | ||
num_gnn_layers: 3 | ||
gnn_jk_mode: last | ||
residual: True | ||
output_activation: "logit" | ||
|
||
train: &train | ||
seed: 200 | ||
num_epochs: 100 | ||
checkpoint_step: 20000 | ||
batch_size: 1 | ||
num_workers: 4 | ||
use_wandb: False | ||
wand_project_name: <wandb_project_name> | ||
wandb_mode: offline | ||
wandb_run_name: <wandb_run_name> | ||
wandb_log_steps: 1000 | ||
|
||
criterion: | ||
WeightedBceWithLogits: | ||
loss_weight: 1 | ||
reduction: none | ||
ones_weight: 9000 # 224x224 and 7aux graphs results in 72k nodes which divided by 8 ones, is 9000 | ||
ExpectedLandmarkMse: | ||
loss_weight: 10 | ||
|
||
optimizer: &optimizer | ||
name: adam | ||
lr: 0.001 # 1e-3 | ||
weight_decay: 0.0001 # 1e-4 | ||
|
||
lr_schedule: &lr_schedule | ||
name: 'reduce_lr_on_plateau' | ||
mode: 'min' | ||
factor: 0.5 # Factor by which the learning rate will be reduced | ||
patience: 2 # Number of epochs with no improvement after which learning rate will be reduced | ||
threshold: 0.01 # Threshold for measuring the new optimum, to only focus on significant changes | ||
min_lr: 0.000001 # 1e-6 | ||
verbose: True | ||
|
||
# Evaluation metrics | ||
eval: &eval | ||
# Report these metrics | ||
standards: [ "balancedaccuracy", "landmarkcoorderror"] | ||
# Save checkpoints based on this metric | ||
standard: "balancedaccuracy" | ||
# Save checkpoints based on whether the metric is to be maximized or minimized | ||
minimize: False | ||
|
||
data: &data | ||
name: uiclvlandmark | ||
data_dir: <path_to_data_dir> | ||
num_aux_graphs: 7 | ||
data_info_file: <path_to_labels_dir> # Please read the constructor of the corresponding dataset in datasets.py to see what's needed here | ||
main_graph_type: 'grid' # 'grid-diagonal' / 'grid' | ||
aux_graph_type: 'grid' # 'grid-diagonal' / 'grid' | ||
use_coordinate_graph: False | ||
use_connection_nodes: False | ||
use_main_graph_only: False | ||
# flip_p: 0.0 | ||
|
||
transform: &transform | ||
image_size: 224 | ||
make_gray: True # This needs to be set to TRUE for the UIC dataset |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.