Skip to content

Commit

Permalink
Merge pull request #26 from ai2es/bugfix-cat-training
Browse files Browse the repository at this point in the history
Bugfix evidential training
  • Loading branch information
djgagne authored Aug 9, 2024
2 parents a2fca35 + b8ab834 commit 07eda38
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
8 changes: 6 additions & 2 deletions applications/train_classifier_ptype.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def trainer(conf, evaluate=True, data_split=0, mc_forward_passes=0):
output_features = conf["output_features"]
metric = conf["metric"]
# flag for using the evidential model
if conf["model"]["loss"] == "dirichlet":
if conf["model"]["loss"] == "evidential":
use_uncertainty = True
else:
use_uncertainty = False
Expand Down Expand Up @@ -177,6 +177,10 @@ def trainer(conf, evaluate=True, data_split=0, mc_forward_passes=0):
x = scaled_data[f"{name}_x"]
if use_uncertainty:
pred_probs, u, ale, epi = mlp.predict(x, return_uncertainties=True)
pred_probs = pred_probs.numpy()
u = u.numpy()
ale = ale.numpy()
epi = epi.numpy()
entropy = np.zeros(pred_probs.shape)
mutual_info = np.zeros(pred_probs.shape)
elif mc_forward_passes > 0: # Compute epistemic uncertainty with MC dropout
Expand All @@ -185,7 +189,7 @@ def trainer(conf, evaluate=True, data_split=0, mc_forward_passes=0):
x, mc_forward_passes=mc_forward_passes)
u = np.zeros(pred_probs.shape)
else:
pred_probs = mlp.predict(x)
pred_probs = mlp.predict(x, return_uncertainties=False)
ale = np.zeros(pred_probs.shape)
u = np.zeros(pred_probs.shape)
epi = np.zeros(pred_probs.shape)
Expand Down
31 changes: 15 additions & 16 deletions config/ptype/evidential.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ case_studies:
- '2021-02-17'
- '2021-02-18'
- '2021-02-19'
data_path: /glade/p/cisl/aiml/ai2es/winter_ptypes/ptype_qc/mPING_interpolated_QC2.parquet
data_path: /glade/campaign/cisl/aiml/ai2es/winter_ptypes/ptype_qc/mPING_hourafter_interpolated_QC3.parquet
direction: max
ensemble:
mc_steps: 0
Expand All @@ -147,24 +147,23 @@ input_features:
- VGRD_m/s
metric: val_ave_acc
model:
activation: leaky
annealing_coeff: 34.593686950910275
balanced_classes: 1
batch_size: 100
dropout_alpha: 0.20146936081973893
epochs: 1000
hidden_layers: 2
hidden_neurons: 6461
activation: leaky_relu
annealing_coeff: 34
batch_size: 1130
dropout_alpha: 0.11676011477923032
epochs: 100
evidential: true
n_inputs: 84
hidden_layers: 4
hidden_neurons: 212
l2_weight: 0.000881889591229087
loss: evidential
loss_weights:
- 58.64242174310205
- 94.59680461256323
- 124.5896569779261
- 227.38800030539545
lr: 0.0027750619126744817
lr: 0.004800502096767794
n_classes: 4
optimizer: adam
output_activation: linear
use_dropout: 1
verbose: 0
verbose: 1
mping_path: /glade/p/cisl/aiml/ai2es/winter_ptypes/precip_rap/mPING_mixture/
output_features:
- ra_percent
Expand Down
10 changes: 6 additions & 4 deletions mlguess/keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
import os
import keras
import numpy as np
from hagelslag.evaluation.ProbabilityMetrics import DistributedROC
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

logger = logging.getLogger(__name__)


def get_callbacks(config: Dict[str, str], path_extend=False) -> List[Callback]:
callbacks = []

Expand Down Expand Up @@ -85,7 +88,6 @@ def get_config(self):
return {}



class MetricsCallback(keras.callbacks.Callback):
def __init__(self, x, y, name="val", n_bins=10, use_uncertainty=False, **kwargs):
super().__init__()
Expand All @@ -99,10 +101,11 @@ def __init__(self, x, y, name="val", n_bins=10, use_uncertainty=False, **kwargs)
self.bin_uppers = bin_boundaries[1:]

def on_epoch_end(self, epoch, logs={}):
pred_probs = np.asarray(self.model.predict(self.x))
if self.use_uncertainty:
pred_probs, _, _, _ = calc_prob_uncertainty(pred_probs)
pred_probs, _, _, _ = self.model.predict(self.x, return_uncertainties=True)
pred_probs = pred_probs.numpy()
else:
pred_probs = np.asarray(self.model.predict(self.x, return_uncertainties=False))
logs[f"{self.name}_csi"] = self.mean_csi(pred_probs)
true_labels = np.argmax(self.y, 1)
pred_labels = np.argmax(pred_probs, 1)
Expand Down Expand Up @@ -192,4 +195,3 @@ def ece(self, true_labels, pred_probs):
pass
mean = np.mean(ece) if np.isfinite(np.mean(ece)) else self.bin_lowers.shape[0]
return mean

7 changes: 4 additions & 3 deletions mlguess/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, hidden_layers=2, hidden_neurons=64, evidential=False, activat
annealing_coeff=1.0, use_noise=False, noise_sd=0.0, lr=0.001, use_dropout=False, dropout_alpha=0.2,
batch_size=128, epochs=2, kernel_reg=None, l1_weight=0.0, l2_weight=0.0, sgd_momentum=0.9,
adam_beta_1=0.9, adam_beta_2=0.999, epsilon=1e-7, decay=0, verbose=0, random_state=1000, n_classes=2,
n_inputs=42, callbacks=None, **kwargs):
n_inputs=42, callbacks=[], **kwargs):

super().__init__(**kwargs)
self.hidden_layers = hidden_layers
Expand Down Expand Up @@ -184,7 +184,7 @@ def fit(self, x=None, y=None, **kwargs):
report_epoch_callback = ReportEpoch(e)
self.loss = evidential_cat_loss(evi_coef=self.annealing_coeff,
epoch_callback=report_epoch_callback)
self.callbacks = [report_epoch_callback]
self.callbacks.append(report_epoch_callback)

super().compile(loss=self.loss,
optimizer=self.optimizer_obj,
Expand Down Expand Up @@ -257,6 +257,7 @@ def predict_dropout(self, x, mc_forward_passes=10, batch_size=None):
def get_config(self):
base_config = super().get_config()
parameter_config = {hp: getattr(self, hp) for hp in self.hyperparameters}
parameter_config['callbacks'] = []
return {**base_config, **parameter_config}


Expand Down Expand Up @@ -427,7 +428,7 @@ def fit(self, x=None, y=None, **kwargs):
self.loss = gaussian_nll
super().compile(optimizer=self.optimizer_obj, loss=self.loss)
hist = super().fit(x, y, epochs=self.epochs, batch_size=self.batch_size, **kwargs)
self.training_var = np.var(x, axis=-1)
self.training_var = np.var(y, axis=-1)

return hist

Expand Down

0 comments on commit 07eda38

Please sign in to comment.