Skip to content

Commit

Permalink
Merge branch 'merge-med-shifts' into 'main'
Browse files Browse the repository at this point in the history
MedShifts

See merge request hi-dkfz/iml/failure-detection-benchmark!16
  • Loading branch information
tbung committed Oct 2, 2023
2 parents 3500f91 + f7324fb commit cf57729
Show file tree
Hide file tree
Showing 84 changed files with 7,014 additions and 105 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ typings
experiments.json
_version.py
.coverage
bak_*

data_folder/
experiments_folder/
experiments_test/
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ If you use fd-shifts please cite our [paper](https://openreview.net/pdf?id=YnkGM
}
```

> **Note**
> This repository also contains the benchmarks for our follow-up study ["Understanding Silent Failures in Medical Image Classification"](https://arxiv.org/abs/2307.14729). For the visualization tool presented in that work please see [sf-visuals](https://github.com/IML-DKFZ/sf-visuals).
## Table Of Contents

<!--toc:start-->
Expand Down
58 changes: 58 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,61 @@ website](https://image-net.org/download-images.php) and move or symlink it to
## Training on CAMELYON-17-Wilds, iWildCam-2020-Wilds

These datasets will be downloaded automatically.

## Dermoscopy Data

You can download the individual datasets from their respective websites:

- [PH2](https://www.fc.up.pt/addi/ph2%20database.html) (extract to `$DATASET_ROOT_DIR/ph2`)
- [HAM10000](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DBW86T) (extract to `$DATASET_ROOT_DIR/ham10000`)
- [derm7pt](http://derm.cs.sfu.ca/) (extract to `$DATASET_ROOT_DIR/d7p`)
- [isic2020](https://challenge.isic-archive.com/data/#2020) (extract to `$DATASET_ROOT_DIR/isic_2020`)

Unpack them into their respective folders, then run the data preprocessing:

```bash
fd_shifts prepare --dataset dermoscopy
```

## Microscopy Data

Download the dataset from their website and extract to `$DATASET_ROOT_DIR/rxrx1`.

- [Rxrx1](https://www.rxrx.ai/rxrx1#Download)

Then run the data preprocessing:

```bash
fd_shifts prepare --dataset microscopy
```

## Chest XRay Data

You can download the individual datasets from their respective websites:

> **Note**
> MIMIC requires you to apply for credentialed access.
- [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) (extract to `$DATASET_ROOT_DIR/chexpert`)
- [MIMIC](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) (extract to `$DATASET_ROOT_DIR/mimic`)
- [NIH14](https://www.nih.gov/news-events/news-releases/nih-clinical-center-provides-one-largest-publicly-available-chest-x-ray-datasets-scientific-community) (extract to `$DATASET_ROOT_DIR/nih14`)

Unpack them into their respective folders, then run the data preprocessing:

```bash
fd_shifts prepare --dataset xray
```

## Lung CT Data

Download the dataset from their website and extract to `$DATASET_ROOT_DIR/lidc_idri`.

- [LIDC-IDRI](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254)

Prepare the dataset by following the instructions in [the separate LIDC readme](../fd_shifts/loaders/preparation/lidc-idri/README.md).

Finaly, run the data preprocessing:

```bash
fd_shifts prepare --dataset lung_ct
```
24 changes: 23 additions & 1 deletion fd_shifts/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ExperimentData:
mcd_external_confids_dist: npt.NDArray[Any] | None = None

_mcd_correct: npt.NDArray[Any] | None = field(default=None)
_mcd_labels: npt.NDArray[Any] | None = field(default=None)
_correct: npt.NDArray[Any] | None = field(default=None)

@property
Expand Down Expand Up @@ -77,6 +78,14 @@ def mcd_correct(self) -> npt.NDArray[Any] | None:
return None
return (np.argmax(self.mcd_softmax_mean, axis=1) == self.labels).astype(int)

@property
def mcd_labels(self) -> npt.NDArray[Any] | None:
if self._mcd_labels is not None:
return self._mcd_labels
if self.mcd_softmax_mean is None:
return None
return self.labels

def dataset_name_to_idx(self, dataset_name: str) -> int:
if dataset_name == "val_tuning":
return 0
Expand Down Expand Up @@ -408,6 +417,7 @@ def _get_confidence_scores(self, study_data: ExperimentData):

self.method_dict[query_confid] = {}
self.method_dict[query_confid]["confids"] = confids
self.method_dict[query_confid]["labels"] = confid_score.labels
self.method_dict[query_confid]["correct"] = confid_score.correct
self.method_dict[query_confid]["metrics"] = confid_score.metrics
self.method_dict[query_confid]["predict"] = confid_score.predict
Expand All @@ -423,6 +433,7 @@ def _get_confidence_scores(self, study_data: ExperimentData):
self.method_dict[query_confid]["correct"] = confid_score.correct
self.method_dict[query_confid]["metrics"] = confid_score.metrics
self.method_dict[query_confid]["predict"] = confid_score.predict
self.method_dict[query_confid]["labels"] = confid_score.labels

def _compute_performance_metrics(self, softmax, labels, correct):
performance_metrics = {}
Expand All @@ -437,6 +448,13 @@ def _compute_performance_metrics(self, softmax, labels, correct):
)
if "accuracy" in self.query_performance_metrics:
performance_metrics["accuracy"] = np.sum(correct) / correct.size
if "b-accuracy" in self.query_performance_metrics:
accuracies_list = []
for cla in np.unique(labels):
is_class = labels == cla
accuracy_class = np.mean(correct[is_class])
accuracies_list.append(accuracy_class)
performance_metrics["b-accuracy"] = np.mean(accuracies_list)
if "brier_score" in self.query_performance_metrics:
if "new_class" in self.study_name or "openset" in self.study_name:
performance_metrics["brier_score"] = None
Expand Down Expand Up @@ -474,6 +492,7 @@ def _compute_confid_metrics(self):
eval = ConfidEvaluator(
confids=confid_dict["confids"],
correct=confid_dict["correct"],
labels=confid_dict.get("labels"),
query_metrics=self.query_confid_metrics,
query_plots=self.query_plots,
bins=self.calibration_bins,
Expand Down Expand Up @@ -534,6 +553,7 @@ def _compute_confid_metrics(self):
query_metrics=self.query_confid_metrics,
query_plots=self.query_plots,
bins=self.calibration_bins,
labels=confid_dict["labels"],
)
self.threshold_plot_dict = {}
self.plot_threshs = []
Expand Down Expand Up @@ -579,6 +599,7 @@ def _compute_confid_metrics(self):
query_metrics=self.query_confid_metrics,
query_plots=self.query_plots,
bins=self.calibration_bins,
labels=confid_dict["labels"],
)
true_thresh = eval.get_val_risk_scores(
self.rstar, 0.1, no_bound_mode=True
Expand Down Expand Up @@ -821,7 +842,7 @@ def main(

analysis_out_dir = out_path

query_performance_metrics = ["accuracy", "nll", "brier_score"]
query_performance_metrics = ["accuracy", "b-accuracy", "nll", "brier_score"]
query_confid_metrics = [
"failauc",
"failap_suc",
Expand All @@ -830,6 +851,7 @@ def main(
"mce",
"ece",
"e-aurc",
"b-aurc",
"aurc",
"fpr@95tpr",
"risk@100cov",
Expand Down
6 changes: 5 additions & 1 deletion fd_shifts/analysis/confid_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,15 @@ def __init__(
assert study_data.mcd_softmax_dist is not None
self.softmax = study_data.mcd_softmax_mean
self.correct = study_data.mcd_correct
self.labels = study_data.mcd_labels

self.confid_args = (
study_data.mcd_softmax_mean,
study_data.mcd_softmax_dist,
)
self.performance_args = (
study_data.mcd_softmax_mean,
study_data.labels,
study_data.mcd_labels,
study_data.mcd_correct,
)

Expand All @@ -327,6 +329,7 @@ def __init__(
else:
self.softmax = study_data.softmax_output
self.correct = study_data.correct
self.labels = study_data.labels
self.confid_args = (study_data.softmax_output,)
self.performance_args = (
study_data.softmax_output,
Expand Down Expand Up @@ -396,6 +399,7 @@ def __init__(
study_data.labels,
study_data.correct,
)
self.labels = study_data.labels

self.confid_args, self.confid_func = parse_secondary_confid(
query_confid, analysis
Expand Down
20 changes: 15 additions & 5 deletions fd_shifts/analysis/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _get_tb_hparams(cf):
def monitor_eval(
running_confid_stats,
running_perf_stats,
running_labels,
query_confid_metrics,
query_monitor_plots,
do_plot=True,
Expand All @@ -31,6 +32,7 @@ def monitor_eval(
out_metrics = {}
out_plots = {}
bins = 20
labels_cpu = torch.stack(running_labels, dim=0).cpu().data.numpy()

# currently not implemented for mcd_softmax_mean
for perf_key, perf_list in running_perf_stats.items():
Expand Down Expand Up @@ -67,6 +69,7 @@ def monitor_eval(
eval = ConfidEvaluator(
confids=confids_cpu,
correct=correct_cpu,
labels=labels_cpu,
query_metrics=query_confid_metrics,
query_plots=query_monitor_plots,
bins=bins,
Expand Down Expand Up @@ -112,7 +115,7 @@ def monitor_eval(


class ConfidEvaluator:
def __init__(self, confids, correct, query_metrics, query_plots, bins):
def __init__(self, confids, correct, labels, query_metrics, query_plots, bins):
self.confids = confids[~np.isnan(confids)]
self.correct = correct[~np.isnan(confids)]
self.query_metrics = query_metrics
Expand All @@ -125,8 +128,10 @@ def __init__(self, confids, correct, query_metrics, query_plots, bins):
self.rc_curve = None
self.precision_list = None
self.recall_list = None

self.stats_cache = StatsCache(self.confids, self.correct, self.bins)
self.labels = labels
self.stats_cache = StatsCache(
self.confids, self.correct, self.bins, self.labels
)

def get_metrics_per_confid(self):
out_metrics = {}
Expand All @@ -149,12 +154,17 @@ def get_metrics_per_confid(self):
self.stats_cache
)

if "aurc" in self.query_metrics or "e-aurc" in self.query_metrics:
if (
"aurc" in self.query_metrics
or "e-aurc" in self.query_metrics
or "b-aurc" in self.query_metrics
):
if self.rc_curve is None:
self.get_rc_curve_stats()
if "aurc" in self.query_metrics:
out_metrics["aurc"] = get_metric_function("aurc")(self.stats_cache)

if "b-aurc" in self.query_metrics:
out_metrics["b-aurc"] = get_metric_function("b-aurc")(self.stats_cache)
if "e-aurc" in self.query_metrics:
out_metrics["e-aurc"] = get_metric_function("e-aurc")(self.stats_cache)

Expand Down
88 changes: 88 additions & 0 deletions fd_shifts/analysis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class StatsCache:
confids: npt.NDArray[Any]
correct: npt.NDArray[Any]
n_bins: int
labels: npt.NDArray[Any] | None = None

@cached_property
def roc_curve_stats(self) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
Expand All @@ -62,6 +63,83 @@ def roc_curve_stats(self) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
def residuals(self) -> npt.NDArray[Any]:
return 1 - self.correct

@cached_property
def brc_curve_stats(self) -> tuple[list[float], list[float], list[float]]:
coverages = []
balanced_risks = []
error_per_class = {}
risk_per_class = {}
assert self.labels is not None, "labels must be set"
assert len(self.labels) == len(
self.confids
), "labels must be same size as confids"
# calculate risk per class
n_residuals = len(self.residuals)
idx_sorted = np.argsort(self.confids)
n_remaining_per_class = {}
# coverage = number samples
coverage = n_residuals
# calcualte baselines:
# errors per class, residuals per class (total amount of images/errors from that class)
# risk per class: errors per class by remaining images in this class
# if there are no more images of a class risk is set to None and then filtered out before calculating mean
for label in np.unique(self.labels):
# remaining_labels = self.labels[idx_sorted]
idx_class = np.where(self.labels == label)[0]
error_per_class[label] = sum(self.residuals[idx_class])
n_remaining_per_class[label] = len(idx_class)
if n_remaining_per_class[label] == 0:
risk_per_class[label] = None
else:
risk_per_class[label] = (
error_per_class[label] / n_remaining_per_class[label]
)
# coverage and risk point on the curve. starting point
coverages.append(coverage / n_residuals)
balanced_risks.append(
np.array([x for x in risk_per_class.values() if x is not None]).mean()
)
weights = []
tmp_weight = 0
for i in range(0, len(idx_sorted) - 1):
coverage = coverage - 1
# Decide which class the images is taken from
label = int(self.labels[idx_sorted[i]])
# from that class subtract 1 if an error is taken out and 0 if no error is taken out
error_per_class[label] = (
error_per_class[label] - self.residuals[idx_sorted[i]]
)
# reduce the remaining amount of images in the class an images was taken out
n_remaining_per_class[label] = n_remaining_per_class[label] - 1
assert (
n_remaining_per_class[label] >= 0
), "Remaining images should be larger 0"
# if there is one or no more images remaining in a class risk is set to 0
# otherwise risk of the class is errors remaining divided by number images remaining
if n_remaining_per_class[label] < 1:
risk_per_class[label] = None
else:
risk_per_class[label] = error_per_class[label] / (
n_remaining_per_class[label]
)
assert risk_per_class[label] >= 0, "Risk can never be below 0"
tmp_weight += 1
if i == 0 or self.confids[idx_sorted[i]] != self.confids[idx_sorted[i - 1]]:
coverages.append(coverage / n_residuals)
balanced_risks.append(
np.array(
[x for x in risk_per_class.values() if x is not None]
).mean()
)
weights.append(tmp_weight / n_residuals)
tmp_weight = 0
# add a well-defined final point to the RC-curve.
if tmp_weight > 0:
coverages.append(0)
balanced_risks.append(balanced_risks[-1])
weights.append(tmp_weight / n_residuals)
return coverages, balanced_risks, weights

@cached_property
def rc_curve_stats(self) -> tuple[list[float], list[float], list[float]]:
coverages = []
Expand Down Expand Up @@ -274,6 +352,16 @@ def aurc(stats_cache: StatsCache) -> float:
)


@register_metric_func("b-aurc")
@may_raise_sklearn_exception
def baurc(stats_cache: StatsCache):
_, risks, weights = stats_cache.brc_curve_stats
return (
sum([(risks[i] + risks[i + 1]) * 0.5 * weights[i] for i in range(len(weights))])
* AURC_DISPLAY_SCALE
)


@register_metric_func("e-aurc")
@may_raise_sklearn_exception
def eaurc(stats_cache: StatsCache) -> float:
Expand Down
Loading

0 comments on commit cf57729

Please sign in to comment.