-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cf5fa64
commit 257d78f
Showing
3 changed files
with
341 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
# Incertidumbre en clasificación | ||
|
||
En los problemas de clasificación generalmente no conseguimos clasificación perfecta, | ||
y buscamos que nuestras probabilidades estimadas expresen qué grado de confianza tenemos | ||
en que ocurra cada categorías. Para lograr esto, necesitamos checar que las probabilidades | ||
que producimos expresan esta incertidumbre correctamente, y no son simplemente *scores* | ||
de creencia de que una clase u otra va a ocurrir. | ||
|
||
En esta sección trataremos de cómo checar estas probabilidades, y qué técnicas podemos | ||
usar para expresar confiablemente (y con supuestos mínimos) incertidumbre al hacer predicciones de clase. | ||
|
||
```{r, include = FALSE} | ||
ggplot2::theme_set(ggplot2::theme_minimal(base_size = 13)) | ||
knitr::opts_chunk$set(fig.width=4, fig.height=3) | ||
cbb_palette <- c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7") | ||
scale_colour_discrete <- function(...) { | ||
ggplot2::scale_colour_manual(..., values = cbb_palette) | ||
} | ||
``` | ||
|
||
## Calibración de probabilidades | ||
|
||
En un principio podemos considerar | ||
las probabilidades estimadas en nuestros predictores | ||
como *scores* que califican qué tan creíble es que | ||
una observación particular sea de una categoría o clase dada. Sin embargo, | ||
quisiéramos interpretarlas también de manera frecuentista, como frecuencia | ||
relativa de ocurrencia de eventos a largo plazo. Esto nos permite utilizarlas en | ||
procesos *downstream* y en la toma de decisiones de manera más efectiva. | ||
|
||
Por ejemplo, | ||
en pronósticos meteorológicos, todos los días expresan una probabilidad de lluvia. Resulta ser que entre aquellos días donde los metereológos dicen que hay 10% de lluvia, en 1 de cada 10 llueve en realidad. Entre aquellos días donde dicen que hay 90% de probabilidad de lluvia, aproximadamente 9 de cada 10 días llueve. | ||
Esto quiere decir que estos pronósticos probabilísticos de la meteorología están bien calibrados, | ||
independientemente de nuestra interpretación de esas probabilidades como grados de creencia. | ||
|
||
Igual que buscábamos que nuestros intervalos predictivos estuvieran bien calibrados | ||
en los problemas de regresión, también buscamos que nuestras | ||
predictores en clasificación estén **probabilísticamente bien calibrados**. Es | ||
decir: lo que decimos que puede ocurrir con 10% de probabilidad ocurre efectivamente 1 de cada 10 veces, si decimos 20% entonces ocurre 2 de 20, etc. | ||
|
||
|
||
Podemos hacer esto realizando pruebas de la *calibración* | ||
de las probabilidades que arroja el modelo. | ||
Esto quiere decir que si el modelo nos dice que la probabilidad de que la clase 1 es 0.8, | ||
entonces si tenemos un número grande de estos casos (con probabilidad 0.8), alrededor | ||
de 80\% de éstos tienen que ser positivos. | ||
|
||
|
||
#### Ejemplo: diabetes {-} | ||
|
||
Podemos checar la calibración de nuestro modelo para el ejemplo de diabetes de la sección | ||
anterior | ||
|
||
```{r, message=FALSE, warning=FALSE} | ||
library(tidyverse) | ||
library(tidymodels) | ||
library(gt) | ||
diabetes_pr <- as_tibble(MASS::Pima.te) | ||
diabetes_pr$id <- 1:nrow(diabetes_pr) | ||
flujo_diabetes <- read_rds("cache/flujo_ajustado_diabetes.rds") | ||
``` | ||
Aquí están las probabilidades estimadas | ||
de tener diabetes sobre la muestra de prueba: | ||
|
||
```{r, message=FALSE, warning = FALSE} | ||
proba_mod <- predict(flujo_diabetes, diabetes_pr, type = "prob") | ||
dat_calibracion <- tibble(obs = diabetes_pr |> pull(type), | ||
probabilidad = proba_mod$.pred_Yes) |> | ||
mutate(y = ifelse(obs == "Yes", 1, 0)) | ||
dat_calibracion |> head() |> gt() | ||
``` | ||
|
||
Para ver si estas probabilidades son realistas, podemos por ejemplo hacer | ||
una gráfica como la que sigue: | ||
|
||
```{r} | ||
ggplot(dat_calibracion, aes(x = probabilidad, y = y)) + | ||
geom_jitter(width = 0, height = 0.02, alpha = 0.2) + | ||
geom_smooth(method = "loess", span = 0.7, colour = "red", se = FALSE) + | ||
geom_abline() + | ||
coord_equal() | ||
``` | ||
Y en esta gráfica verificamos que los promedios locales de proporciones de 0-1's son | ||
consistentes con las probabilidades que estimamos. | ||
|
||
## Gráficas de calibración binaria | ||
|
||
Otra manera de hacer esta gráfica | ||
es cortando las probabilidades en cubetas y calculando intervalos de credibilidad | ||
para cada estimación: con esto checamos si el observado es consistente con | ||
las probabilidades de clase. | ||
|
||
::: callout-note | ||
# Pruebas de calibración | ||
|
||
Sobre una muestra de prueba y para un problema de clasificación | ||
binaria, donde $\hat{p} (x)$ es la probabilidad estimada de la clase 1: | ||
|
||
1. Producimos las probabilidades de clase $\hat{p}(\mathbf{x}^{(i)}).$ | ||
2. Ordenamos estas probabilidades de clase de la más grande a la más chica, y las | ||
agrupamos en cubetas $C_1, C_2, \ldots, C_r$. | ||
3. Calculamos en cada cubeta la probabilidad de clase promedio y el porcentaje de | ||
casos de clase 1. | ||
|
||
Si en cada cubeta la probabilidad de clase promedio y el porcentaje de casos de clase | ||
1 son consistentes (similares módulo variación muestral), entonces decimos que | ||
nuestras probabilidades pasan esta prueba de calibración. | ||
::: | ||
|
||
### Ejemplo: diabetes {-} | ||
|
||
|
||
```{r} | ||
# usamos intervalos suavizados (bayesiano beta-binomial) en lugar de los basados | ||
# en los errores estándar sqrt(p*(1-p) / n) | ||
calibracion_gpos <- dat_calibracion |> | ||
mutate(proba_grupo = cut(probabilidad, | ||
quantile(probabilidad, seq(0, 1, 0.1)), include.lowest = TRUE)) |> | ||
group_by(proba_grupo) |> | ||
summarise(prob_media = mean(probabilidad), | ||
n = n(), | ||
obs = sum(y), .groups = "drop") |> | ||
mutate(obs_prop = (obs + 1) / (n + 2), | ||
inferior = qbeta(0.05, obs + 1, n - obs + 2), | ||
superior = qbeta(0.95, obs + 1, n - obs + 2)) | ||
calibracion_gpos |> gt() |> fmt_number(where(is_double), decimals = 3) | ||
``` | ||
```{r} | ||
ggplot(calibracion_gpos, | ||
aes(x = prob_media, y = obs_prop, ymin = inferior, ymax = superior)) + | ||
geom_abline() + | ||
geom_linerange() + | ||
geom_point(colour = "red") + coord_equal() + | ||
xlab("Probabilidad de clase") + | ||
ylab("Proporción observada") + | ||
labs(subtitle = "Intervalos de 90% para prop observada") | ||
``` | ||
Y con esto verificamos que calibración del modelo es razonable, y que es razonable | ||
usar estas probabilidades para tomar decisiones o incluir en ejercicios de simulación. | ||
|
||
**Observación**: | ||
1. Si las probabilidades no están calibradas, y las queremos | ||
utilizar como tales (no simplemente como *scores*), entonces puede ser | ||
necesario hacer un paso adicional de calibración, con una muestra | ||
separada de calibración (ver por ejemplo @kuhn, sección 11.1). | ||
2. En este ejemplo construimos intervalos para las proporciones observadas | ||
usando intervalos bayesianos. Es posible usar intervalos normales o t (usando el | ||
error estándar), pero estos intervalos tienen cobertura mala para proporciones | ||
muy chicas o muy grandes [Binomial proportion wikipedia](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). Nuestro ejemplo es similar a los intervalos de Agresti-Coull. | ||
|
||
|
||
## Ejemplo: artículos de ropa | ||
|
||
Regresamos a nuestro modelo para clasificar imágenes de artículos de ropa: | ||
|
||
```{r, message = FALSE} | ||
library(keras) | ||
ropa_datos <- dataset_fashion_mnist() | ||
ropa_prueba <- ropa_datos$test | ||
# estas son las categorias: | ||
articulos <- c("playera/top", "pantalón", "suéter", "vestido", "abrigo", "sandalia", | ||
"camisa", "tenis", "bolsa", "bota") | ||
etiquetas_tbl <- tibble(codigo = 0:9, articulo = articulos) | ||
x_prueba <- ropa_prueba$x / 255 | ||
y_prueba <- to_categorical(ropa_prueba$y, 10) | ||
modelo <- load_model_tf("cache/red_ropa_1") | ||
``` | ||
|
||
|
||
```{r, message = FALSE} | ||
preds_mat <- predict(modelo, x_prueba) | ||
colnames(preds_mat) <- articulos | ||
preds_tbl <- as_tibble(preds_mat) |> | ||
mutate(id = 1:nrow(preds_mat), .before = 1) |> | ||
mutate(codigo = ropa_prueba$y) |> | ||
left_join(etiquetas_tbl) | ||
preds_tbl |> head() |> gt() |> fmt_number(where(is_double), decimals = 3) | ||
``` | ||
|
||
Nuestra gráfica de calibración para la categoría *camisa* es: | ||
|
||
```{r} | ||
preds_tbl |> select(camisa, articulo) |> | ||
mutate(obs_camisa = articulo == "camisa") |> | ||
mutate(grupo_prob = cut_number(camisa, 40)) |> | ||
group_by(grupo_prob) |> | ||
summarise(n = n(), | ||
pred_camisa = mean(camisa), | ||
prop_camisa = mean(obs_camisa)) |> | ||
mutate(ee = sqrt(prop_camisa * (1 - prop_camisa) / n)) |> | ||
ggplot(aes(x = pred_camisa, y = prop_camisa, ymin = prop_camisa - 2*ee, | ||
ymax = prop_camisa + 2* ee)) + | ||
geom_abline() + | ||
geom_linerange() + | ||
geom_point(colour = "red") + xlim(0,1) + ylim(0,1) | ||
``` | ||
|
||
Vemos que la calibración no es muy mala, al menos para la categoría de *camisa*. Podemos | ||
checar otras categorías de esta manera. Calibrar estas probabilidades puede ser más difícil, | ||
pero podemos construir regiones conformes con garantías de cobertura como mostramos abajo: | ||
|
||
## Regiones conformes para clasificación | ||
|
||
Podemos construir conjuntos de predicción con la técnica de predicción conforme | ||
con muestra de prueba, siguiendo ideas de [@gentle21], ver [aquí](http://people.eecs.berkeley.edu/~angelopoulos/blog/posts/gentle-intro/) también. | ||
|
||
En este caso, no es necesario tener probabilidades calibradas, pero los conjuntos | ||
de predicción que obtendremos tendrán garantía de cobertura correcta promedio. | ||
|
||
|
||
```{r} | ||
preds_larga_tbl <- preds_tbl |> | ||
pivot_longer(`playera/top`:bota, names_to = "articulo_prob", values_to = "prob") |> | ||
group_by(id) |> | ||
arrange(id, desc(prob)) |> | ||
mutate(clase_verdadera = articulo == articulo_prob) |> | ||
mutate(acumulado = cumsum(clase_verdadera)) | ||
e_valores <- preds_larga_tbl |> | ||
filter(clase_verdadera) |> | ||
group_by(id) |> | ||
summarise(e = sum(prob)) | ||
q <- quantile(e_valores$e, prob = 0.05) | ||
e_valores |> ggplot(aes(x = e)) + geom_histogram() + | ||
geom_vline(xintercept = q, colour = "red") | ||
``` | ||
|
||
```{r} | ||
conjuntos_conf <- preds_larga_tbl |> group_by(id) |> | ||
filter(prob > q) |> | ||
select(id, articulo = articulo_prob, prob) | ||
conjuntos_conf | ||
``` | ||
|
||
Y así se ve el tamaño de los conjuntos conformes. La mayoría consiste de una sola | ||
clase (con probabilidad de 95%), pero muchos tienen 2 o 3 categorías posibles, de modo | ||
que el desempeño de nuestro modelo no es excelente. | ||
|
||
```{r} | ||
conjuntos_conf |> count(id) |> | ||
group_by(n) |> count() | ||
``` | ||
|
||
|
||
Por ejemplo: | ||
|
||
```{r} | ||
filter(conjuntos_conf, id == 21) | ||
``` | ||
|
||
```{r, message = FALSE} | ||
library(imager) | ||
plot(as.cimg(t(ropa_prueba$x[21,,])), axes = FALSE, main = articulos[ropa_prueba$y[21] + 1]) | ||
``` | ||
|
||
|
||
```{r} | ||
id_1 <- 27 | ||
ejemplo_conf <- filter(conjuntos_conf |> ungroup(), id == id_1) | ||
ejemplo_conf | ||
plot(as.cimg(t(ropa_prueba$x[id_1,,])), axes = FALSE, main = articulos[ropa_prueba$y[id_1] + 1]) | ||
``` | ||
|
||
## Calibración de probabilidades | ||
|
||
Cuando la calibración no es muy buena, es posible seguir el camino de | ||
predicción conforme de clase, o podemos también intentar un proceso de | ||
calibración. La idea es construir una funcion ${f}_cal$ tal que si | ||
$p'(x) = f_{cal}(p(x))$ , las probabilidades dadas por $p'(x)$ tienen | ||
buena calibración. | ||
|
||
Los métodos más comunes son: | ||
|
||
- Aplicar regresión logística (quizá con splines, por ejemplo), usando | ||
la probabilidad/score del modelo original como variable de entrada, y la respuesta como la variable de salida. | ||
- Aplicar regresión isotónica, que es similar pero restringe a que la calibración | ||
preserve el orden de las probabilidades/scores del modelo original. | ||
|
||
Existen otros métodos ( ver por ejemplo https://www.tidymodels.org/learn/models/calibration/) | ||
|
||
Veamos un ejemplo donde queremos contruir un modelo para predecir que células | ||
están correctamente segmentadas y cuáles no, bajo un método que se llama | ||
High content screening. Las clases son PS (poorly segmented) y WS (well segmented): | ||
|
||
```{r} | ||
data(cells) | ||
dim(cells) | ||
cells$case <- NULL | ||
cells |> count(class) | ||
``` | ||
|
||
|
||
```{r} | ||
library(probably) | ||
library(discrim) | ||
set.seed(128) | ||
split <- initial_split(cells, strata = class) | ||
cells_tr <- training(split) | ||
cells_te <- testing(split) | ||
cells_rs <- vfold_cv(cells_tr, v = 10, strata = class) | ||
discrim_wflow <- | ||
workflow() |> | ||
add_formula(class ~ .) |> | ||
add_model(discrim_linear() |> set_mode("classification")) | ||
metricas <- metric_set(roc_auc, brier_class) | ||
discrim_res <- | ||
discrim_wflow |> | ||
fit_resamples(resamples = cells_rs, | ||
metrics = metricas, | ||
control = control_resamples(save_pred = TRUE)) | ||
collect_metrics(discrim_res) | ||
``` | ||
|
||
|
||
```{r} | ||
cal_plot_breaks(discrim_res, num_breaks = 15) | ||
``` | ||
Corremos ahora validación con regresión logística com ejemplo: | ||
|
||
```{r} | ||
logit_val <- probably::cal_validate_logistic(discrim_res, | ||
metrics = metricas, save_pred = TRUE) | ||
collect_metrics(logit_val) | ||
``` | ||
|
||
Y el resultado es ahora como sigue: | ||
|
||
```{r} | ||
collect_predictions(logit_val) |> | ||
filter(.type == "calibrated") |> | ||
cal_plot_breaks(truth = class, estimate = .pred_PS) | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters