From 257d78ff407725fb4c44dae86b53f55805bf4df8 Mon Sep 17 00:00:00 2001 From: Felipe Gonzalez Date: Sun, 17 Sep 2023 18:46:01 -0600 Subject: [PATCH] =?UTF-8?q?Agregar=20secci=C3=B3n=2010?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/notas-docker.yml | 8 +- notas/10-clasificacion-calibracion.qmd | 334 +++++++++++++++++++++++++ notas/_quarto.yml | 1 + 3 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 notas/10-clasificacion-calibracion.qmd diff --git a/.github/workflows/notas-docker.yml b/.github/workflows/notas-docker.yml index 2c387a16..2f36f0d9 100644 --- a/.github/workflows/notas-docker.yml +++ b/.github/workflows/notas-docker.yml @@ -15,10 +15,14 @@ jobs: steps: - name: Install probably run: | - remotes::install_github("https://github.com/tidymodels/probably/", upgrade = "always", force = TRUE) + remotes::install_github("https://github.com/tidymodels/probably/", upgrade = "always", force = TRUE) + shell: Rscript {0} + - name: Install others + run: | + install.packages(c("discrim")) shell: Rscript {0} - uses: actions/checkout@v2 - - name: Render book + - name: Render book # Add any command line argument needed run: | quarto render notas/ --to html diff --git a/notas/10-clasificacion-calibracion.qmd b/notas/10-clasificacion-calibracion.qmd new file mode 100644 index 00000000..85d9f41c --- /dev/null +++ b/notas/10-clasificacion-calibracion.qmd @@ -0,0 +1,334 @@ +# Incertidumbre en clasificación + +En los problemas de clasificación generalmente no conseguimos clasificación perfecta, +y buscamos que nuestras probabilidades estimadas expresen qué grado de confianza tenemos +en que ocurra cada categorías. Para lograr esto, necesitamos checar que las probabilidades +que producimos expresan esta incertidumbre correctamente, y no son simplemente *scores* +de creencia de que una clase u otra va a ocurrir. + +En esta sección trataremos de cómo checar estas probabilidades, y qué técnicas podemos +usar para expresar confiablemente (y con supuestos mínimos) incertidumbre al hacer predicciones de clase. + +```{r, include = FALSE} +ggplot2::theme_set(ggplot2::theme_minimal(base_size = 13)) +knitr::opts_chunk$set(fig.width=4, fig.height=3) +cbb_palette <- c("#000000", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7") +scale_colour_discrete <- function(...) { + ggplot2::scale_colour_manual(..., values = cbb_palette) +} +``` + +## Calibración de probabilidades + +En un principio podemos considerar +las probabilidades estimadas en nuestros predictores +como *scores* que califican qué tan creíble es que +una observación particular sea de una categoría o clase dada. Sin embargo, +quisiéramos interpretarlas también de manera frecuentista, como frecuencia +relativa de ocurrencia de eventos a largo plazo. Esto nos permite utilizarlas en +procesos *downstream* y en la toma de decisiones de manera más efectiva. + +Por ejemplo, +en pronósticos meteorológicos, todos los días expresan una probabilidad de lluvia. Resulta ser que entre aquellos días donde los metereológos dicen que hay 10% de lluvia, en 1 de cada 10 llueve en realidad. Entre aquellos días donde dicen que hay 90% de probabilidad de lluvia, aproximadamente 9 de cada 10 días llueve. +Esto quiere decir que estos pronósticos probabilísticos de la meteorología están bien calibrados, +independientemente de nuestra interpretación de esas probabilidades como grados de creencia. + +Igual que buscábamos que nuestros intervalos predictivos estuvieran bien calibrados +en los problemas de regresión, también buscamos que nuestras +predictores en clasificación estén **probabilísticamente bien calibrados**. Es +decir: lo que decimos que puede ocurrir con 10% de probabilidad ocurre efectivamente 1 de cada 10 veces, si decimos 20% entonces ocurre 2 de 20, etc. + + +Podemos hacer esto realizando pruebas de la *calibración* +de las probabilidades que arroja el modelo. +Esto quiere decir que si el modelo nos dice que la probabilidad de que la clase 1 es 0.8, +entonces si tenemos un número grande de estos casos (con probabilidad 0.8), alrededor +de 80\% de éstos tienen que ser positivos. + + +#### Ejemplo: diabetes {-} + +Podemos checar la calibración de nuestro modelo para el ejemplo de diabetes de la sección +anterior + +```{r, message=FALSE, warning=FALSE} +library(tidyverse) +library(tidymodels) +library(gt) +diabetes_pr <- as_tibble(MASS::Pima.te) +diabetes_pr$id <- 1:nrow(diabetes_pr) +flujo_diabetes <- read_rds("cache/flujo_ajustado_diabetes.rds") +``` +Aquí están las probabilidades estimadas +de tener diabetes sobre la muestra de prueba: + +```{r, message=FALSE, warning = FALSE} +proba_mod <- predict(flujo_diabetes, diabetes_pr, type = "prob") +dat_calibracion <- tibble(obs = diabetes_pr |> pull(type), + probabilidad = proba_mod$.pred_Yes) |> + mutate(y = ifelse(obs == "Yes", 1, 0)) +dat_calibracion |> head() |> gt() +``` + +Para ver si estas probabilidades son realistas, podemos por ejemplo hacer +una gráfica como la que sigue: + +```{r} +ggplot(dat_calibracion, aes(x = probabilidad, y = y)) + + geom_jitter(width = 0, height = 0.02, alpha = 0.2) + + geom_smooth(method = "loess", span = 0.7, colour = "red", se = FALSE) + + geom_abline() + + coord_equal() +``` +Y en esta gráfica verificamos que los promedios locales de proporciones de 0-1's son +consistentes con las probabilidades que estimamos. + +## Gráficas de calibración binaria + +Otra manera de hacer esta gráfica +es cortando las probabilidades en cubetas y calculando intervalos de credibilidad +para cada estimación: con esto checamos si el observado es consistente con +las probabilidades de clase. + +::: callout-note +# Pruebas de calibración + +Sobre una muestra de prueba y para un problema de clasificación +binaria, donde $\hat{p} (x)$ es la probabilidad estimada de la clase 1: + +1. Producimos las probabilidades de clase $\hat{p}(\mathbf{x}^{(i)}).$ +2. Ordenamos estas probabilidades de clase de la más grande a la más chica, y las +agrupamos en cubetas $C_1, C_2, \ldots, C_r$. +3. Calculamos en cada cubeta la probabilidad de clase promedio y el porcentaje de +casos de clase 1. + +Si en cada cubeta la probabilidad de clase promedio y el porcentaje de casos de clase +1 son consistentes (similares módulo variación muestral), entonces decimos que +nuestras probabilidades pasan esta prueba de calibración. +::: + +### Ejemplo: diabetes {-} + + +```{r} +# usamos intervalos suavizados (bayesiano beta-binomial) en lugar de los basados +# en los errores estándar sqrt(p*(1-p) / n) +calibracion_gpos <- dat_calibracion |> + mutate(proba_grupo = cut(probabilidad, + quantile(probabilidad, seq(0, 1, 0.1)), include.lowest = TRUE)) |> + group_by(proba_grupo) |> + summarise(prob_media = mean(probabilidad), + n = n(), + obs = sum(y), .groups = "drop") |> + mutate(obs_prop = (obs + 1) / (n + 2), + inferior = qbeta(0.05, obs + 1, n - obs + 2), + superior = qbeta(0.95, obs + 1, n - obs + 2)) +calibracion_gpos |> gt() |> fmt_number(where(is_double), decimals = 3) +``` +```{r} +ggplot(calibracion_gpos, + aes(x = prob_media, y = obs_prop, ymin = inferior, ymax = superior)) + + geom_abline() + + geom_linerange() + + geom_point(colour = "red") + coord_equal() + + xlab("Probabilidad de clase") + + ylab("Proporción observada") + + labs(subtitle = "Intervalos de 90% para prop observada") +``` +Y con esto verificamos que calibración del modelo es razonable, y que es razonable +usar estas probabilidades para tomar decisiones o incluir en ejercicios de simulación. + +**Observación**: +1. Si las probabilidades no están calibradas, y las queremos +utilizar como tales (no simplemente como *scores*), entonces puede ser +necesario hacer un paso adicional de calibración, con una muestra +separada de calibración (ver por ejemplo @kuhn, sección 11.1). +2. En este ejemplo construimos intervalos para las proporciones observadas +usando intervalos bayesianos. Es posible usar intervalos normales o t (usando el +error estándar), pero estos intervalos tienen cobertura mala para proporciones +muy chicas o muy grandes [Binomial proportion wikipedia](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). Nuestro ejemplo es similar a los intervalos de Agresti-Coull. + + +## Ejemplo: artículos de ropa + +Regresamos a nuestro modelo para clasificar imágenes de artículos de ropa: + +```{r, message = FALSE} +library(keras) +ropa_datos <- dataset_fashion_mnist() +ropa_prueba <- ropa_datos$test +# estas son las categorias: +articulos <- c("playera/top", "pantalón", "suéter", "vestido", "abrigo", "sandalia", + "camisa", "tenis", "bolsa", "bota") +etiquetas_tbl <- tibble(codigo = 0:9, articulo = articulos) +x_prueba <- ropa_prueba$x / 255 +y_prueba <- to_categorical(ropa_prueba$y, 10) +modelo <- load_model_tf("cache/red_ropa_1") +``` + + +```{r, message = FALSE} +preds_mat <- predict(modelo, x_prueba) +colnames(preds_mat) <- articulos +preds_tbl <- as_tibble(preds_mat) |> + mutate(id = 1:nrow(preds_mat), .before = 1) |> + mutate(codigo = ropa_prueba$y) |> + left_join(etiquetas_tbl) +preds_tbl |> head() |> gt() |> fmt_number(where(is_double), decimals = 3) +``` + +Nuestra gráfica de calibración para la categoría *camisa* es: + +```{r} +preds_tbl |> select(camisa, articulo) |> + mutate(obs_camisa = articulo == "camisa") |> + mutate(grupo_prob = cut_number(camisa, 40)) |> + group_by(grupo_prob) |> + summarise(n = n(), + pred_camisa = mean(camisa), + prop_camisa = mean(obs_camisa)) |> + mutate(ee = sqrt(prop_camisa * (1 - prop_camisa) / n)) |> +ggplot(aes(x = pred_camisa, y = prop_camisa, ymin = prop_camisa - 2*ee, + ymax = prop_camisa + 2* ee)) + + geom_abline() + + geom_linerange() + + geom_point(colour = "red") + xlim(0,1) + ylim(0,1) +``` + +Vemos que la calibración no es muy mala, al menos para la categoría de *camisa*. Podemos +checar otras categorías de esta manera. Calibrar estas probabilidades puede ser más difícil, +pero podemos construir regiones conformes con garantías de cobertura como mostramos abajo: + +## Regiones conformes para clasificación + +Podemos construir conjuntos de predicción con la técnica de predicción conforme +con muestra de prueba, siguiendo ideas de [@gentle21], ver [aquí](http://people.eecs.berkeley.edu/~angelopoulos/blog/posts/gentle-intro/) también. + +En este caso, no es necesario tener probabilidades calibradas, pero los conjuntos +de predicción que obtendremos tendrán garantía de cobertura correcta promedio. + + +```{r} +preds_larga_tbl <- preds_tbl |> + pivot_longer(`playera/top`:bota, names_to = "articulo_prob", values_to = "prob") |> + group_by(id) |> + arrange(id, desc(prob)) |> + mutate(clase_verdadera = articulo == articulo_prob) |> + mutate(acumulado = cumsum(clase_verdadera)) +e_valores <- preds_larga_tbl |> + filter(clase_verdadera) |> + group_by(id) |> + summarise(e = sum(prob)) +q <- quantile(e_valores$e, prob = 0.05) +e_valores |> ggplot(aes(x = e)) + geom_histogram() + + geom_vline(xintercept = q, colour = "red") +``` + +```{r} +conjuntos_conf <- preds_larga_tbl |> group_by(id) |> + filter(prob > q) |> + select(id, articulo = articulo_prob, prob) +conjuntos_conf +``` + +Y así se ve el tamaño de los conjuntos conformes. La mayoría consiste de una sola +clase (con probabilidad de 95%), pero muchos tienen 2 o 3 categorías posibles, de modo +que el desempeño de nuestro modelo no es excelente. + +```{r} +conjuntos_conf |> count(id) |> + group_by(n) |> count() +``` + + +Por ejemplo: + +```{r} +filter(conjuntos_conf, id == 21) +``` + +```{r, message = FALSE} +library(imager) +plot(as.cimg(t(ropa_prueba$x[21,,])), axes = FALSE, main = articulos[ropa_prueba$y[21] + 1]) +``` + + +```{r} +id_1 <- 27 +ejemplo_conf <- filter(conjuntos_conf |> ungroup(), id == id_1) +ejemplo_conf +plot(as.cimg(t(ropa_prueba$x[id_1,,])), axes = FALSE, main = articulos[ropa_prueba$y[id_1] + 1]) +``` + +## Calibración de probabilidades + +Cuando la calibración no es muy buena, es posible seguir el camino de +predicción conforme de clase, o podemos también intentar un proceso de +calibración. La idea es construir una funcion ${f}_cal$ tal que si +$p'(x) = f_{cal}(p(x))$ , las probabilidades dadas por $p'(x)$ tienen +buena calibración. + +Los métodos más comunes son: + +- Aplicar regresión logística (quizá con splines, por ejemplo), usando +la probabilidad/score del modelo original como variable de entrada, y la respuesta como la variable de salida. +- Aplicar regresión isotónica, que es similar pero restringe a que la calibración +preserve el orden de las probabilidades/scores del modelo original. + +Existen otros métodos ( ver por ejemplo https://www.tidymodels.org/learn/models/calibration/) + +Veamos un ejemplo donde queremos contruir un modelo para predecir que células +están correctamente segmentadas y cuáles no, bajo un método que se llama +High content screening. Las clases son PS (poorly segmented) y WS (well segmented): + +```{r} +data(cells) +dim(cells) +cells$case <- NULL +cells |> count(class) +``` + + +```{r} +library(probably) +library(discrim) +set.seed(128) +split <- initial_split(cells, strata = class) +cells_tr <- training(split) +cells_te <- testing(split) + +cells_rs <- vfold_cv(cells_tr, v = 10, strata = class) +discrim_wflow <- + workflow() |> + add_formula(class ~ .) |> + add_model(discrim_linear() |> set_mode("classification")) +metricas <- metric_set(roc_auc, brier_class) + +discrim_res <- + discrim_wflow |> + fit_resamples(resamples = cells_rs, + metrics = metricas, + control = control_resamples(save_pred = TRUE)) +collect_metrics(discrim_res) +``` + + +```{r} +cal_plot_breaks(discrim_res, num_breaks = 15) +``` +Corremos ahora validación con regresión logística com ejemplo: + +```{r} +logit_val <- probably::cal_validate_logistic(discrim_res, + metrics = metricas, save_pred = TRUE) +collect_metrics(logit_val) +``` + +Y el resultado es ahora como sigue: + +```{r} +collect_predictions(logit_val) |> + filter(.type == "calibrated") |> + cal_plot_breaks(truth = class, estimate = .pred_PS) +``` + diff --git a/notas/_quarto.yml b/notas/_quarto.yml index 2286cb2f..2d8e89e3 100644 --- a/notas/_quarto.yml +++ b/notas/_quarto.yml @@ -15,6 +15,7 @@ book: - 07-intervalos-predictivos.qmd - 08-clasificacion-1.qmd - 09-clasificacion-2.qmd + - 10-clasificacion-calibracion.qmd appendices: - 81-apendice-descenso.qmd - 82-apendice-descenso-estocastico.qmd