Skip to content

Commit

Permalink
Updated handling of metrics calculations during predict/validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Jul 11, 2024
1 parent 8eb8214 commit 9071135
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 57 deletions.
2 changes: 1 addition & 1 deletion applications/predict_regressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main(rank, world_size, conf, trial=False):
if split == "train":
df = train_loader.dataset.train_data
elif split == "valid":
df = train_loader.dataset.valid_data
df = valid_loader.dataset.valid_data
elif split == "test":
df = train_loader.dataset.test_data

Expand Down
2 changes: 1 addition & 1 deletion applications/torch_dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, conf, split='train'):

# Compute var on the total training data set
self.training_var = [
np.var(self.y_scaler.transform(train_data[output_cols]))
np.var(self.y_scaler.transform(train_data[output_cols])[:, i])
for i in range(self.train_data[output_cols].shape[-1])
]

Expand Down
16 changes: 11 additions & 5 deletions applications/train_regressor_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TorchFSDPCheckpointIO
)
from mlguess.torch.trainer import Trainer
from mlguess.torch.regression_losses import LipschitzMSELoss
from mlguess.torch.regression_losses import LipschitzMSELoss, EvidentialRegressionLoss
from mlguess.torch.models import seed_everything, DNN
from mlguess.regression_metrics import regression_metrics

Expand Down Expand Up @@ -69,7 +69,7 @@ def load_dataset_and_sampler(conf, world_size, rank, is_train, seed=42):
rank=rank,
seed=seed,
shuffle=is_train,
drop_last=True
drop_last=(not is_train)
)
flag = 'training' if is_train else 'validation'
logging.info(f"Loaded a {flag} torch dataset, and a distributed sampler")
Expand Down Expand Up @@ -187,9 +187,9 @@ def main(rank, world_size, conf, trial=False):
batch_size=valid_batch_size,
shuffle=False,
sampler=valid_sampler,
pin_memory=False,
pin_memory=True,
num_workers=valid_thread_workers,
drop_last=True
drop_last=False
)

# model
Expand All @@ -213,6 +213,8 @@ def main(rank, world_size, conf, trial=False):
model, optimizer, scheduler, scaler = load_model_states_and_optimizer(conf, model, device)

# Train and validation losses
# train_criterion = EvidentialRegressionLoss(coef=10.84134458514458)
# valid_criterion = EvidentialRegressionLoss(coef=10.84134458514458)

train_criterion = LipschitzMSELoss(**conf["train_loss"])
valid_criterion = LipschitzMSELoss(**conf["valid_loss"])
Expand Down Expand Up @@ -248,6 +250,10 @@ def __init__(self, config, metric="val_loss", device="cpu"):

def train(self, trial, conf):

conf['trainer']['train_batch_size'] = conf['data']['batch_size']
conf['trainer']['valid_batch_size'] = conf['data']['batch_size']
conf['valid_loss']['factor'] = conf['train_loss']['factor']

try:
return main(0, 1, conf, trial=trial)

Expand All @@ -257,7 +263,7 @@ def train(self, trial, conf):
f"Pruning trial {trial.number} due to CUDA memory overflow: {str(E)}."
)
raise optuna.TrialPruned()
elif "non-singleton" in str(E):
elif "non-singleton" in str(E) or "nan" in str(E):
logging.warning(
f"Pruning trial {trial.number} due to shape mismatch: {str(E)}."
)
Expand Down
31 changes: 15 additions & 16 deletions mlguess/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,41 @@ def regression_metrics(y_true, y_pred, total=None, split="val"):
# metrics[f"{split}_crps_ss"] = r2_score(result['bin'], result['crps'], sample_weight=result["count"])
# metrics[f"{split}_rmse_ss"] = r2_score(result['bin'], result['rmse'], sample_weight=result["count"])

rmse_ss = rmse_crps_skill_scores(y_true, y_pred, total, filter_top_percentile=5)
rmse_ss = rmse_crps_skill_scores(y_true, y_pred, total, filter_top_percentile=0)
metrics[f"{split}_r2_rmse_sigma"] = rmse_ss["r2_rmse"]
metrics[f"{split}_r2_crps_sigma"] = rmse_ss["r2_crps"]

return metrics


def rmse_crps_skill_scores(y_true, y_pred, total, filter_top_percentile=0):
# Initialize dictionaries to store r2_rmse and r2_crps for each column
def rmse_crps_skill_scores(y, mu, total, filter_top_percentile=0):
# Create a grid of subplots with the number of rows determined by the length of output_cols
r2_rmse_dict = {}
r2_crps_dict = {}

# Get the number of columns from y_pred
num_cols = y_pred.shape[1]

# Loop over the columns
for col in range(num_cols):
# Loop over the length of output_cols
num_cols = y.shape[-1]
for col in range(y.shape[-1]):
result = calculate_skill_score(
y_true[:, col], # Use y_true for the true values
y_pred[:, col], # Use y_pred for the predicted values
total[:, col],
y[:, col],
mu[:, col],
np.sqrt(total)[:, col],
num_bins=100,
log=True,
filter_top_percentile=filter_top_percentile
)

r2_rmse = r2_score(result['bin'], result['rmse'])
r2_crps = r2_score(result['bin'], result['crps'])
r2_rmse_dict[col] = r2_rmse
r2_crps_dict[col] = r2_crps

if np.isnan(r2_rmse):
r2_rmse = -10
# if np.isnan(r2_rmse):
# r2_rmse = -10

# Check if r2_crps is NaN and replace it with -10
if np.isnan(r2_crps):
r2_crps = -10
# # Check if r2_crps is NaN and replace it with -10
# if np.isnan(r2_crps):
# r2_crps = -10

# Calculate the average of r2_rmse and r2_crps
avg_r2_rmse = sum(r2_rmse_dict.values()) / num_cols
Expand Down
6 changes: 3 additions & 3 deletions mlguess/torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def predict_uncertainty(self, input, y_scaler=None):
if y_scaler:
mu = y_scaler.inverse_transform(mu)

for i in range(mu.shape[-1]):
aleatoric[:, i] *= self.training_var[i]
epistemic[:, i] *= self.training_var[i]
for i in range(mu.shape[-1]):
aleatoric[:, i] *= self.training_var[i]
epistemic[:, i] *= self.training_var[i]

return mu, aleatoric, epistemic

Expand Down
15 changes: 7 additions & 8 deletions mlguess/torch/regression_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def normal_inverse_gamma_reg(self, y, gamma, v, alpha, beta):
evi = 2 * v + alpha
return error * evi

def __call__(self, y, pred):
def __call__(self, gamma, v, alpha, beta, y):
"""Calculate the Evidential Regression Loss"""
gamma, v, alpha, beta = pred
loss_nll = self.normal_inverse_gamma_nll(y, gamma, v, alpha, beta)
loss_reg = self.normal_inverse_gamma_reg(y, gamma, v, alpha, beta)
return loss_nll.mean() + self.coef * loss_reg.mean()
Expand Down Expand Up @@ -206,9 +205,9 @@ class EvidenceRegularizer(torch.nn.modules.loss._Loss):
Reference: https://www.mit.edu/~amini/pubs/pdf/deep-evidential-regression.pdf
Source: https://github.com/deargen/MT-ENet/tree/468822188f52e517b1ee8e386eea607b2b7d8829
"""
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', factor=0.1):
def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', coef=0.1):
super(EvidenceRegularizer, self).__init__(size_average, reduce, reduction)
self.factor = factor
self.coef = coef

def forward(self, gamma: torch.Tensor, nu: torch.Tensor, alpha: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
Expand All @@ -224,7 +223,7 @@ def forward(self, gamma: torch.Tensor, nu: torch.Tensor, alpha: torch.Tensor,
Loss = |y - gamma|*(2*nu + alpha) * factor
"""
loss_value = torch.abs(target - gamma)*(2*nu + alpha) * self.factor
loss_value = torch.abs(target - gamma)*(2*nu + alpha) * self.coef
if self.reduction == 'mean':
return loss_value.mean()
elif self.reduction == 'sum':
Expand All @@ -234,13 +233,13 @@ def forward(self, gamma: torch.Tensor, nu: torch.Tensor, alpha: torch.Tensor,


class LipschitzMSELoss(torch.nn.Module):
def __init__(self, tol=1e-8, factor=0.1, reduction='mean'):
def __init__(self, tol=1e-8, coef=0.1, reduction='mean'):
super(LipschitzMSELoss, self).__init__()
self.tol = tol
self.factor = factor
self.coef = coef
self.reduction = reduction
self.evidential_marginal_likelihood = EvidentialMarginalLikelihood(reduction=reduction)
self.evidence_regularizer = EvidenceRegularizer(factor=factor, reduction=reduction)
self.evidence_regularizer = EvidenceRegularizer(coef=coef, reduction=reduction)

def forward(self, gamma, nu, alpha, beta, target):
loss = self.evidential_marginal_likelihood(gamma, nu, alpha, beta, target)
Expand Down
47 changes: 24 additions & 23 deletions mlguess/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def train_one_epoch(
commit_loss = 0.0

with autocast(enabled=amp):

x = x.to(self.device)
y_pred = self.model(x)
gamma, nu, alpha, beta = y_pred
Expand All @@ -91,7 +90,7 @@ def train_one_epoch(
# Metrics
y_pred = (_.cpu().detach() for _ in y_pred)
mu, ale, epi = self.model.predict_uncertainty(y_pred, y_scaler=transform)
total = np.sqrt(ale + epi)
total = ale + epi
if transform:
y = transform.inverse_transform(y.cpu())
metrics_dict = metrics(y, mu, total, split="train")
Expand Down Expand Up @@ -192,7 +191,7 @@ def validate(
# Metrics
y_pred = (_.cpu() for _ in y_pred)
mu, ale, epi = self.model.predict_uncertainty(y_pred, y_scaler=transform)
total = np.sqrt(ale + epi)
total = ale + epi
if transform:
y = transform.inverse_transform(y.cpu())
metrics_dict = metrics(y, mu, total, split="valid")
Expand Down Expand Up @@ -235,20 +234,14 @@ def validate(

def predict(self, conf, test_loader, criterion, metrics, transform=None, split=None):
self.model.eval()
valid_batches_per_epoch = conf['trainer']['valid_batches_per_epoch']
distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False

results_dict = defaultdict(list)
mu_list, ale_list, epi_list, y_list = [], [], [], []

# Set up a custom tqdm
valid_batches_per_epoch = (
valid_batches_per_epoch if 0 < valid_batches_per_epoch < len(test_loader) else len(test_loader)
)

batch_group_generator = tqdm.tqdm(
enumerate(test_loader),
total=valid_batches_per_epoch,
total=len(test_loader),
leave=True,
disable=True if self.rank > 0 else False
)
Expand All @@ -273,21 +266,18 @@ def predict(self, conf, test_loader, criterion, metrics, transform=None, split=N
batch_loss = torch.Tensor([loss.item()]).cuda(self.device)
if distributed:
torch.distributed.barrier()
results_dict["loss"].append(batch_loss[0].item())
results_dict[f"{split}_loss"].append(batch_loss[0].item())

# Print to tqdm
to_print = f"{split} loss: {np.mean(results_dict['loss']):.6f}"
to_print = f'{split} loss: {np.mean(results_dict[f"{split}_loss"]):.6f}'
if self.rank == 0:
batch_group_generator.set_description(to_print)

if i >= valid_batches_per_epoch and i > 0:
break

# Concatenate arrays
mu = np.concatenate(mu_list, axis=0)
ale = np.concatenate(ale_list, axis=0)
epi = np.concatenate(epi_list, axis=0)
total = np.sqrt(ale + epi)
total = ale + epi
y = np.concatenate(y_list, axis=0)

if transform:
Expand All @@ -300,7 +290,7 @@ def predict(self, conf, test_loader, criterion, metrics, transform=None, split=N
if distributed:
dist.all_reduce(value, dist.ReduceOp.AVG, async_op=False)
results_dict[name].append(value[0].item())
results_dict["loss"] = np.mean(results_dict["loss"])
results_dict[f"{split}_loss"].append(np.mean(results_dict[f"{split}_loss"]))

# Shutdown the progbar
batch_group_generator.close()
Expand Down Expand Up @@ -393,14 +383,25 @@ def fit(

else:

valid_results = self.validate(
epoch,
valid_results = self.predict(
conf,
valid_loader,
valid_criterion,
metrics,
transform
)
transform,
split="valid"
)["metrics"]

# this version of validation computes metrics batch-by-batch, which may affect metrics computed through binning

# valid_results = self.validate(
# epoch,
# conf,
# valid_loader,
# valid_criterion,
# metrics,
# transform
# )

#################
#
Expand Down Expand Up @@ -510,8 +511,8 @@ def fit(
gc.collect()

# Report result to the trial
if trial:
trial.report(results_dict[training_metric][-1], step=epoch)
# if trial:
# trial.report(results_dict[training_metric][-1], step=epoch)

# Stop training if we have not improved after X epochs (stopping patience)
best_epoch = [
Expand Down

0 comments on commit 9071135

Please sign in to comment.