Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECG2AF open-source weights and notebook #543

Merged
merged 46 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
00832eb
include fix
lucidtronix Sep 26, 2023
0101b46
include fix
lucidtronix Oct 6, 2023
72fa8fc
notebook demo
lucidtronix Oct 6, 2023
9ca05df
Update README.md
lucidtronix Oct 6, 2023
eabe16d
Update README.md
lucidtronix Oct 6, 2023
8badf16
notebook demo
lucidtronix Oct 6, 2023
6ceeb73
include fix
lucidtronix Oct 6, 2023
fd063da
include fix
lucidtronix Oct 6, 2023
c0eb2ff
include fix
lucidtronix Oct 6, 2023
957c939
xdl recipe
lucidtronix Oct 16, 2023
6e2c361
xdl recipe
lucidtronix Oct 16, 2023
d592784
xdl recipe
lucidtronix Oct 16, 2023
5f73bb4
xdl recipe
lucidtronix Oct 16, 2023
7333876
xdl recipe
lucidtronix Oct 17, 2023
302f3db
xdl recipe
lucidtronix Oct 17, 2023
57f4f38
xdl recipe
lucidtronix Oct 25, 2023
e7b7b2f
xdl recipe
lucidtronix Oct 25, 2023
ae0f15f
xdl recipe
lucidtronix Oct 25, 2023
211bc78
xdl recipe
lucidtronix Oct 25, 2023
619b490
xdl recipe
lucidtronix Oct 25, 2023
481f534
xdl recipe
lucidtronix Oct 25, 2023
a57acf8
xdl recipe
lucidtronix Oct 25, 2023
c6f44e3
xdl recipe
lucidtronix Oct 25, 2023
ad134d8
xdl recipe
lucidtronix Oct 25, 2023
ffceed5
xdl recipe
lucidtronix Oct 25, 2023
d48f936
xdl recipe
lucidtronix Oct 25, 2023
8d78ed0
xdl recipe
lucidtronix Oct 26, 2023
01a1196
fix
lucidtronix Nov 2, 2023
17d8f51
cleanup
lucidtronix Nov 3, 2023
3203ba7
cleanup
lucidtronix Nov 7, 2023
8916d4a
cleanup
lucidtronix Nov 8, 2023
3e027f3
cleanup
lucidtronix Nov 8, 2023
df784ab
cleanup
lucidtronix Nov 9, 2023
7d81d61
cleanup
lucidtronix Nov 27, 2023
3ccd91a
cleanup
lucidtronix Nov 30, 2023
e852d38
add ranod
lucidtronix Dec 6, 2023
42520b4
add ranod
lucidtronix Dec 6, 2023
9eaf159
Merge branch 'master' into sf_dock
lucidtronix Dec 16, 2023
fdc13e0
cleanup
lucidtronix Dec 16, 2023
b0decb2
cleanup
lucidtronix Dec 16, 2023
b090ec7
setup.py
lucidtronix Jan 2, 2024
1dfdd5c
setup.py
lucidtronix Jan 2, 2024
98a2678
setup.py
lucidtronix Jan 2, 2024
4649ea2
setup.py
lucidtronix Jan 2, 2024
a6dcfd8
setup.py
lucidtronix Jan 2, 2024
0f71349
setup.py
lucidtronix Jan 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,3 @@ umap-learn[plot]
neurite
voxelmorph
pystrum

2 changes: 0 additions & 2 deletions ml4h/TensorMap.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def __init__(
elif self.activation is None and (self.is_survival_curve() or self.is_time_to_event()):
self.activation = 'sigmoid'



if self.channel_map is None and self.is_time_to_event():
self.channel_map = DEFAULT_TIME_TO_EVENT_CHANNELS

Expand Down
4 changes: 2 additions & 2 deletions ml4h/data_descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Union, Optional, Tuple, Dict, Any

import h5py
import datetime
import numcodecs
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -331,10 +332,9 @@ def __init__(
):
"""
Gets data from a column of the provided DataFrame.
:param df: Must be multi-indexed with sample_id, loading_option
# TODO: allow multiple loading options
:param col: The column name to get data from
:param process_col: Function to turn the column value into Tensor
:param name: Optional overwrite of the df column name
"""
self.process_col = process_col or self._default_process_call
self.df = df
Expand Down
4 changes: 2 additions & 2 deletions ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def make_hidden_layer_model(parent_model: Model, tensor_maps_in: List[TensorMap]
dummy_input = {tm.input_name(): np.zeros((1,) + parent_model.get_layer(tm.input_name()).input_shape[0][1:]) for tm in tensor_maps_in}
intermediate_layer_model = Model(inputs=parent_inputs, outputs=target_layer.output)
# If we do not predict here then the graph is disconnected, I do not know why?!
intermediate_layer_model.predict(dummy_input)
intermediate_layer_model.predict(dummy_input, verbose=0)
return intermediate_layer_model


Expand Down Expand Up @@ -1344,7 +1344,7 @@ def make_paired_autoencoder_model(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def embed_model_predict(model, tensor_maps_in, embed_layer, test_data, batch_size):
embed_model = make_hidden_layer_model(model, tensor_maps_in, embed_layer)
return embed_model.predict(test_data, batch_size=batch_size)
return embed_model.predict(test_data, batch_size=batch_size, verbose=0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def plot_scatter(

ax1.set_xlabel("Predictions")
ax1.set_ylabel("Actual")
ax1.set_title(title)
ax1.set_title(f'{title} N = {len(prediction)}' )
ax1.legend(loc="lower right")

sns.distplot(prediction, label="Predicted", color="r", ax=ax2)
Expand Down Expand Up @@ -2253,7 +2253,7 @@ def plot_ecg_rest(
tensor_paths: List[str],
rows: List[int],
out_folder: str,
is_blind: bool,
is_blind: bool
) -> None:
"""Plots resting ECGs including annotations and LVH criteria

Expand Down
240 changes: 234 additions & 6 deletions ml4h/recipes.py

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions ml4h/tensormap/mgb/xdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Dict

import h5py
import numpy as np
from ml4h.TensorMap import TensorMap, Interpretation

ecg_5000_std = TensorMap('ecg_5000_std', Interpretation.CONTINUOUS, shape=(5000, 12))

hypertension_icd_only = TensorMap(name='hypertension_icd_only', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_only': 0, 'hypertension_icd_only': 1})
hypertension_icd_bp = TensorMap(name='hypertension_icd_bp', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_bp': 0, 'hypertension_icd_bp': 1})
hypertension_icd_bp_med = TensorMap(name='hypertension_icd_bp_med', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_icd_bp_med': 0, 'hypertension_icd_bp_med': 1})
hypertension_med = TensorMap(name='start_fu_hypertension_med', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypertension_medication': 0, 'hypertension_medication': 1})

lvef = TensorMap(name='LVEF', interpretation=Interpretation.CONTINUOUS, channel_map={'LVEF': 0})

age = TensorMap(name='age_in_days', interpretation=Interpretation.CONTINUOUS, channel_map={'age_in_days': 0})
sex = TensorMap(name='sex', interpretation=Interpretation.CATEGORICAL, channel_map={'Female': 0, 'Male': 1})

cad = TensorMap(name='cad', interpretation=Interpretation.CATEGORICAL, channel_map={'no_cad': 0, 'cad': 1})
dm = TensorMap(name='dm', interpretation=Interpretation.CATEGORICAL, channel_map={'no_dm': 0, 'dm': 1})
hypercholesterolemia = TensorMap(name='hypercholesterolemia', interpretation=Interpretation.CATEGORICAL,
channel_map={'no_hypercholesterolemia': 0, 'hypercholesterolemia': 1})


def ecg_median_biosppy(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
tensor = np.zeros(tm.shape, dtype=np.float32)
for lead in tm.channel_map:
tensor[:, tm.channel_map[lead]] = hd5[f'{tm.path_prefix}{lead}']
tensor = np.nan_to_num(tensor)
return tensor

ecg_channel_map = {
'I': 0, 'II': 1, 'III': 2, 'aVR': 3, 'aVL': 4, 'aVF': 5,
'V1': 6, 'V2': 7, 'V3': 8, 'V4': 9, 'V5': 10, 'V6': 11,
}

ecg_biosppy_median_60bpm = TensorMap(
'median', Interpretation.CONTINUOUS, path_prefix='median_60bpm_', shape=(600, 12),
tensor_from_file=ecg_median_biosppy,
channel_map=ecg_channel_map,
)
15 changes: 10 additions & 5 deletions ml4h/tensormap/ukb/demographics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,11 @@ def alcohol_from_file(tm, hd5, dependents={}):
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)
# sex = TensorMap(
# 'Sex_Male_0_0', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG, path_prefix='categorical', annotation_units=2,
# channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
# )
sex_dummy1 = TensorMap(
'sex', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG,
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)
af_dummy2 = TensorMap(
'af_in_read', Interpretation.CATEGORICAL, path_prefix='categorical', storage_type=StorageType.CATEGORICAL_FLAG,
channel_map={'no_atrial_fibrillation': 0, 'atrial_fibrillation': 1},
Expand All @@ -354,7 +355,11 @@ def alcohol_from_file(tm, hd5, dependents={}):
path_prefix='categorical', annotation_units=2,
channel_map={'Sex_Female_0_0': 0, 'Sex_Male_0_0': 1}, loss='categorical_crossentropy',
)

sex_dummy3 = TensorMap(
'sex_from_wide', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG,
path_prefix='categorical', annotation_units=2,
channel_map={'female': 0, 'male': 1}, loss='categorical_crossentropy',
)
brain_volume = TensorMap(
'25010_Volume-of-brain-greywhite-matter_2_0', Interpretation.CONTINUOUS, path_prefix='continuous', normalization={'mean': 1165940.0, 'std': 111511.0},
channel_map={'25010_Volume-of-brain-greywhite-matter_2_0': 0}, loss='logcosh', loss_weight=0.1,
Expand Down
2 changes: 1 addition & 1 deletion ml4h/tensormap/ukb/dxa.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dxa_background_erase(tm, hd5, dependents={}):
)
dxa_11 = TensorMap(
'dxa_1_11',
shape=(896, 352, 1),
shape=(896, 384, 1),
path_prefix='ukb_dxa',
tensor_from_file=dxa_background_erase,
normalization=ZeroMeanStd1(),
Expand Down
7 changes: 5 additions & 2 deletions ml4h/tensormap/ukb/ecg.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}):
metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS, normalization=Standardize(mean=0, std=10),
)


ecg_rest_median_576 = TensorMap(
'ecg_rest_median_576', Interpretation.CONTINUOUS, path_prefix='ukb_ecg_rest', shape=(576, 12), loss='logcosh',
activation='linear', tensor_from_file=_make_ecg_rest(), channel_map=ECG_REST_MEDIAN_LEADS,
Expand All @@ -595,8 +596,10 @@ def ecg_rest_section_to_segment(tm, hd5, dependents={}):
)

ecg_rest_median_raw_10_prediction = TensorMap(
'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear', normalization=ZeroMeanStd1(),
tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'], channel_map=ECG_REST_MEDIAN_LEADS,
'ecg_rest_median_raw_10', Interpretation.CONTINUOUS, shape=(600, 12), loss='logcosh', activation='linear',
normalization=ZeroMeanStd1(),
tensor_from_file=named_tensor_from_hd5('ecg_rest_median_raw_10_prediction'), metrics=['mse', 'mae'],
channel_map=ECG_REST_MEDIAN_LEADS,
)


Expand Down
65 changes: 65 additions & 0 deletions ml4h/tensormap/ukb/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,25 @@ def _slice_tensor_from_file(tm, hd5, dependents={}):
return _slice_tensor_from_file


def _random_slice_tensor(tensor_key, max_random=50):
def _slice_tensor_from_file(tm, hd5, dependents={}):
slice_index = np.random.randint(max_random)
if tm.shape[-1] == 1:
t = pad_or_crop_array_to_shape(
tm.shape[:-1],
np.array(hd5[tensor_key][..., slice_index], dtype=np.float32),
)
tensor = np.expand_dims(t, axis=-1)
else:
tensor = pad_or_crop_array_to_shape(
tm.shape,
np.array(hd5[tensor_key][..., slice_index], dtype=np.float32),
)
return tensor

return _slice_tensor_from_file


def _segmented_dicom_slices(dicom_key_prefix, path_prefix='ukb_cardiac_mri', step=1, total_slices=50):
def _segmented_dicom_tensor_from_file(tm, hd5, dependents={}):
tensor = np.zeros(tm.shape, dtype=np.float32)
Expand Down Expand Up @@ -389,6 +408,12 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
tensor_from_file=_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0', 0),
)

lax_4ch_random_slice_3d = TensorMap(
'lax_4ch_random_slice_3d', Interpretation.CONTINUOUS, shape=(160, 224, 1),
normalization=ZeroMeanStd1(),
tensor_from_file=_random_slice_tensor('ukb_cardiac_mri/cine_segmented_lax_4ch/2/instance_0'),
)

lax_4ch_diastole_slice0_224_3d_augmented = TensorMap(
'lax_4ch_diastole_slice0_224_3d_augmented', Interpretation.CONTINUOUS, shape=(160, 224, 1),
normalization=ZeroMeanStd1(), augmentations=[_gaussian_noise, _make_rotate(-15, 15)],
Expand All @@ -415,6 +440,36 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_160_3d = TensorMap(
'lax_2ch_diastole_slice_224_160_3d',
Interpretation.CONTINUOUS,
shape=(224, 160, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_192_3d = TensorMap(
'lax_2ch_diastole_slice_224_192_3d',
Interpretation.CONTINUOUS,
shape=(224, 192, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_2ch_diastole_slice_224_224_3d = TensorMap(
'lax_2ch_diastole_slice_224_224_3d',
Interpretation.CONTINUOUS,
shape=(224, 224, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_2ch/2/instance_0', 0,
),
)
lax_3ch_diastole_slice0_3d = TensorMap(
'lax_3ch_diastole_slice0_3d',
Interpretation.CONTINUOUS,
Expand All @@ -425,6 +480,16 @@ def _mri_slice_blackout_tensor_from_file(tm, hd5, dependents={}):
'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0,
),
)
lax_3ch_diastole_slice_224_160_3d = TensorMap(
'lax_3ch_diastole_slice_224_160_3d',
Interpretation.CONTINUOUS,
shape=(224, 160, 1),
loss='logcosh',
normalization=ZeroMeanStd1(),
tensor_from_file=_slice_tensor(
'ukb_cardiac_mri/cine_segmented_lax_3ch/2/instance_0', 0,
),
)
cine_segmented_ao_dist_slice0_3d = TensorMap(
'cine_segmented_ao_dist_slice0_3d',
Interpretation.CONTINUOUS,
Expand Down
56 changes: 49 additions & 7 deletions model_zoo/ECG2AF/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,44 @@
This directory contains models and code for predicting incident atrial fibrillation from 12 lead resting ECGs, as described in our
[Circulation paper](https://www.ahajournals.org/doi/full/10.1161/CIRCULATIONAHA.121.057480).

To perform inference with this model run:
The raw model files are stored using `git lfs` so you must have it installed and localize the full ~200MB files with:
```bash
git lfs pull --include model_zoo/ECG2AF/ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5
git lfs pull --include model_zoo/ECG2AF/strip_*
```

To load the 12 lead model in a jupyter notebook (running with the ml4h docker or python library installed) see the [example](./ecg2af_infer.ipynb) or run:

```python
import numpy as np
from tensorflow.keras.models import load_model
from ml4h.models.model_factory import get_custom_objects
from ml4h.tensormap.ukb.survival import mgb_afib_wrt_instance2
from ml4h.tensormap.ukb.demographics import age_2_wide, af_dummy, sex_dummy3

output_tensormaps = {tm.output_name(): tm for tm in [mgb_afib_wrt_instance2, age_2_wide, af_dummy, sex_dummy3]}
custom_dict = get_custom_objects(list(output_tensormaps.values()))
model = load_model('./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5', custom_objects=custom_dict)
ecg = np.random.random((1, 5000, 12))
prediction = model(ecg)
```
If above does not work you may need to use an absolute path in `load_model`.

The model has 4 output heads: the survival curve prediction for incident atrial fibrillation, the classification of atrial fibrillation at the time of ECG, sex, and age regression. Those outputs can be accessed with:
```python
for name, pred in zip(model.output_names, prediction):
otm = output_tensormaps[name]
if otm.is_survival_curve():
intervals = otm.shape[-1] // 2
days_per_bin = 1 + otm.days_window // intervals
predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
print(f'AF Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}')
else:
print(f'{otm} prediction is {pred}')
```


To perform command line inference with this model run:
```bash
python /path/to/ml4h/ml4h/recipes.py \
--mode infer \
Expand All @@ -20,18 +57,23 @@ The model weights for the main model which performs incident atrial fibrillation
age regression, sex classification and prevalent (at the time of ECG) atrial fibrillation:
[ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5](./ecg_5000_survival_curve_af_quadruple_task_mgh_v2021_05_21.h5)

We also include single lead models for lead strip I:[strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5)
We also include single lead models for lead/strip I: [strip_I_survival_curve_af_v2021_06_15.h5](./strip_I_survival_curve_af_v2021_06_15.h5)
and II: [strip_II_survival_curve_af_v2021_06_15.h5](./strip_II_survival_curve_af_v2021_06_15.h5)

### Study Design
Flow chart of study design
![Flow chart of study design](./study_design.jpg)
### Study design
<div style="padding: 10px; background-color: white; display: inline-block;">
<img src="./study_design.jpg" alt="Flow chart of study design" />
</div>

### Performance
Risk stratification model comparison
![Risk stratification model comparison](./km.jpg)
<div style="padding: 10px; background-color: white; display: inline-block;">
<img src="./km.jpg" alt="Risk stratification model comparison" />
</div>

### Salience
Salience and Median waveforms from predicted risk extremes.
![Salience and Median waveforms](./salience.jpg)
### Architecture
1D Convolutional neural net architecture
![Convolutional neural net architecture](./architecture.png)
![Convolutional neural net architecture](./architecture.png)
Loading
Loading