From f8a3bfc372bad6ac396decfd819c8d640d5988ff Mon Sep 17 00:00:00 2001 From: Chris Endemann Date: Thu, 19 Dec 2024 14:20:46 -0600 Subject: [PATCH] Update 7b-OOD-detection-softmax.md --- episodes/7b-OOD-detection-softmax.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/episodes/7b-OOD-detection-softmax.md b/episodes/7b-OOD-detection-softmax.md index ee6e7e2..7c435e1 100644 --- a/episodes/7b-OOD-detection-softmax.md +++ b/episodes/7b-OOD-detection-softmax.md @@ -23,6 +23,7 @@ exercises: 0 :::::::::::::::::::::::::::::::::::::::::::::::::: +## Leveraging softmax model outputs Softmax-based methods are among the most widely used techniques for out-of-distribution (OOD) detection, leveraging the probabilistic outputs of a model to differentiate between in-distribution (ID) and OOD data. These methods are inherently tied to models employing a softmax activation function in their final layer, such as logistic regression or neural networks with a classification output layer. The softmax function normalizes the logits (i.e., sum of neuron input without passing through activation function) in the final layer, squeezing the output into a range between 0 and 1. This is useful for interpreting the model’s predictions as probabilities. Softmax probabilities are computed as: @@ -31,9 +32,9 @@ $$ P(y = k \mid x) = \frac{\exp(f_k(x))}{ \sum_{j} \exp(f_j(x))} $$ - In this lesson, we will train a logistic regression model to classify images from the Fashion MNIST dataset and explore how its softmax outputs can signal whether a given input belongs to the ID classes (e.g., T-shirts or pants) or is OOD (e.g., sandals). While softmax is most naturally applied in models with a logistic activation, alternative approaches, such as applying softmax-like operations post hoc to models with different architectures, are occasionally used. However, these alternatives are less common and may require additional considerations. By focusing on logistic regression, we aim to illustrate the fundamental principles of softmax-based OOD detection in a simple and interpretable context before extending these ideas to more complex architectures. -### Prepare the ID (train and test) and OOD data + +## Prepare the ID (train and test) and OOD data In order to determine a threshold that can separate ID data from OOD data (or ensure new test data as ID), we need to sample data from both distributions. OOD data used should be representative of potential new classes (i.e., semanitic shift) that may be seen by your model, or distribution/covariate shifts observed in your application area. * ID = T-shirts/Blouses, Pants @@ -120,7 +121,7 @@ def plot_data_sample(train_data, ood_data): ``` Load and prepare the ID data (train+test containing shirts and pants) and OOD data (sandals) -**Why not just add the OOD class to training dataset?** +## Why not just add the OOD class to training dataset? OOD data is, by definition, not part of the training distribution. It could encompass anything outside the known classes, which means you'd need to collect a representative dataset for "everything else" to train the OOD class. This is practically impossible because OOD data is often diverse and unbounded (e.g., new species, novel medical conditions, adversarial examples). The key idea behind threshold-based methods is we want to vet our model against a small sample of potential risk-cases using known OOD data to determine an empirical threshold that *hopefully* extends to other OOD cases that may arise in real-world scenarios. @@ -136,7 +137,7 @@ Plot sample fig = plot_data_sample(train_data, ood_data) plt.show() ``` -## Visualizing OOD and ID data +## Visualizing OOD and ID data with PCA ### PCA PCA visualization can provide insights into how well a model is separating ID and OOD data. If the OOD data overlaps significantly with ID data in the PCA space, it might indicate that the model could struggle to correctly identify OOD samples. @@ -178,6 +179,7 @@ From this plot, we see that sandals are more likely to be confused as T-shirts t * **Over-reliance on linear relationships**: Part of this has to do with the fact that we're only looking at linear relationships and treating each pixel as its own input feature, which is usually never a great idea when working with image data. In our next example, we'll switch to the more modern approach of CNNs. * **Semantic gap != feature gap**: Another factor of note is that images that have a wide semantic gap may not necessarily translate to a wide gap in terms of the data's visual features (e.g., ankle boots and bags might both be small, have leather, and have zippers). Part of an effective OOD detection scheme involves thinking carefully about what sorts of data contanimations may be observed by the model, and assessing how similar these contaminations may be to your desired class labels. + ## Train and evaluate model on ID data ```python model = LogisticRegression(max_iter=10, solver='lbfgs', multi_class='multinomial').fit(train_data_flat, train_labels) # 'lbfgs' is an efficient solver that works well for small to medium-sized datasets. @@ -290,7 +292,7 @@ Unfortunately, we observe a significant amount of overlap between OOD data and h For pants, the problem is much less severe. It looks like a low threshold (on this T-shirt probability scale) can separate nearly all OOD samples from being pants. -### Setting a threshold +## Setting a threshold Let's put our observations to the test and produce a confusion matrix that includes ID-pants, ID-Tshirts, and OOD class labels. We'll start with a high threshold of 0.9 to see how that performs. ```python def softmax_thresh_classifications(probs, threshold): @@ -371,7 +373,7 @@ What threhsold is required to ensure that no OOD samples are incorrectly conside With a very conservative threshold, we can make sure very few OOD samples are incorrectly classified as ID. However, the flip side is that conservative thresholds tend to incorrectly classify many ID samples as being OOD. In this case, we incorrectly assume almost 20% of shirts are OOD samples. -## Iterative Threshold Determination +## Iterative threshold determination In practice, selecting an appropriate threshold is an iterative process that balances the trade-off between correctly identifying in-distribution (ID) data and accurately flagging out-of-distribution (OOD) data. Here's how you can iteratively determine the threshold: @@ -558,8 +560,9 @@ disp.plot(cmap=plt.cm.Blues) plt.title('Confusion Matrix for OOD and ID Classification') plt.show() ``` + #### Discuss -How might you use these tools to ensure that a model trained on health data from hospital A will reliably predict new test data from hospital B? +How might you use these tools to ensure that a model trained on health data from hospital A will reliably predict new test data from hospital B? :::::::::::::::::::::::::::::::::::::::: keypoints @@ -569,4 +572,4 @@ How might you use these tools to ensure that a model trained on health data from - While simple and widely used, softmax-based methods have limitations, including sensitivity to threshold choices and reduced reliability in high-dimensional settings. - Understanding softmax-based OOD detection lays the groundwork for exploring more advanced techniques like energy-based detection. -:::::::::::::::::::::::::::::::::::::::::::::::::: \ No newline at end of file +::::::::::::::::::::::::::::::::::::::::::::::::::