pyLocalGLMnet
Eine Python Implementierung des Richman/Wüthrich-Ansatzes
Abbildungsverzeichnis
Dependencies
1. Einleitung
2. Datensatz 1: Künstlicher Datensatz
2.1 Künstlichen Datensatz erzeugen
2.2 LocalGLMnet
2.2.1 GLM
2.2.2 LocalGLMnet
2.2.3 Performance Benchmark
2.3 Auswertung
2.3.1 Variable Selection
2.3.2 Feature Contribution
2.3.3 Interaction Strengths
3. Datensatz 2: freMTPL2freq
3.1 Vorverarbeitung
3.2 LocalGLMnet
3.3 Auswertung
3.3.1 Variable Selection
3.3.2 Neues LocalGLMnet trainieren
3.3.3 Feature Contribution
3.3.4 Interaction Strengths
4. Zusammenfassung
Literaturverzeichnis
Synthetischer Datensatz:
Abb. 1: LocalGLMnet vs. GLM
Abb. 2: Regression Attentions
Abb. 3: Feature Contributions
Abb. 4: Interaction Strengths
FreMTPL-Datensatz:
Abb. 5: Regression Attentions
Abb. 6: Area Code vs. Density
Abb. 7: Feature Contributions
Abb. 8: Feature Contribution kategorialer Variablen
Abb. 9: Interaction Strengths
import tensorflow as tf
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as patches
from scipy import interpolate
import scipy.stats as stats
from sklearn import linear_model
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from labellines import labelLines
import os
import random
Um reproduzierbare Ergebnisse zu gewährleisten, werden zusätzlich die Zufallsgeneratoren mit dem Seed 0 initialisiert.
# Seed der Zufallsgeneratoren festlegen
seed = 0
tf.random.set_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
In dem Paper LocalGLMnet: interpretable deep learning for tabular data beschreiben Richman & Wüthrich eine neue Struktur für Neuronale Netze, welche auf Generalisierten Linearen Modellen (GLMs) beruht [1]. Dies soll einen Kompromiss zwischen der hohen Performanz der klassischen vorwärts gerichteten Neuronalen Netzen und der Erklärbarkeit von GLMs schaffen. Der Grundgedanke besteht darin, dass die Koeffizienten des GLMs durch das Neuronale Netz bestimmt werden. Hierdurch können sie anders als bei einem klassischen GLM in Abhängigkeit von den Merkmalsausprägungen variieren. Für einen beschränkten Wertebereich können die Koeffizienten jedoch konstant erscheinen, weshalb von einem lokalen GLM gesprochen wird. Um den Zusammenhang zwischen Attention-Gewichten zu den ursprünglichen Merkmalen beizubehalten und Rückschlüsse auf den Einfluss unterschiedlicher Merkmale zu erlauben, wird in der Netzstruktur eine Skip-Connection verwendet. Bevor die Attention-Gewichte als Parameter des GLMs verwendet werden, werden sie hierfür mit der entsprechenden ursprünglichen Merkmalsausprägung multipliziert.
Im Folgenden soll die Modellierung des LocalGLMnet-Ansatzes in Python mithilfe der Tensorflow Implementierung der Keras API dargestellt werden. Hierfür werden dieselben Datensätze wie im ursprünglichen Paper verwendet, um einen einfachen Transfer der Inhalte des Papers zur Implementierung zu ermöglichen. Der erste Datensatz ist hierbei ein synthetischer Datensatz, der dadurch, dass der tatsächliche Regressionszusammenhang bekannt ist, Möglichkeiten sowie Grenzen des Ansatzes aufzeigt. Anschließend wird der Einsatz des LocalGLMnet an einem realen Sachverhalt, der Vorhersage der Schadensmeldungen einer Kfz-Haftpflicht, dargestellt.
Der synthetische Datensatz besteht aus insgesamt 8 Merkmalen. x7 und x8 haben keinen Einfluss auf die Zielvariable. x8 ist jedoch zu 50% mit x2 korreliert. Der funktionale Zusammenhang der Zielvariable ergibt sich wie folgt:
\begin{equation} \mu\left( x \right)=\frac{1}{2}x_{1}-\frac{1}{4}x^2_{2}+\frac{1}{2}\left\lvert x_{3} \right\rvert sin\left( 2x_{3} \right)+\frac{1}{2}x_{4}x_{5}+\frac{1}{8}x^{2}{5}x{6} \end{equation}
Die Merkmalsausprägungen werden mithilfe des Zufallsgenerators von Numpy auf Basis einer Standardnormalverteilung erzeugt. Hierdurch sind die Merkmale bereits standardisiert, d. h. alle haben den Mittelwert µ=0 und std=1. Bei einem anderen Datensatz müssten die Merkmale zuerst standardisiert werden, damit die Werte die gleiche Größenordnung haben. Da die Daten künstlich erzeugt werden, wird sowohl ein Trainings- als auch ein Testdatensatz mit 100000 Beobachtungen erzeugt. Bei einem realen Datensatz müsste der vorhandene Datensatz entsprechend aufgeteilt werden (bspw. 80:20).
# Zielfunktion
def target_variable(x):
return (
(1 / 2) * x[0]
- (1 / 4) * (x[1] ** 2)
+ (1 / 2) * abs(x[2]) * math.sin(2 * x[2])
+ (1 / 2) * x[3] * x[4]
+ (1 / 8) * (x[4] ** 2) * x[5]
)
# Random Number Generator
rng = np.random.default_rng()
# Trainingsdatensatz (n = 100.000) erzeugen (Variablen x1, x3, x4, x5, x6, x7)
x1_train = rng.standard_normal(size=(100000, 1))
x3_7_train = rng.standard_normal(size=(100000, 5))
# Variablen x2, x8 mit 50 % Korrelation erzeugen
cov_matrix = [[1, 0.5], [0.5, 1]]
x2_x8_train = rng.multivariate_normal(mean=[0, 0], cov=cov_matrix, size=100000)
# Trainingsdatensatz zusammenfügen und Zielvariable y bestimmen
x_train = np.column_stack((x1_train, x2_x8_train[:, 0], x3_7_train, x2_x8_train[:, 1]))
y_train = np.array(list(map(target_variable, x_train[:, 0:7])))
# Testdatensatz (n = 100.000) erzeugen (Variablen x1, x3, x4, x5, x6, x7)
x1_test = rng.standard_normal(size=(100000, 1))
x3_7_test = rng.standard_normal(size=(100000, 5))
# Variablen x2, x8 mit 50 % Korrelation erzeugen
x2_x8_test = rng.multivariate_normal(mean=[0, 0], cov=cov_matrix, size=100000)
# Testdatensatz zusammenfügen und Zielvariable y bestimmen
x_test = np.column_stack((x1_test, x2_x8_test[:, 0], x3_7_test, x2_x8_test[:, 1]))
y_test = np.array(list(map(target_variable, x_test[:, 0:7])))
Für den synthetischen Datensatz haben Richman & Wüthrich die Identity-Link Funktion verwendet. Das resultierende GLM entspricht also einer klassischen linearen Regression. Um ein GLM mit Python zu erzeugen, bieten sich Bibliotheken wie scikit-learn oder statsmodels an.
reg = linear_model.LinearRegression()
reg.fit(x_train, y_train)
LinearRegression()
# LocalGLMnet strukturieren
input = tf.keras.Input(shape=(8), dtype="float32")
attention = input
attention = tf.keras.layers.Dense(units=20, activation="tanh")(attention)
attention = tf.keras.layers.Dense(units=15, activation="tanh")(attention)
attention = tf.keras.layers.Dense(units=10, activation="tanh")(attention)
attention = tf.keras.layers.Dense(units=8, activation="linear", name="Attention")(
attention
)
# Skip-Connection
response = tf.keras.layers.Dot(axes=1)([input, attention])
# Response Schicht = lokales GLM
response = tf.keras.layers.Dense(units=1, activation="linear", name="Response")(
response
)
2023-02-19 13:13:06.119452: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
# Modell kompilieren
local_glm_net = tf.keras.Model(inputs=input, outputs=response)
local_glm_net.compile(loss="mse", optimizer="nadam")
local_glm_net.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 8)] 0
__________________________________________________________________________________________________
dense (Dense) (None, 20) 180 input_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 15) 315 dense[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 10) 160 dense_1[0][0]
__________________________________________________________________________________________________
Attention (Dense) (None, 8) 88 dense_2[0][0]
__________________________________________________________________________________________________
dot (Dot) (None, 1) 0 input_1[0][0]
Attention[0][0]
__________________________________________________________________________________________________
Response (Dense) (None, 1) 2 dot[0][0]
==================================================================================================
Total params: 745
Trainable params: 745
Non-trainable params: 0
__________________________________________________________________________________________________
# Modell trainieren
history = local_glm_net.fit(
x_train, y_train, batch_size=32, epochs=10, validation_split=0.2
)
2023-02-19 13:13:06.404522: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/10
2500/2500 [==============================] - 4s 1ms/step - loss: 0.2422 - val_loss: 0.0840
Epoch 2/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0612 - val_loss: 0.0480
Epoch 3/10
2500/2500 [==============================] - 4s 2ms/step - loss: 0.0391 - val_loss: 0.0470
Epoch 4/10
2500/2500 [==============================] - 4s 2ms/step - loss: 0.0295 - val_loss: 0.0257
Epoch 5/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0155 - val_loss: 0.0092
Epoch 6/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0074 - val_loss: 0.0058
Epoch 7/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0051 - val_loss: 0.0050
Epoch 8/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0043 - val_loss: 0.0046
Epoch 9/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0037 - val_loss: 0.0035
Epoch 10/10
2500/2500 [==============================] - 3s 1ms/step - loss: 0.0037 - val_loss: 0.0038
# Vorhersage mit localGLMnet und GLM
pred_local = local_glm_net.predict(x_test)
pred_reg = reg.predict(x_test)
fig_performance = plt.figure(tight_layout=True, figsize=(10, 5))
spec = GridSpec(ncols=2, nrows=1, figure=fig_performance)
axs_perf = [
fig_performance.add_subplot(spec[0, 0:1]),
fig_performance.add_subplot(spec[0, 1:2]),
]
axs_perf[0].scatter(y_test, pred_local, s=1)
axs_perf[1].scatter(y_test, pred_reg, s=1)
# Layout
for ax in axs_perf:
ax.set_xlabel("True value")
ax.set_ylabel("Estimated value")
ax.set_xlim((-4, 4))
ax.set_ylim((-4, 4))
fig_performance.suptitle("Abbildung 1: LocalGLMnet vs. GLM")
plt.show()
print("MSE LocalGLMnet: " + str(metrics.mean_squared_error(y_test, pred_local)))
print("MSE GLM: " + str(metrics.mean_squared_error(y_test, pred_reg)))
MSE LocalGLMnet: 0.0037230475765324548
MSE GLM: 0.5342545394609545
Streut ein Attention-Wert für einen Großteil des Wertebereichs der Inputvariable um 0, scheint der Einfluss vernachlässigbar, das Merkmal kann demnach entfernt werden. Um ein Maß für die Streuung zu bieten, haben Richman & Wüthrich einen empirischen Wald-Test entwickelt [1]. Hierbei wird dem Modell eine zusätzliche Variable ohne Zusammenhang zur Zielvariable hinzugefügt. Anschließend wird auf Basis der Streuung des zugehörigen Attention-Gewichts ein Konfidenzintervall berechnet. Hiermit lässt sich daraufhin die Coverage Ratio für jedes Attention-Gewicht, also der Anteil der Gewichte, die innerhalb der Grenzen liegen, berechnen. Ist diese kleiner als das Signifikanzniveau, kann die Variable entfernt und das Modell erneut ohne diese trainiert werden.
Bei dem verwendeten synthetischen Datensatz kann direkt β7 verwendet werden, da sie keinen Einfluss auf die tatsächliche Regressionsfunktion hat. Bei einem realen Datensatz bieten sich künstlich erzeugte normal- und gleichverteilte Merkmale mit μ=0 und std=1 an.
# Über die Methode get_weights() erhält man die Kantengewichte, sowie den Bias für jeder Schicht
# --> man erhält also eine Liste mit numpy Arrays die in der Länge der Anzahl der Ebenen * 2 entspricht
weights = local_glm_net.get_weights()
for i in weights:
print(i.shape, end=" | ")
(8, 20) | (20,) | (20, 15) | (15,) | (15, 10) | (10,) | (10, 8) | (8,) | (1, 1) | (1,) |
# Neues Model ohne Response-Schicht --> ermöglicht auslesen der Attention Gewichte
weights_local_glm = tf.keras.Model(
inputs=local_glm_net.input, outputs=local_glm_net.get_layer(name="Attention").output
)
# Gewichte bestimmen
beta_x = weights_local_glm.predict(x_test)
# Skalierung der Attention-Gewichte mithilfe des Gewichts der Response Schicht ( = Intercept beta_0)
beta_x_scaled = beta_x * weights[8]
# Merkmal 7 ist von der wahren Regressionsfunktion unabhängig
# --> Einsatz zur Berechnung des Konfidenzintervals
print("Mittelwert β7: " + str(beta_x_scaled[:, 6].mean()))
print("Standardabweichung β7: " + str(beta_x_scaled[:, 6].std()))
# Intervalgrenzen bestimmen
alpha = 0.001
bound = stats.norm.ppf(alpha / 2) * beta_x_scaled[:, 6].std()
print("Quantil " + str(1 - alpha / 2) + ": " + str(stats.norm.ppf(alpha / 2)))
print("Grenzen: ± " + str(abs(bound)))
Mittelwert β7: -0.003177547
Standardabweichung β7: 0.008993691
Quantil 0.9995: -3.2905267314918945
Grenzen: ± 0.029593980102238748
# Attention Plot
fig_attention = plt.figure(tight_layout=True, figsize=(30, 15))
# Gliederung der Subplots
spec = GridSpec(ncols=8, nrows=3, figure=fig_attention)
ax1_att = fig_attention.add_subplot(spec[0, 1:3])
ax2_att = fig_attention.add_subplot(spec[0, 3:5])
ax3_att = fig_attention.add_subplot(spec[0, 5:7])
ax4_att = fig_attention.add_subplot(spec[1, 1:3])
ax5_att = fig_attention.add_subplot(spec[1, 3:5])
ax6_att = fig_attention.add_subplot(spec[1, 5:7])
ax7_att = fig_attention.add_subplot(spec[2, 2:4])
ax8_att = fig_attention.add_subplot(spec[2, 4:6])
axs_att = [ax1_att, ax2_att, ax3_att, ax4_att, ax5_att, ax6_att, ax7_att, ax8_att]
# Ein Subplot pro Input Feature erstellen
for i in range(len(axs_att)):
# Linien zur Verdeutlichung der Höhe der Attention Gewichte
axs_att[i].hlines(y=0.5, xmin=-4, xmax=4, colors="orange")
axs_att[i].hlines(y=-0.5, xmin=-4, xmax=4, colors="orange")
axs_att[i].hlines(y=0.25, xmin=-4, xmax=4, colors="orange", linestyles="dashed")
axs_att[i].hlines(y=-0.25, xmin=-4, xmax=4, colors="orange", linestyles="dashed")
axs_att[i].hlines(y=0, xmin=-4, xmax=4, colors="red")
# Intervalgrenzen
interval = patches.Rectangle(
xy=(-4, bound),
height=2 * abs(bound),
width=8,
edgecolor="royalblue",
facecolor="lightcyan",
alpha=0.8,
zorder=1,
)
axs_att[i].add_patch(interval)
# Scatter Plot --> x: Werte der Inputfeatures, y: Attention Gewichte
axs_att[i].scatter(x_test[:, i], beta_x_scaled[:, i], s=0.5, c="black")
# Layout
axs_att[i].set_xlim((-4, 4))
axs_att[i].set_ylim((-1, 1))
axs_att[i].set_xlabel("x" + str(i + 1))
axs_att[i].set_ylabel("Attention β" + str(i + 1))
fig_attention.suptitle("Abbildung 2: Regression Attentions")
plt.show()
- β1
- liegt relativ konstant bei 0.5
- wenig laterale Verzerrungen --> kaum Interaktionen mit anderen Inputvariablen
- β2, β3
- wenig laterale Verzerrungen --> kaum Interaktionen mit anderen Inputvariablen
- β4, β5, β6
- Attention-Gewicht ≠ 0 --> Einfluss auf die Vorhersage
- laterale Verzerrungen --> Interaktionen mit anderen Inputvariablen
- β7, β8
- streuen um 0 --> β8 streut durch die Korrelation von x2 und x8 stärker (durch Wald-Test muss entschieden werden, ob x8 entfernt werden kann)
for i in range(8):
if i != 6:
size = beta_x_scaled.shape[0]
coverage = np.count_nonzero(
beta_x_scaled[:, i] < abs(bound)
) - np.count_nonzero(beta_x_scaled[:, i] < -abs(bound))
coverage_ratio = coverage / size
print("Coverage Ratio β" + str(i + 1) + ": " + str(coverage_ratio))
Coverage Ratio β1: 0.0
Coverage Ratio β2: 0.09149
Coverage Ratio β3: 0.01268
Coverage Ratio β4: 0.09206
Coverage Ratio β5: 0.086
Coverage Ratio β6: 0.22275
Coverage Ratio β8: 0.57559
Im Paper wird β8 entfernt, da die Coverage Ratio über 0.999 liegt. Schwankungen können bspw. durch die zufällig erzeugten Daten oder die für einen Trainingsbatch ausgewählten Beobachtungen verursacht werden. Bevor die Attention Gewichte als Parameter des GLMs verwendet werden, wird das Skalarprodukt mit der ursprünglichen Inputvariable gebildet. Die resultierende Größe ist die Feature Contribution. Eine Visualisierung dieser in Abhängigkeit von der Inputvariable zeigt deutlicher den resultierenden funktionalen Zusammenhang. Zur Verdeutlichung können zusätzlich Splines hinzugefügt werden, welche diesen approximieren.
# Feature Contribution Plot
fig_contribution = plt.figure(tight_layout=True, figsize=(30, 15))
spec = GridSpec(ncols=8, nrows=3, figure=fig_contribution)
ax1_con = fig_contribution.add_subplot(spec[0, 1:3])
ax2_con = fig_contribution.add_subplot(spec[0, 3:5])
ax3_con = fig_contribution.add_subplot(spec[0, 5:7])
ax4_con = fig_contribution.add_subplot(spec[1, 1:3])
ax5_con = fig_contribution.add_subplot(spec[1, 3:5])
ax6_con = fig_contribution.add_subplot(spec[1, 5:7])
ax7_con = fig_contribution.add_subplot(spec[2, 2:4])
ax8_con = fig_contribution.add_subplot(spec[2, 4:6])
axs_con = [ax1_con, ax2_con, ax3_con, ax4_con, ax5_con, ax6_con, ax7_con, ax8_con]
xs = np.linspace(-4, 4, 1000)
for i in range(len(axs_con)):
# Feature Contribution Splines berechnen
# Feature Contribution = beta(xi)*xi
contribution = np.column_stack([x_test[:, i], beta_x_scaled[:, i] * x_test[:, i]])
con_ind = np.lexsort((contribution[:, 1], contribution[:, 0]))
contribution_sorted = contribution[con_ind]
con_spline = interpolate.UnivariateSpline(
contribution_sorted[:, 0], contribution_sorted[:, 1]
)
# Hinzufügen von horizontalen Linien um die Stärke der Feature Contribution zu visualisieren
axs_con[i].hlines(y=0, xmin=-4, xmax=4, colors="orange", alpha=0.7, zorder=1)
axs_con[i].hlines(y=0.5, xmin=-4, xmax=4, colors="red", alpha=0.5, zorder=1)
axs_con[i].hlines(y=-0.5, xmin=-4, xmax=4, colors="red", alpha=0.5, zorder=1)
axs_con[i].hlines(y=1, xmin=-4, xmax=4, colors="lightcyan", alpha=0.7, zorder=1)
axs_con[i].hlines(y=-1, xmin=-4, xmax=4, colors="lightcyan", alpha=0.7, zorder=1)
axs_con[i].hlines(y=1.5, xmin=-4, xmax=4, colors="royalblue", alpha=0.7, zorder=1)
axs_con[i].hlines(y=-1.5, xmin=-4, xmax=4, colors="royalblue", alpha=0.7, zorder=1)
# Scatter Plot --> x: Werte der Inputfeatures, y:Feature Contribution (β(x)*x)
axs_con[i].scatter(contribution[:, 0], contribution[:, 1], s=0.5, zorder=10)
# Feature Contribution Spline plotten
axs_con[i].plot(xs, con_spline(xs), color="purple", zorder=20)
# Layout
axs_con[i].set_xlim((-4, 4))
axs_con[i].set_ylim((-2, 2))
axs_con[i].set_xlabel("x" + str(i + 1))
axs_con[i].set_ylabel("Feature Contribution (beta(x)*x)" + str(i + 1))
fig_contribution.suptitle("Abbildung 3: Feature Contributions")
plt.show()
- β1
- lineare Funktion entsprechend der zugrundeliegenden Regressionsfunktion (½x1)
- wenig laterale Verzerrungen --> kaum Interaktionen mit anderen Inputvariablen
- β2
- quadratischer Zusammenhang erkennbar (¼x22)
- wenig laterale Verzerrungen --> kaum Interaktionen mit anderen Inputvariablen
- β3
- Sinusfunktion (½ |x3| sin(2x3))
- wenig laterale Verzerrungen --> kaum Interaktionen mit anderen Inputvariablen
- β4, β5, β6
- laterale Verzerrungen --> starke Interaktionen mit anderen Inputvariablen
- β7, β8
- streuen um 0 --> β8 weist etwas mehr Interaktionen auf
Um die zuvor bereits erkannten Interaktionen genauer zu analysieren, bietet es sich an, die Gradienten der Attention Gewichte zu untersuchen. Liegt keine Interaktion zwischen einem Attention-Gewicht
Zur Darstellung der Gradienten bieten sich Regressionssplines an. Bei diesen handelt es sich um eine aus mehreren Polynomen zusammengesetzte Funktion, welche daher besonders "glatt" verläuft [3].
Im Paper verwenden Richman & Wüthrich die R Bibliothek locfit [1]. Da diese nicht für Python verfügbar ist muss auf eine andere Bibliothek ausgewichen werden. Eine Möglichkeit zur Modellierung eines Univariaten Splines bietet bspw. scipy. Dies entspricht nicht genau der Implementierung mittels locfit, ermöglicht jedoch die gleichen Schlüsse.
# Gradienten bestimmen
gradients = []
x = tf.constant(x_train)
# Für jede Inputvariable wird ein Modell gefittet, um anschließend die partiellen Ableitungen auslesen zu können
for i in range(input.shape[-1]):
# Lambda Layer als Output Schicht, um beta_i als Output zu erhalten (partielle Ableitungen ∂β_j(x)/∂x_j')
beta = attention
beta = tf.keras.layers.Lambda(lambda x: x[:, i])(beta)
grad_model = tf.keras.Model(inputs=input, outputs=beta)
# GradientTape ermöglicht das auslesen der Gradienten
with tf.GradientTape() as g:
g.watch(x)
pred_attention = grad_model.call(x)
grad = g.gradient(pred_attention, x)
# Array das sowohl den Wert von x, als auch den entsprechenden Wert von βk(x) enthält
grad_wrt_x = np.column_stack((x[:, i].numpy(), grad.numpy()))
# Um später die Splines zu modellieren muss die x-Komponente monoton steigend sein --> sortieren des Arrays
ind = np.lexsort((grad_wrt_x[:, 2], grad_wrt_x[:, 0]))
grad_wrt_x_sorted = grad_wrt_x[ind]
# Gradienten in Liste speichern
gradients.append(grad_wrt_x_sorted)
# Univariate Splines modellieren, um die Interaktion zwischen Features darzustellen
splines = []
# Für alle Attention Gewichte β
for i in range(input.shape[-1]):
splines.append([])
# Für alle Inputvariablen x
for j in range(input.shape[-1]):
splines[i].append(
interpolate.UnivariateSpline(gradients[i][:, 0], gradients[i][:, j + 1])
)
# Spline Interaction Plot
fig_spline = plt.figure(tight_layout=True, figsize=(30, 15))
spec = GridSpec(ncols=8, nrows=3, figure=fig_spline)
ax1_sp = fig_spline.add_subplot(spec[0, 1:3])
ax2_sp = fig_spline.add_subplot(spec[0, 3:5])
ax3_sp = fig_spline.add_subplot(spec[0, 5:7])
ax4_sp = fig_spline.add_subplot(spec[1, 1:3])
ax5_sp = fig_spline.add_subplot(spec[1, 3:5])
ax6_sp = fig_spline.add_subplot(spec[1, 5:7])
ax7_sp = fig_spline.add_subplot(spec[2, 2:4])
ax8_sp = fig_spline.add_subplot(spec[2, 4:6])
axs_sp = [ax1_sp, ax2_sp, ax3_sp, ax4_sp, ax5_sp, ax6_sp, ax7_sp, ax8_sp]
xs = np.linspace(-4, 4, 100)
for i in range(input.shape[-1]):
# Splines für jedes Merkmal plotten
for j in range(input.shape[-1]):
axs_sp[i].plot(xs, splines[i][j](xs), label="x" + str(j + 1))
# Inline Lables und Legende hinzufügen
labelLines(axs_sp[i].get_lines(), zorder=2.5)
axs_sp[i].legend(loc="lower right", ncol=2)
# Layout
axs_sp[i].set_xlim((-4, 4))
axs_sp[i].set_ylim((-0.5, 0.5))
axs_sp[i].set_xlabel("Feature Values x" + str(i + 1))
axs_sp[i].set_ylabel("Interaction Strengths")
axs_sp[i].set_title("Interactions of Feature Component x" + str(i + 1))
fig_spline.suptitle("Abbildung 4: Interaction Strengths")
plt.show()
- x1, x6, x7, x8
- Der Wert aller partiellen Ableitungen liegt konstant bei ≈0
- → Keine Interaktionen (β ist konstant)
- x2
- Großteil der Werte um 0 konzentriert
- x2 ≠ 0
- → Sehr geringe Interaktionen mit anderen Variablen, allerdings nicht-linearer Zusammenhang mit x2 (quadratisch)
- x3
- generell um 0 zentriert, aber größere Streuung als bei x2
- x3 ≠ 0 → Sinus ähnelndes Verhalten
- → geringe Interaktionen mit anderen Variablen
- x4
- lineare Interaktion mit x5 (const. ≈ 0.3)
- → Sehr geringe Interaktionen mit anderen Variablen
- x5
- Geringe Interaktionen mit anderen Variablen
- Stärkste Interaktionen mit x4 (linear) und x5
Nachdem der Ansatz des LocalGLMnet grundlegend anhand des synthetischen Datensatzes vorgestellt wurde, soll dieser nun auf einen realen Sachverhalt angewendet werden. Richman und Wüthrich verwenden hierfür den freMTPL (= French Motor Third-Part Liability) Datensatz, da er generell als Benchmark im Aktuarsbereich gilt [1]. Der Datensatz enthält Informationen über Kfz-Haftpflichtversicherungen und aufgetretene Schadensmeldungen.
Um ähnliche Ergebnisse wie Richman und Wüthrich [1] zu erhalten, wurden die gleichen Schritte zur Vorverarbeitung des Datensatzes durchgeführt. Diese werden in Wüthrich/Merz [4] genauer dargestellt.Im ersten Schritt werden die Datensätze FreMTPL2freq und FreMTPL2sev zusammengeführt. FreMTPL2freq enthält Informationen über die Versicherungspolicen und FreMTPL2sev über die aufgetretenen Schäden. FreMTPL2freq enthält zwar ebenfalls die Schadensanzahl, jedoch scheint es hierbei einige inkorrekte Aufzeichnungen zu geben. Eine Erläuterung der sonstigen Merkmale lässt sich Wüthrich/Merz [4, S. 555] entnehmen:
- IDpol: policy number (unique identifier)
- Exposure: total exposure in yearly units (years-at-risk) and within (0, 1 ]
- Area: area code (categorical, ordinal with 6 levels)
- VehPower: power of the car (continuous)
- VehAge age of the car in years
- DrivAge: age of the (most common) driver in years
- BonusMalus: bonus-malus level between 50 and 230 (with entrance level 100)
- VehBrand: car brand (categorical, nominal with 11 levels)
- VehGas: diesel or regular fuel car (binary)
- Density: density of population per km 2 at the location of the living place of the driver
- Region: regions in France (prior to 2016)(categorical)
Entsprechend der Anweisungen von Merz/Wüthrich [4] wurde der FreMTPL2freq Datensatz in der Version 1.0-8 über die OpenML ID 41214 heruntergeladen. Dennoch entspricht die Anzahl der Kategorien von VehBrand mit 14 nicht der Anzahl im Paper. Aus diesem Grund werden im Folgenden leichte Anpassungen vorgenommen, bspw. hat das LocalGLMnet hierdurch eine Inputdimension q=45.
Nachdem der Datensatz in ein DataFrame geladen wurde, wird den Merkmalen der zugehörige Datentyps zugeordnet und teils weitere Vorverarbeitungen vorgenommen. Der maximale Wert der Exposures wurde bspw. auf 1 begrenzt, da lediglich betrachtet wird, ob die Policen im ganzen Jahr aktiv sind. Beobachtungen mit mehr als 5 Schadensfällen werden zudem entfernt, da es sich hierbei höchstwahrscheinlich um fehlerhafte Daten handelt. Um den Einfluss der kategorialen Variablen VehBrand und Region im Modell abbilden zu können, werden diese mittels One-Hot Encoding transformiert. Um später ein Maß für die zufälligen Schwankungen der Attention-Gewichte zu haben, wird eine gleichverteilte (= RandU) und eine normalverteilte (= RandN) Störvariable hinzugefügt
Abschließend werden die Daten in Trainings- und Testdatensätze mit einem Split von 90:10 aufgeteilt und so skaliert, dass der Mittelwert null und die Standardabweichung 1 ist. Diese Standardisierung wird erst nach Aufteilung in Trainings- und Testdatensätze durchgeführt, um Information Leakage zu verhindern.
# Enthält Kundendaten von einer Kfz-Haftpflichtversicherung
freq = pd.read_csv("../data/freMTPL2freq.csv")
# Claim Anzahl entfernen (Erklärung siehe [4] Listing B.1)
freq = freq.drop(columns=["ClaimNb"])
freq["IDpol"] = freq["IDpol"].astype("int64")
# Enthält die Schadenshöhe für jeden Schaden
sev = pd.read_csv("../data/freMTPL2sev.csv")
# Schadenshöhe und Vorkommen nach Kunden-ID aggregieren
sev_agg = sev
sev_agg["ClaimNb"] = 1
sev_agg = sev_agg.groupby("IDpol").sum()[["ClaimNb", "ClaimAmount"]].reset_index()
sev_agg = sev_agg.rename(columns={"ClaimAmount": "ClaimTotal"})
# freq und sev zusammenführen --> Datensatz mit der korrekten Anzahl an Schadensmeldungen
freq = freq.merge(sev_agg, on="IDpol", how="left")
freq["ClaimNb"] = freq["ClaimNb"].fillna(0)
freq["ClaimTotal"] = freq["ClaimTotal"].fillna(0)
# Vehicle Brand als kategoriales Merkmal definieren um Reihenfolge der Brands festzulegen
freq["VehBrand"] = pd.Categorical(
freq["VehBrand"],
categories=[
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B9",
"B10",
"B11",
"B12",
"B13",
"B14",
],
)
# Area Codes einer Ordinalskala zuweisen (A=1, B=2,...)
freq["Area"] = pd.Categorical(freq["Area"], categories=["A", "B", "C", "D", "E", "F"])
freq["Area"] = freq["Area"].cat.codes + 1
freq = freq.rename(columns={"Area": "AreaCode"})
# Binäre Variable "VehGas" den Codes 0 und 1 zuordnen (Diesel = 0, Regular = 1)
freq["VehGas"] = pd.Categorical(freq["VehGas"], categories=["Diesel", "Regular"])
freq["VehGas"] = freq["VehGas"].cat.codes
# Datentyp von ClaimNb und Region anpassen
freq = freq.astype({"ClaimNb": "int64", "Region": "category"})
# Alle Einträge mit mehr als 5 Schadensmeldungen entfernen:
freq = freq[freq["ClaimNb"] <= 5]
# Exposure kann maximal 1 sein --> alle Beobachtungen mit höheren Werten auf 1 setzen:
freq["Exposure"] = freq["Exposure"].clip(lower=0, upper=1)
# Log(Density)
freq["log_Density"] = np.log(freq["Density"])
freq = freq.drop(columns=["Density"])
# Alle Einträge aus sev entfernen die jetzt nicht mehr in freq enthalten sind:
sev = sev[sev["IDpol"].isin(freq["IDpol"])][["IDpol", "ClaimAmount"]]
freq.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
IDpol | Exposure | AreaCode | VehPower | VehAge | DrivAge | BonusMalus | VehBrand | VehGas | Region | ClaimNb | ClaimTotal | log_Density | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0.10 | 4 | 5 | 0 | 55 | 50 | B12 | 1 | R82 | 0 | 0.0 | 7.104144 |
1 | 3 | 0.77 | 4 | 5 | 0 | 55 | 50 | B12 | 1 | R82 | 0 | 0.0 | 7.104144 |
2 | 5 | 0.75 | 2 | 6 | 2 | 52 | 50 | B12 | 0 | R22 | 0 | 0.0 | 3.988984 |
3 | 10 | 0.09 | 2 | 7 | 0 | 46 | 50 | B12 | 0 | R72 | 0 | 0.0 | 4.330733 |
4 | 11 | 0.84 | 2 | 7 | 0 | 46 | 50 | B12 | 0 | R72 | 0 | 0.0 | 4.330733 |
sev.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
IDpol | ClaimAmount | |
---|---|---|
0 | 1552 | 995.20 |
1 | 1010996 | 1128.12 |
2 | 4024277 | 1851.11 |
3 | 4007252 | 1204.00 |
4 | 4046424 | 1204.00 |
# Zufällige Störvariablen hinzufügen, um später für die Regression Attentions ein Maß für die Streung um 0 definieren zu können
# Normalverteilte Zufallsvariable RandN
freq["RandN"] = rng.standard_normal(size=(freq.shape[0], 1))
# Gleichverteilte Zufallsvariable RandU (standardisiert)
freq["RandU"] = rng.uniform(size=(freq.shape[0], 1))
# Kategoriale Merkmale One-Hot Encoden (k-Kategorien führen zu k-1 Spalten)
categorical_columns = ["VehBrand", "Region"]
freq = pd.get_dummies(freq, columns=categorical_columns, drop_first=False)
# Datensatz in Merkmale x und Zielvariable y aufteilen
y_freq = freq["ClaimNb"]
x_freq = freq.drop(columns=["IDpol", "ClaimNb", "ClaimTotal"])
# Aufteilen in Trainings- und Testdaten
x_freq_train, x_freq_test, y_freq_train, y_freq_test = train_test_split(
x_freq, y_freq, test_size=0.1
)
# Exposures getrennt speichern
exposures_train = x_freq_train["Exposure"]
exposures_test = x_freq_test["Exposure"]
x_freq_train = x_freq_train.drop(columns=["Exposure"])
x_freq_test = x_freq_test.drop(columns=["Exposure"])
# Stetige und binäre Merkmale standardisieren:
continuous_columns = [
"AreaCode",
"BonusMalus",
"log_Density",
"DrivAge",
"VehAge",
"VehPower",
]
binary_columns = ["VehGas"]
x_freq_train_sc = x_freq_train.copy()
x_freq_test_sc = x_freq_test.copy()
# Trainings- und Testdatensatz werden getrennt standardisiert, um Information Leakage der Testdaten zu verhindern
scaler_freq = StandardScaler()
x_freq_train_sc[continuous_columns + binary_columns] = scaler_freq.fit_transform(
x_freq_train_sc[continuous_columns + binary_columns]
)
x_freq_test_sc[continuous_columns + binary_columns] = scaler_freq.transform(
x_freq_test_sc[continuous_columns + binary_columns]
)
# Zufallsvariable RandU standardisieren
scaler_freq_rand = StandardScaler()
x_freq_train_sc["RandU"] = scaler_freq_rand.fit_transform(x_freq_train_sc[["RandU"]])
x_freq_test_sc["RandU"] = scaler_freq_rand.transform(x_freq_test_sc[["RandU"]])
# LocalGLMnet Modell strukturieren
# LocalGLMnet nimmt als Input sowohl die Exposure als auch die Merkmale x
input_freq = tf.keras.Input(shape=(45), dtype="float32", name="Input")
vol_freq = tf.keras.Input(shape=(1), dtype="float32", name="Vol")
# Hidden Layer welche bis hin zur Attention Schicht mit 42 Neuronen (= Anzahl Inputmerkmale) führt
attention_freq = input_freq
attention_freq = tf.keras.layers.Dense(units=20, activation="tanh", name="Layer1")(
attention_freq
)
attention_freq = tf.keras.layers.Dense(units=15, activation="tanh", name="Layer2")(
attention_freq
)
attention_freq = tf.keras.layers.Dense(units=10, activation="tanh", name="Layer3")(
attention_freq
)
attention_freq = tf.keras.layers.Dense(units=45, activation="linear", name="Attention")(
attention_freq
)
# Skip-Connection
local_glm_freq = tf.keras.layers.Dot(name="LocalGLM", axes=1)(
[input_freq, attention_freq]
)
# Fügt Intercept hinzu
local_glm_freq = tf.keras.layers.Dense(
units=1, activation="exponential", name="Balance"
)(local_glm_freq)
# Response Schicht multipliziert Output des Netzes mit der Exposure
response_freq = tf.keras.layers.Multiply(name="Multiply")([local_glm_freq, vol_freq])
# Modell kompilieren
local_glm_net_freq = tf.keras.Model(
inputs=[input_freq, vol_freq], outputs=response_freq
)
local_glm_net_freq.compile(loss="poisson", optimizer="nadam")
local_glm_net_freq.summary()
Model: "model_10"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input (InputLayer) [(None, 45)] 0
__________________________________________________________________________________________________
Layer1 (Dense) (None, 20) 920 Input[0][0]
__________________________________________________________________________________________________
Layer2 (Dense) (None, 15) 315 Layer1[0][0]
__________________________________________________________________________________________________
Layer3 (Dense) (None, 10) 160 Layer2[0][0]
__________________________________________________________________________________________________
Attention (Dense) (None, 45) 495 Layer3[0][0]
__________________________________________________________________________________________________
LocalGLM (Dot) (None, 1) 0 Input[0][0]
Attention[0][0]
__________________________________________________________________________________________________
Balance (Dense) (None, 1) 2 LocalGLM[0][0]
__________________________________________________________________________________________________
Vol (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
Multiply (Multiply) (None, 1) 0 Balance[0][0]
Vol[0][0]
==================================================================================================
Total params: 1,892
Trainable params: 1,892
Non-trainable params: 0
__________________________________________________________________________________________________
# Modell trainieren
history_freq = local_glm_net_freq.fit(
[x_freq_train_sc, exposures_train],
y_freq_train,
batch_size=5000,
epochs=100,
validation_split=0.2,
)
Epoch 1/100
98/98 [==============================] - 2s 9ms/step - loss: 0.3123 - val_loss: 0.1945
Epoch 2/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1795 - val_loss: 0.1676
Epoch 3/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1678 - val_loss: 0.1626
Epoch 4/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1642 - val_loss: 0.1601
Epoch 5/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1620 - val_loss: 0.1584
Epoch 6/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1605 - val_loss: 0.1573
Epoch 7/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1593 - val_loss: 0.1563
Epoch 8/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1585 - val_loss: 0.1557
Epoch 9/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1578 - val_loss: 0.1553
Epoch 10/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1574 - val_loss: 0.1550
Epoch 11/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1571 - val_loss: 0.1548
Epoch 12/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1568 - val_loss: 0.1546
Epoch 13/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1566 - val_loss: 0.1545
Epoch 14/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1565 - val_loss: 0.1544
Epoch 15/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1563 - val_loss: 0.1543
Epoch 16/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1562 - val_loss: 0.1542
Epoch 17/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1561 - val_loss: 0.1542
Epoch 18/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1560 - val_loss: 0.1541
Epoch 19/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1559 - val_loss: 0.1540
Epoch 20/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1558 - val_loss: 0.1539
Epoch 21/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1558 - val_loss: 0.1539
Epoch 22/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1557 - val_loss: 0.1538
Epoch 23/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1556 - val_loss: 0.1538
Epoch 24/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1556 - val_loss: 0.1538
Epoch 25/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1555 - val_loss: 0.1538
Epoch 26/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1555 - val_loss: 0.1537
Epoch 27/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1554 - val_loss: 0.1537
Epoch 28/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1553 - val_loss: 0.1537
Epoch 29/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1553 - val_loss: 0.1536
Epoch 30/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1553 - val_loss: 0.1536
Epoch 31/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1552 - val_loss: 0.1537
Epoch 32/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1552 - val_loss: 0.1536
Epoch 33/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1551 - val_loss: 0.1536
Epoch 34/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1551 - val_loss: 0.1536
Epoch 35/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1551 - val_loss: 0.1536
Epoch 36/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1550 - val_loss: 0.1536
Epoch 37/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1550 - val_loss: 0.1537
Epoch 38/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1550 - val_loss: 0.1536
Epoch 39/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 40/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1549 - val_loss: 0.1536
Epoch 41/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 42/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 43/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 44/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 45/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 46/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 47/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 48/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 49/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1547 - val_loss: 0.1537
Epoch 50/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 51/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 52/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 53/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1546 - val_loss: 0.1537
Epoch 54/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 55/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 56/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 57/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1545 - val_loss: 0.1535
Epoch 58/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1545 - val_loss: 0.1535
Epoch 59/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1545 - val_loss: 0.1536
Epoch 60/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1545 - val_loss: 0.1537
Epoch 61/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1545 - val_loss: 0.1537
Epoch 62/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 63/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 64/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 65/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 66/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 67/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1544 - val_loss: 0.1537
Epoch 68/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1544 - val_loss: 0.1536
Epoch 69/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1543 - val_loss: 0.1537
Epoch 70/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1543 - val_loss: 0.1536
Epoch 71/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1543 - val_loss: 0.1536
Epoch 72/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1543 - val_loss: 0.1536
Epoch 73/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1543 - val_loss: 0.1536
Epoch 74/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1543 - val_loss: 0.1537
Epoch 75/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1542 - val_loss: 0.1536
Epoch 76/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 77/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 78/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 79/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 80/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 81/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 82/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1542 - val_loss: 0.1537
Epoch 83/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1541 - val_loss: 0.1537
Epoch 84/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1541 - val_loss: 0.1539
Epoch 85/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1541 - val_loss: 0.1537
Epoch 86/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1541 - val_loss: 0.1537
Epoch 87/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1541 - val_loss: 0.1538
Epoch 88/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1541 - val_loss: 0.1537
Epoch 89/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1541 - val_loss: 0.1538
Epoch 90/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 91/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 92/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 93/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 94/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1540 - val_loss: 0.1539
Epoch 95/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 96/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 97/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1540 - val_loss: 0.1538
Epoch 98/100
98/98 [==============================] - 1s 6ms/step - loss: 0.1540 - val_loss: 0.1537
Epoch 99/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1539 - val_loss: 0.1538
Epoch 100/100
98/98 [==============================] - 1s 5ms/step - loss: 0.1539 - val_loss: 0.1538
# Über die Methode get_weights() erhält man die Kantengewichte, sowie den Bias für jeder Schicht
# --> man erhält also eine Liste mit numpy Arrays die in der Länge der Anzahl der Ebenen * 2 entspricht
weights_freq = local_glm_net_freq.get_weights()
for i in weights:
print(i.shape, end=" | ")
(8, 20) | (20,) | (20, 15) | (15,) | (15, 10) | (10,) | (10, 8) | (8,) | (1, 1) | (1,) |
# Neues Model ohne Response-Schicht --> ermöglicht auslesen der Attention Gewichte
# Benötigt als Input nur die Features, nicht die Exposures, da diese erst im späteren Layer erforderlich werden
weights_model_freq = tf.keras.Model(
inputs=local_glm_net_freq.inputs[0],
outputs=local_glm_net_freq.get_layer(name="Attention").output,
)
# Gewichte bestimmen
beta_x_freq = weights_model_freq.predict(x_freq_test_sc)
# Skalierung der Attention-Gewichte mithilfe des Gewichts der Response Schicht ( = Intercept beta_0)
beta_x_freq_scaled = beta_x_freq * local_glm_net_freq.get_weights()[8]
# Als DataFrame speichern um mittels der Merkmalsnamen auf die Attention Gewichte zugreifen zu können
beta_x_freq_scaled = pd.DataFrame(beta_x_freq_scaled, columns=x_freq_test_sc.columns)
# Intervallgrenzen auf Basis der Zufallsvariablen RandN und RandU bestimmen
randn_mean = beta_x_freq_scaled["RandN"].mean()
randn_std = beta_x_freq_scaled["RandN"].std()
randu_mean = beta_x_freq_scaled["RandU"].mean()
randu_std = beta_x_freq_scaled["RandU"].std()
rand_mean = (randn_mean + randu_mean) / 2
rand_std = (randn_std + randu_std) / 2
print("Mittelwert RandN: " + str(randn_mean))
print("Standardabweichung RandN: " + str(randn_std))
print("\nMittelwert RandU: " + str(randu_mean))
print("Standardabweichung RandU: " + str(randu_std))
print("\nMittelwert Gesamt: " + str(rand_mean))
print("Standardabweichung Gesamt: " + str(rand_std))
# Intervalgrenzen bestimmen
alpha_freq = 0.001
bound_freq = stats.norm.ppf(alpha_freq / 2) * rand_std
print(
"\nQuantil " + str(1 - alpha_freq / 2) + ": " + str(stats.norm.ppf(alpha_freq / 2))
)
print("Grenzen: ± " + str(abs(bound_freq)))
Mittelwert RandN: 0.06710727
Standardabweichung RandN: 0.08473704
Mittelwert RandU: -0.040262144
Standardabweichung RandU: 0.108501375
Mittelwert Gesamt: 0.013422561809420586
Standardabweichung Gesamt: 0.09661920368671417
Quantil 0.9995: -3.2905267314918945
Grenzen: ± 0.31792807250659316
# Indizes der Testdaten zurücksetzen, damit sie mit beta_x_freq_scaled übereinstimmen
x_att = x_freq_test.copy()
x_att.reset_index(inplace=True)
# Attention Plot freq-Datensatz
fig_freq_attention, axs_freq_att = plt.subplots(nrows=3, ncols=3, figsize=(30, 15))
# Merkmale festlegen für die ein Attention Subplot erstellt werden soll
columns = continuous_columns + binary_columns + ["RandN", "RandU"]
for i, ax in enumerate(axs_freq_att.flatten()):
# Für VehGas wird ein Boxplot geplottet, da es eine binäre Variable ist und ein normaler Scatterplot nicht viel Sinn ergibt
if columns[i] == "VehGas":
diesel_index = x_att[x_att["VehGas"] == 0].index
regular_index = x_att[x_att["VehGas"] == 1].index
ax.boxplot(
[
beta_x_freq_scaled.loc[diesel_index]["VehGas"],
beta_x_freq_scaled.loc[regular_index]["VehGas"],
],
labels=["Diesel", "Regular"],
zorder=10,
)
# Scatterplot für alle anderen Merkmale hinzufügen
else:
ax.scatter(
x_att[columns[i]],
beta_x_freq_scaled[columns[i]],
s=0.5,
c="black",
zorder=10,
)
# x-Grenzen des Plots abfragen
x_min, x_max = ax.get_xlim()
# Intervallgrenzen
interval = patches.Rectangle(
xy=(x_min, -abs(bound_freq)),
height=2 * abs(bound_freq),
width=x_max - x_min,
edgecolor="royalblue",
facecolor="lightcyan",
alpha=0.8,
zorder=1,
)
ax.add_patch(interval)
# Linien zur Verdeutlichung der Höhe der Attention Gewichte
ax.hlines(y=0.25, xmin=x_min, xmax=x_max, colors="orange", linestyles="dashed")
ax.hlines(y=-0.25, xmin=x_min, xmax=x_max, colors="orange", linestyles="dashed")
ax.hlines(y=0, xmin=x_min, xmax=x_max, colors="red")
# Layout
ax.set_xlabel(columns[i])
ax.set_ylabel("Regression Attention")
ax.set_ylim((-1, 1))
fig_freq_attention.suptitle("Abbildung 5: Regression Attentions")
plt.show()
Der Attention Plot lässt bereits darauf schließen, dass der Zusammenhang zwischen VehGas, VehPower und AreaCode mit der Zielvariable nur sehr gering ausfällt. Ein Großteil der Attention-Gewichte fällt in die durch die Störvariablen berechneten Konfidenzintervalle. Für DriveAge, VehAge und BonusMalus ist ein deutlicher Zusammenhang zu erkennen. Bevor auf diese mittels des Feature Contribution Plots genauer eingegangen wird, liefert der Hypothesentest auf Basis der Coverage Ratio eine Aussage darüber, welche Merkmale entfernt werden sollten:
for col in columns:
if col not in ["RandN", "RandU"]:
size = beta_x_freq_scaled.shape[0]
coverage = np.count_nonzero(
beta_x_freq_scaled[col] < abs(bound_freq)
) - np.count_nonzero(beta_x_freq_scaled[col] < -abs(bound_freq))
coverage_ratio = coverage / size
print("Coverage Ratio " + col + ": " + str(coverage_ratio))
Coverage Ratio AreaCode: 0.9636436040766361
Coverage Ratio BonusMalus: 0.2912198935118951
Coverage Ratio log_Density: 0.9910620787304022
Coverage Ratio DrivAge: 0.5153463813218094
Coverage Ratio VehAge: 0.8742643913806581
Coverage Ratio VehPower: 0.9967404610551467
Coverage Ratio VehGas: 0.08660639223610271
area_density = []
labels = []
for i in np.sort(x_att["AreaCode"].unique()):
index = x_att[x_att["AreaCode"] == i].index
area_density.append(x_att.loc[index]["log_Density"])
labels.append(int(i))
plt.boxplot(x=area_density, labels=labels)
plt.xlabel("Area Code")
plt.ylabel("log(Density)")
plt.title("Abbildung 6: Area Code vs. Density")
plt.show()
# Nur Merkmale behalten die signifikant sind (RandN, RandU, AreaCode, VehPower entfernen)
x_freq_sig_train = x_freq_train.drop(columns=["RandN", "RandU", "AreaCode", "VehPower"])
x_freq_sig_test = x_freq_test.drop(columns=["RandN", "RandU", "AreaCode", "VehPower"])
sig_columns = continuous_columns + binary_columns
sig_columns.remove("VehPower")
sig_columns.remove("AreaCode")
print(sig_columns)
['BonusMalus', 'log_Density', 'DrivAge', 'VehAge', 'VehGas']
# Neuen Datensatz mit wichtigen Spalten skalieren (mean=0, std=1)
x_freq_sig_train_sc = x_freq_sig_train.copy()
x_freq_sig_test_sc = x_freq_sig_test.copy()
scaler_freq_sig = StandardScaler()
x_freq_sig_train_sc[sig_columns] = scaler_freq.fit_transform(
x_freq_sig_train_sc[sig_columns]
)
x_freq_sig_test_sc[sig_columns] = scaler_freq.transform(x_freq_sig_test_sc[sig_columns])
# Neues LocalGLMnet erstellen (41 Input Merkmale)
input_freq_sig = tf.keras.Input(shape=(41), dtype="float32", name="Input")
vol_freq_sig = tf.keras.Input(shape=(1), dtype="float32", name="Vol")
attention_freq_sig = input_freq_sig
attention_freq_sig = tf.keras.layers.Dense(units=20, activation="tanh", name="Layer1")(
attention_freq_sig
)
attention_freq_sig = tf.keras.layers.Dense(units=15, activation="tanh", name="Layer2")(
attention_freq_sig
)
attention_freq_sig = tf.keras.layers.Dense(units=10, activation="tanh", name="Layer3")(
attention_freq_sig
)
attention_freq_sig = tf.keras.layers.Dense(
units=41, activation="linear", name="Attention"
)(attention_freq_sig)
# Skip-Connection
local_glm_freq_sig = tf.keras.layers.Dot(name="LocalGLM", axes=1)(
[input_freq_sig, attention_freq_sig]
)
# Fügt Intercept hinzu
local_glm_freq_sig = tf.keras.layers.Dense(
units=1, activation="exponential", name="Balance"
)(local_glm_freq_sig)
# Response Schicht
response_freq_sig = tf.keras.layers.Multiply(name="Multiply")(
[local_glm_freq_sig, vol_freq_sig]
)
# Modell kompilieren
local_glm_net_freq_sig = tf.keras.Model(
inputs=[input_freq_sig, vol_freq_sig], outputs=response_freq_sig
)
local_glm_net_freq_sig.compile(loss="poisson", optimizer="nadam")
local_glm_net_freq_sig.summary()
Model: "model_12"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input (InputLayer) [(None, 41)] 0
__________________________________________________________________________________________________
Layer1 (Dense) (None, 20) 840 Input[0][0]
__________________________________________________________________________________________________
Layer2 (Dense) (None, 15) 315 Layer1[0][0]
__________________________________________________________________________________________________
Layer3 (Dense) (None, 10) 160 Layer2[0][0]
__________________________________________________________________________________________________
Attention (Dense) (None, 41) 451 Layer3[0][0]
__________________________________________________________________________________________________
LocalGLM (Dot) (None, 1) 0 Input[0][0]
Attention[0][0]
__________________________________________________________________________________________________
Balance (Dense) (None, 1) 2 LocalGLM[0][0]
__________________________________________________________________________________________________
Vol (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
Multiply (Multiply) (None, 1) 0 Balance[0][0]
Vol[0][0]
==================================================================================================
Total params: 1,768
Trainable params: 1,768
Non-trainable params: 0
__________________________________________________________________________________________________
# Modell trainieren
history_freq_sig = local_glm_net_freq_sig.fit(
[x_freq_sig_train_sc, exposures_train],
y_freq_train,
batch_size=5000,
epochs=100,
validation_split=0.2,
)
Epoch 1/100
98/98 [==============================] - 2s 9ms/step - loss: 0.3150 - val_loss: 0.1937
Epoch 2/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1776 - val_loss: 0.1646
Epoch 3/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1646 - val_loss: 0.1591
Epoch 4/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1610 - val_loss: 0.1569
Epoch 5/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1593 - val_loss: 0.1558
Epoch 6/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1583 - val_loss: 0.1552
Epoch 7/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1577 - val_loss: 0.1548
Epoch 8/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1573 - val_loss: 0.1546
Epoch 9/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1571 - val_loss: 0.1544
Epoch 10/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1569 - val_loss: 0.1543
Epoch 11/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1567 - val_loss: 0.1543
Epoch 12/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1566 - val_loss: 0.1541
Epoch 13/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1565 - val_loss: 0.1541
Epoch 14/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1564 - val_loss: 0.1540
Epoch 15/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1563 - val_loss: 0.1540
Epoch 16/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1562 - val_loss: 0.1539
Epoch 17/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1561 - val_loss: 0.1539
Epoch 18/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1561 - val_loss: 0.1539
Epoch 19/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1560 - val_loss: 0.1538
Epoch 20/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1560 - val_loss: 0.1538
Epoch 21/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1559 - val_loss: 0.1538
Epoch 22/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1559 - val_loss: 0.1537
Epoch 23/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1558 - val_loss: 0.1537
Epoch 24/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1558 - val_loss: 0.1537
Epoch 25/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1557 - val_loss: 0.1537
Epoch 26/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1557 - val_loss: 0.1536
Epoch 27/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1556 - val_loss: 0.1536
Epoch 28/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1556 - val_loss: 0.1536
Epoch 29/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1556 - val_loss: 0.1536
Epoch 30/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1555 - val_loss: 0.1536
Epoch 31/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1555 - val_loss: 0.1536
Epoch 32/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1555 - val_loss: 0.1536
Epoch 33/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1555 - val_loss: 0.1535
Epoch 34/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1554 - val_loss: 0.1536
Epoch 35/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1554 - val_loss: 0.1535
Epoch 36/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1554 - val_loss: 0.1535
Epoch 37/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1554 - val_loss: 0.1535
Epoch 38/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1553 - val_loss: 0.1535
Epoch 39/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1553 - val_loss: 0.1535
Epoch 40/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1553 - val_loss: 0.1534
Epoch 41/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1553 - val_loss: 0.1535
Epoch 42/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1552 - val_loss: 0.1535
Epoch 43/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1552 - val_loss: 0.1535
Epoch 44/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1552 - val_loss: 0.1535
Epoch 45/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1552 - val_loss: 0.1535
Epoch 46/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1552 - val_loss: 0.1535
Epoch 47/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1552 - val_loss: 0.1534
Epoch 48/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1551 - val_loss: 0.1535
Epoch 49/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1551 - val_loss: 0.1535
Epoch 50/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1551 - val_loss: 0.1535
Epoch 51/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1551 - val_loss: 0.1534
Epoch 52/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1551 - val_loss: 0.1535
Epoch 53/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1551 - val_loss: 0.1535
Epoch 54/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1550 - val_loss: 0.1535
Epoch 55/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1550 - val_loss: 0.1534
Epoch 56/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1550 - val_loss: 0.1535
Epoch 57/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1550 - val_loss: 0.1534
Epoch 58/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1550 - val_loss: 0.1534
Epoch 59/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1550 - val_loss: 0.1534
Epoch 60/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1536
Epoch 61/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 62/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 63/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 64/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 65/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 66/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 67/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 68/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1549 - val_loss: 0.1535
Epoch 69/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 70/100
98/98 [==============================] - 1s 10ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 71/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 72/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 73/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 74/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1548 - val_loss: 0.1536
Epoch 75/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 76/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 77/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1548 - val_loss: 0.1535
Epoch 78/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 79/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 80/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 81/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 82/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 83/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 84/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1547 - val_loss: 0.1537
Epoch 85/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 86/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1547 - val_loss: 0.1535
Epoch 87/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1547 - val_loss: 0.1536
Epoch 88/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 89/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 90/100
98/98 [==============================] - 1s 9ms/step - loss: 0.1546 - val_loss: 0.1537
Epoch 91/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 92/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1546 - val_loss: 0.1535
Epoch 93/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1535
Epoch 94/100
98/98 [==============================] - 1s 12ms/step - loss: 0.1546 - val_loss: 0.1537
Epoch 95/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 96/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 97/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1546 - val_loss: 0.1536
Epoch 98/100
98/98 [==============================] - 1s 8ms/step - loss: 0.1546 - val_loss: 0.1535
Epoch 99/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1545 - val_loss: 0.1536
Epoch 100/100
98/98 [==============================] - 1s 7ms/step - loss: 0.1545 - val_loss: 0.1536
# Neues Model ohne Response-Schicht --> ermöglicht auslesen der Attention Gewichte
# benötigt als Input nur die Features, nicht die Exposures, da diese erst im späteren Layer erforderlich werden
weights_model_freq_sig = tf.keras.Model(
inputs=local_glm_net_freq_sig.inputs[0],
outputs=local_glm_net_freq_sig.get_layer(name="Attention").output,
)
# Gewichte bestimmen
beta_x_freq_sig = weights_model_freq_sig.predict(x_freq_sig_test_sc)
# Skalierung der Attention-Gewichte mithilfe des Gewichts der Response Schicht ( = Intercept beta_0)
beta_x_freq_sig_sc = beta_x_freq_sig * local_glm_net_freq_sig.get_weights()[8]
# Als DataFrame speichern um mittels der Merkmalsnamen auf die Attention Gewichte zugreifen zu können
beta_x_freq_sig_sc = pd.DataFrame(
beta_x_freq_sig_sc, columns=x_freq_sig_test_sc.columns
)
# Feature Contribution Plot freq-Datensatz
# Index der Testdaten zurücsetzen
x_con = x_freq_sig_test.reset_index(drop=True)
# Feature Contribution berechnen (mit skalierten Featurewert)
feature_con = x_freq_sig_test_sc.reset_index(drop=True) * beta_x_freq_sig_sc
fig_freq_con = plt.figure(tight_layout=True, figsize=(30, 15))
spec = GridSpec(ncols=6, nrows=2, figure=fig_freq_con)
ax1_freq_con = fig_freq_con.add_subplot(spec[0, 0:2])
ax2_freq_con = fig_freq_con.add_subplot(spec[0, 2:4])
ax3_freq_con = fig_freq_con.add_subplot(spec[0, 4:6])
ax4_freq_con = fig_freq_con.add_subplot(spec[1, 1:3])
ax5_freq_con = fig_freq_con.add_subplot(spec[1, 3:5])
axs_freq_con = [ax1_freq_con, ax2_freq_con, ax3_freq_con, ax4_freq_con, ax5_freq_con]
for i in range(len(axs_freq_con)):
# Feature Contribution Splines berechnen
# Feature Contribution = beta(xi)*xi
# Wenn es sich um das Merkmal "VehGas" handelt, wird ein Boxplot hinzugefügt
if sig_columns[i] == "VehGas":
diesel_index = x_con[x_con["VehGas"] == 0].index
regular_index = x_con[x_con["VehGas"] == 1].index
axs_freq_con[i].boxplot(
[
feature_con.loc[diesel_index]["VehGas"],
feature_con.loc[regular_index]["VehGas"],
],
labels=["Diesel", "Regular"],
zorder=10,
)
x_min, x_max = axs_freq_con[i].get_xlim()
# Ansonsten wird ein Scatterplot + Spline hinzugefügt
else:
contribution = np.column_stack(
[x_con[sig_columns[i]], feature_con[sig_columns[i]]]
)
con_ind = np.lexsort((contribution[:, 1], contribution[:, 0]))
contribution_sorted = contribution[con_ind]
con_spline = interpolate.UnivariateSpline(
contribution_sorted[:, 0], contribution_sorted[:, 1]
)
# Scatter Plot --> x: Werte der Inputfeatures, y:Feature Contribution (β(x)*x)
axs_freq_con[i].scatter(
contribution[:, 0], contribution[:, 1], s=0.5, zorder=10
)
# X min und x max festlegen (min = kleinster Wert, max= mean+3*std)
x_min = x_con[sig_columns[i]].min()
x_max = x_con[sig_columns[i]].mean() + 3 * x_con[sig_columns[i]].std()
xs = np.linspace(x_min, x_max, 1000)
# Feature Contribution Spline plotten
axs_freq_con[i].plot(xs, con_spline(xs), color="purple", zorder=20)
axs_freq_con[i].set_xlim((x_min, x_max))
# Hinzufügen von horizontalen Linien um die Stärke der Feature Contribution zu visualisieren
axs_freq_con[i].hlines(
y=0, xmin=x_min, xmax=x_max, colors="red", alpha=0.7, zorder=1
)
axs_freq_con[i].hlines(
y=0.25, xmin=x_min, xmax=x_max, colors="orange", linestyles="dashed"
)
axs_freq_con[i].hlines(
y=-0.25, xmin=x_min, xmax=x_max, colors="orange", linestyles="dashed"
)
# Layout
axs_freq_con[i].set_ylim((-2, 2))
axs_freq_con[i].set_xlabel(sig_columns[i])
axs_freq_con[i].set_title("Feature Contribution: " + sig_columns[i])
axs_freq_con[i].set_ylabel("Feature Contribution")
fig_freq_con.suptitle("Abbildung 7: Feature Contribution")
plt.show()
Vor allem bei BonusMalus, DrivAge und VehAge zeigen sich nachvollziehbare Zusammenhänge. Bei dem Fahreralter ist die Feature Contribution bspw. besonders für sehr junge Fahrer und sehr alte Fahrer hoch. Zu Beginn kann sich dies auf fehlende Erfahrung, später auf nachlassende Reaktionsfähigkeit oder Sehstärke zurückführen. Die Feature Contribution des Bonus-Malus verläuft in etwa gespiegelt, weshalb Richmann und Wüthrich eine Interaktion zwischen den beiden Variablen unterstellen [1]. Dies liegt daran, dass ein Fahranfänger im Bonus-Malus-System bei 100 startet und dieser Wert erst mit zunehmender Erfahrung abnimmt. Dieser Zusammenhang lässt sich ebenfalls in Abbildung 6 des Papers erkennen [1]. Um die Feature Contribution von kategorialen Merkmalen darzustellen, müssen diese mittels One-Hot-Encoding und nicht Dummy-Encoding (k-Merkmale führen zu k-1 Spalten) encodiert worden sein. Bei Dummy Encoding wäre es nicht möglich, die Feature Contribution für die wegfallende Kategorie zu berechnen, da sie keiner Spalte zugeordnet werden kann. Eine genaue Erläuterung der Vor- und Nachteile von One-Hot Encoding gegenüber Dummy Encoding findet sich in Richman/Wüthrich Abschnitt 3.6 [1].
Da es sich bei den Werten der One-Hot encodierten Kategorien nur um 0 oder 1 handelt, entsprechen die Attention-Gewichte ebenfalls der Feature Contribution (β*x). Sie werden in Abbildung 8 dargestellt. Während die Darstellung für relativ wenige Ausprägungen möglich ist, wird sie schnell unübersichtlich.
regions_con = beta_x_freq_sig_sc.filter(regex="Region*")
regions_con.columns = regions_con.columns.str.replace("Region_", "")
brands_con = beta_x_freq_sig_sc.filter(regex="VehBrand*")
brands_con.columns = brands_con.columns.str.replace("VehBrand_", "")
fig_cat, axs_cat = plt.subplots(nrows=1, ncols=2, figsize=(30, 10))
axs_cat[0].boxplot(x=regions_con)
axs_cat[0].set_xticklabels(labels=regions_con.columns, rotation=90)
axs_cat[0].set_ylabel("Feature Contribution")
axs_cat[0].set_xlabel("Regions")
axs_cat[0].set_title("Feature Contribution: Regions")
axs_cat[1].boxplot(x=brands_con)
axs_cat[1].set_xticklabels(labels=brands_con.columns, rotation=90)
axs_cat[1].set_ylabel("Feature Contribution")
axs_cat[1].set_xlabel("Vehicle Brand")
axs_cat[1].set_title("Feature Contribution: VehBrand")
plt.suptitle("Abbildung 8: Feature Contribution kategorialer Variablen")
plt.show()
Die stärkste Auswirkung auf die Vorhersage bei den Regionen hat die Ausprägung R25. Bei den Fahrzeugmarken fällt vor allem B9 auf. Um die Interaktionen der (binären und stetigen) Merkmale zu analysieren, werden erneut die Gradienten mithilfe von Splines dargestellt. Hierfür werden zuerst wie auch bei dem synthetischen Datensatz die Gradienten ermittelt.
# Stetige Spalten festlegen (Splines nur für stetige Spalten)
sig_continuous_columns = ["BonusMalus", "log_Density", "DrivAge", "VehAge"]
sig_continuous_columns_id = []
sig_columns_id = []
# Spaltenindex der stetigen Merkmale ermitteln
for col in sig_continuous_columns:
sig_continuous_columns_id.append(x_freq_sig_train.columns.get_loc(col))
# Spaltenindex der signifikanten Merkmale ermitteln (stetige Merkmale + VehGas)
for col in sig_columns:
sig_columns_id.append(x_freq_sig_train.columns.get_loc(col))
# Gradienten bestimmen
gradients_freq = []
x_freq_grad = tf.constant(x_freq_sig_train_sc)
# Für jede Inputvariable wird ein Modell gefittet, um anschließend die partiellen Ableitungen auslesen zu können
for i in range(len(sig_continuous_columns)):
# Lambda Layer als Output Schicht, um beta_i als Output zu erhalten (partielle Ableitungen ∂β_j(x)/∂x_j')
beta = attention_freq_sig
beta = tf.keras.layers.Lambda(lambda x: x[:, sig_continuous_columns_id[i]])(beta)
grad_model = tf.keras.Model(inputs=input_freq_sig, outputs=beta)
# GradientTape ermöglicht das auslesen der Gradienten
with tf.GradientTape() as g:
g.watch(x_freq_grad)
pred_attention = grad_model.call(x_freq_grad)
grad = g.gradient(pred_attention, x_freq_grad)
# Array das sowohl den Wert von x, als auch den entsprechenden Wert von βk(x) enthält
grad_wrt_x = np.column_stack(
(x_freq_sig_train[sig_continuous_columns[i]], grad.numpy())
)
# Um später die Splines zu modellieren muss die x-Komponente monoton steigend sein --> sortieren des Arrays
ind = np.lexsort((grad_wrt_x[:, 2], grad_wrt_x[:, 0]))
grad_wrt_x_sorted = grad_wrt_x[ind]
# Gradienten in Liste speichern
gradients_freq.append(grad_wrt_x_sorted)
# Univariate Splines modellieren, um die Interaktion zwischen Features darzustellen
freq_splines = []
# Für alle Attention Gewichte β (nur stetige Merkmale)
for i in range(len(sig_continuous_columns)):
freq_splines.append([])
# Für alle Inputvariablen x (auch VehGas)
for j in range(len(sig_columns_id)):
freq_splines[i].append(
interpolate.UnivariateSpline(
gradients_freq[i][:, 0],
gradients_freq[i][:, sig_columns_id[j] + 1],
)
)
# Spline Interaction Plot freq-Datensatz
# x_freq_spline enthält nur stetige Spalten
x_freq_spline = x_freq_sig_train[sig_continuous_columns]
fig_freq_spline, axs_freq_spline = plt.subplots(nrows=2, ncols=2, figsize=(30, 15))
for i, ax in enumerate(axs_freq_spline.flatten()):
# x_min und x_max festlegen und darauf basierend Linspace für die Splines erzeugen
x_min = x_freq_spline.iloc[:, i].min()
x_max = x_freq_spline.iloc[:, i].mean() + 3 * x_freq_spline.iloc[:, i].std()
xs = np.linspace(x_min, x_max, 1000)
# Spline für alle noch vorhandenen Merkmale
for j in range(len(sig_columns)):
ax.plot(xs, freq_splines[i][j](xs), label=sig_columns[j])
# Hinzufügen von horizontalen Linien um die Stärke der Feature Contribution zu visualisieren
ax.hlines(y=0.25, xmin=x_min, xmax=x_max, colors="black", linestyles="dashed")
ax.hlines(y=-0.25, xmin=x_min, xmax=x_max, colors="black", linestyles="dashed")
ax.hlines(y=0, xmin=x_min, xmax=x_max, colors="black")
# Inline Labels und Legende hinzufügen
labelLines(ax.get_lines(), zorder=2.5)
ax.legend(loc="lower right", ncol=2)
# Layout
ax.set_xlabel(sig_continuous_columns[i])
ax.set_ylabel("Interaction Strengths")
ax.set_ylim((-2, 2))
ax.set_title("Interactions of feature component: " + sig_continuous_columns[i])
fig_freq_spline.suptitle("Abbildung 9: Interaction Strengths")
plt.show()
Zusammenfassend lässt sich sagen, dass der LocalGLMnet Ansatz vor allem in Bezug auf die Erklärbarkeit und somit auch Merkmalsselektion deutliche Vorteile gegenüber klassischen FFNs bietet. Anstatt Erklärbarkeit nur nachträglich durch Ansätze wie Surrogatmodelle oder Partial Dependency Plots zu erzeugen, erlaubt das LocalGLMnet bereits durch seine Struktur bereits ein gewisses Maß an Erklärbarkeit. Ein Vergleich von diesen klassischen Ansätzen zum LocalGLMnet findet sich ebenfalls im ursprünglichen Paper und dessen Anhang.
Richmann & Wüthrich können sich aufgrund der guten Erklärbarkeit bei gleichzeitig hoher Vorhersagegenauigkeit unterschiedliche Anwendungszwecke vorstellen. Es kann sowohl als eigenständiges Netz, aber auch als Surrogatmodell oder als Vorläufer eines klassischen FFNs zur initialen Merkmalselektion verwendet werden [1].
Vor allem für Aktuare, welche den Umgang mit GLMs gewohnt sind, bietet das LocalGLMnet einen interessanten Ansatz bei dem die Erklärbarkeit nicht vollständig für die Vorhersagegenauigkeit aufgegeben werden muss. Ein einziger Nachteil ist, dass das Modell zur Zeit hauptsächlich für strukturierte, möglichst stetige oder binäre Daten optimiert ist. Um weitere Anwendungsfelder zu erschließen, benötigt es weiterer Forschung. [1] Ronald Richman und Mario V. Wüthrich. 2022. LocalGLMnet: interpretable deep learning for tabular data. Scandinavian Actuarial Journal 2022, 1, 71–95. DOI: https://doi.org/10.1080/03461238.2022.2081816
[2] John A. Nelder und Robert W. M. Wedderburn. 1972. Generalized Linear Models. Journal of the Royal Statistical Society, Vol. 135, No. 3, 370–384 DOI: https://doi.org/10.2307/2344614.
[3] Martin Seehafer, Stefan Nörtemann, Jonas Offtermatt, Fabian Transchel, Axel Kiermaier, René Külheim, und Wiltrud Weidner. 2021. Actuarial Data Science. De Gruyter.
[4] Mario V. Wüthrich und Michael Merz. 2023. Statistical Foundations of Actuarial Learning and its Applications. Springer International Publishing, Cham.
[5] Alexander Noll, Robert Salzmann, und Mario V. Wüthrich. 2020. Case Study: French Motor Third-Party Liability Claims http://dx.doi.org/10.2139/ssrn.3164764