diff --git a/policyengine_uk_data/utils/reweight.py b/policyengine_uk_data/utils/reweight.py index 384168b..f07d669 100644 --- a/policyengine_uk_data/utils/reweight.py +++ b/policyengine_uk_data/utils/reweight.py @@ -33,6 +33,12 @@ def loss(weights): raise ValueError("Relative error contains NaNs") return rel_error.mean() + def pct_close(weights, t=0.1): + # Return the percentage of metrics that are within t% of the target + estimate = weights @ loss_matrix + abs_error = torch.abs((estimate - targets_array) / (1 + targets_array)) + return (abs_error < t).sum() / abs_error.numel() + def dropout_weights(weights, p): if p == 0: return weights @@ -53,12 +59,15 @@ def dropout_weights(weights, p): optimizer.zero_grad() weights_ = dropout_weights(weights, dropout_rate) l = loss(torch.exp(weights_)) + close = pct_close(torch.exp(weights_)) if start_loss is None: start_loss = l.item() loss_rel_change = (l.item() - start_loss) / start_loss l.backward() if i % 100 == 0: - print(f"Loss: {l.item()}, Rel change: {loss_rel_change}") + print( + f"Loss: {l.item()}, Rel change: {loss_rel_change}, Epoch: {i}, Within 10%: {close:.2%}" + ) optimizer.step() return torch.exp(weights).detach().numpy()