-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiments.tex
110 lines (94 loc) · 9.86 KB
/
experiments.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
\section{Datasets}
Our experiments are carried out on natural distribution shifts across multiple domains, modalities, and model types.
We use the \textit{CIFAR-10.1} dataset~\citep{cifar101} where shift comes from subtle changes in the dataset creation processes,
the \textit{Camelyon17 dataset}~\citep{camelyon} for metastases detection in histopathological slides from multiple hospitals, as well as the \textit{UCI heart disease} dataset~\citep{misc_heart_disease_45} which contains tabular features collected across international health systems and indicators of heart disease.
A summary of the three datasets used as well as the description of shifts and their impacts on base model performance is provided in \autoref{tab:datasets_main} below.
\begin{table}[!htb]
\centering
\setlength\tabcolsep{1.5pt}
\renewcommand{\arraystretch}{0.5}
\small
\caption{\small \textbf{Datasets:} We investigate three different forms of covariate shift in two unique data modalities.
To verify that these shifts are indeed harmful to the models, we report performance in both the shifted and unshifted domains.
Examples and further descriptions of unshifted/shifted splits of each dataset are given in \autoref{sec:data}.}
\resizebox{\textwidth}{!}{
\begin{tabular}{cccccc}
\toprule[1.5pt]
Domain/Task & {Dataset} & \thead{Shift} & \thead{Metric} & \thead{(Unshifted)} & \thead{(Shifted)} \\\midrule[1.5pt]
\thead{Natural Images \\\emph{Object classification}} & \thead{CIFAR-10/10.1 \\~\citep{cifar101}} & \thead{Data Collection \\ Process} & \thead{Accuracy} & \thead{0.87 (Resent18)} & \thead{0.77 (Resent18)} \\\midrule
\thead{Histopathological Images \\ \emph{Metastases Detection}} & \thead{Camelyon-17 \\\citep{camelyon}} & \thead{Different \\Hospitals} & \thead{Accuracy} & \thead{0.93 (Resent18)} & \thead{0.81 (Resent18)} \\\midrule
\thead{Tabular Medical Data \\\emph{Angiographic Status}} & \thead{UCI Heart Disease \\\citep{misc_heart_disease_45}} & \thead{Different \\Countries} & \thead{AUROC} & \thead{0.88 (xgboost)\\0.85 (MLP)} & \thead{0.70 (xgboost)\\0.42 (MLP)} \\\midrule
\end{tabular}}
\label{tab:datasets_main}
\end{table}
\section{Learning Constrained Disagreement}
Before discussing the results of the Detectron test, we begin by investigating the CDC algorithm (\autoref{alg:constrained}) in isolation to verify that it meets our desired criteria in \autoref{def:cdc}.
We begin by training ensembles of $10$ CDCs using the \textit{disagreement cross entropy} (DCE) with CIFAR-10 as $\Pc$ and CIFAR-10.1 as $\Qc$ for 100 random runs at a sample size of $50$.
We use an ImageNet pretrained residual network fine-tuned on CIFAR 10 as our basemodel $f$; see \autoref{sec:expdet} for full training details.
CDCs are initialized from $f$ and use the same training algorithm/loss function when updated on samples from $\Pc$ (property 1 in \autoref{def:cdc}).
The results in \autoref{fig:disrates} empirically validates minimizing the DCE as a CDC learning objective.
The first observation is that when an unseen set is drawn from a shifted distribution $\Qc$, the empirical disagreement rate $\phi_\boldQ$ grows significantly larger
than the baseline disagreement rate $\phi_\boldP$ (property 3).
Next, we see that CDCs preserve accuracy on unseen data from the training distribution, regardless of the dataset they are trained to disagree on (property 2).
Finally, an additional observation shows that as the ensemble size increases (and disagreed upon points are removed) we see that the selective classification accuracy on the agreed upon points increases.
Furthermore, the concave behaviour of the selective classification curve indicates that the points that CDCs quickly disagree on many data samples would have been misclassified by $f$,
while in later iterations fewer of such samples are identified (see \autoref{fig:selective_cls} for a direct comparison of accuracy vs rejection rate of CDCs).
\begin{figure}[!htb]
\centering
\includegraphics[width=\textwidth]{images/acc_dis.pdf}
\caption{\small \textbf{Ensemble Size vs Properties of Constrained disagreement classifiers on CIFAR-10/10.1:}
(Left) We see that for all ensemble sizes, there is lower disagreement on unshifted data (CIFAR-10) compared to disagreement on shifted data (CIFAR-10.1).
(Center) Constrained disagreement does not compromise in-distribution performance as computed on unseen data.
(Right) As the ensemble grows the selective classification accuracy computed on the set of test examples that all models agree on increases both on in-distribution and out-of-distribution data.
Confidence intervals are computed as $\pm$ one standard deviation across experiments.}
\label{fig:disrates}
\end{figure}
\begin{figure}[!htb]
\floatbox[{\capbeside\thisfloatsetup{capbesideposition={right,top},capbesidewidth=0.4\linewidth}}]{figure}[\FBwidth]
{\caption{\small \textbf{Selective Classification Behaviour of CDCS}: Comparing the accuracy versus rejection rate of selective classification on ID (CIFAR 10) and shifted (CIFAR 10.1)
sets we see that the CDC algorithm constructs highly linear selective classifiers. As expected the rejection rate required to reach high accuracy on ID data is significantly lower
than that required on shifted data.}
\label{fig:selective_cls}}
{\includegraphics[width=\linewidth]{images/selective_classification}}
\end{figure}
\section{Shift Detection Approach} We evaluate the \method\ in a standard two-sample testing scenario similar to prior work~\citep{zhao2022comparing}.
Given two datasets $\boldP = \{(x_i,y_i)\}_{i=1}^n$ ($x_i$ drawn from $\Pc$) and $\boldQ = \{\tilde{x}_i\}_{i=1}^m$ ($\tilde{x}_i$ drawn from $\Qc$) and classifier $f$,
we seek to rule out the null hypothesis ($\Pc = \Qc$) at the $5\%$ significance level.
To guarantee fixed significance we employ a permutation test by first sampling from the distribution of test statistics
derived by the \method\ where the null hypothesis $\Pc=\Qc$ holds (i.e., $\boldQ$ is drawn $\Pc$).
Next, we compute a threshold over the empirical test statistic distribution that sets the false positive rate to $5\%$ (see \autoref{fig:detectron_permutation}).
This step can be performed before deployment as it only requires access to $\boldP$.
To mimic deployment settings where we wish to identify covariate shift quickly,
we assume access to far more samples from $\Pc$ compared to $\Qc$.
For each dataset, we begin by training a base classifier on the unshifted dataset.
We evaluate the detection of covariate shift on 100 randomly selected test sets of $n=$ 10, 20 and 50 samples from $\Qc$.
In all cases, we train a maximum ensemble size of 5 (parameter $\aleph$ in \autoref{alg:detectron}).
To prevent CDCs from overfitting in the case of small test set sizes, we perform early stopping if in-distribution validation performance drops by over 5\% from the measured performance of the base classifier.
Hyperparameters and training details for all models can be found in \autoref{sec:expdet}.
\subsection{Evaluation} We report the \textit{True Positive Rate at 5\% Significance Level (TPR@5)} aggregated over $100$ randomly selected sets $\boldQ$.
This signifies how often our method correctly identifies covariate shift ($\Pc\neq\Qc$) while only incorrectly identifying shift 5\% of the time.
This is also referred to as the statistical power of a test where the significance level or probability of making a type I error is 5\%.
% (2) \textit{Area Under the Receiver Operating Characteristic Curve (AUC)}: To showcase the sensitivity of our method - a desirable characteristic in high-risk domains where false negatives are far more costly than false positives - we report the area under the true positive vs false positive curve generated by varying the significance level from 0 to 1. %This is identical to the conventional interpretation of the AUC for classification problems.
\subsection{Baselines} We compare the \method\ against several methods for OOD detection, uncertainty estimation and covariate shift detection. \textit{Deep Ensembles}~\cite{trustuncert} using both (1) \textit{disagreement} and (2) \textit{entropy} scoring methods as a direct ablation to the CDC approach (3) \textit{Black Box Shift Detection (BBSD)}~\citep{bbsd}.
(4) \textit{Relative Mahalanobis Distance (RMD)}~\citep{relmahala}.
(5) \textit{Classifier Two Sample Test (CTST)}~\citep{paz2017revisiting}.
(6) \textit{Deep Kernel MMD (MMD-D)}~\citep{liu2020learning}.
(7) \textit{H Divergence (H-Div)}~\citep{zhao2022comparing}.
For more information on baselines see Appendix \autoref{subsec:baselines}.
\section{Shift Detection Experiments}
We present statistical power (TPR@5) results for sample sizes of 10, 20 and 50 using the \method (\autoref{alg:detectron})
on all datasets are shown in \autoref{tab:results}.
We report the mean and standard error of TPR@5 computed on 100 samples for each trial.
\input{results}
To expand on \autoref{tab:results} we show an extended analysis of the performance of the \method\ on the UCI Heart Disease dataset.
Using a sample size ranging from 10 to 100 we compute the TPR@5 averaged over 100 trials and plot the results in \autoref{fig:uci_plot}.
As the \method\ is model agnostic we use gradient boosted trees (XGBoost~\citep{xgb}) for \method\ and CTST
methods while the remaining baselines which require neural network models use a 2 layer MLP that achieves similar in-distribution performance.
\begin{figure}[!htb]
\centering
\includegraphics[width=\textwidth]{images/uci_plot.pdf}
\caption{\small \textbf{True positive rate at the 5\% significance level} for the \method\ and baseline methods for detection of covariate shift on the UCI heart disease dataset.
The \method\ (Entropy) is shown to uniformly outperform baselines.
Confidence intervals are excluded for visual clarity but are found in \autoref{tab:results}.}
\label{fig:uci_plot}
\end{figure}