Skip to content

Commit

Permalink
Add pct close to EFRS
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Dec 23, 2024
1 parent 1bcab90 commit 6774eb9
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion policyengine_uk_data/utils/reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def loss(weights):
raise ValueError("Relative error contains NaNs")
return rel_error.mean()

def pct_close(weights, t=0.1):
# Return the percentage of metrics that are within t% of the target
estimate = weights @ loss_matrix
abs_error = torch.abs((estimate - targets_array) / (1 + targets_array))
return (abs_error < t).sum() / abs_error.numel()

def dropout_weights(weights, p):
if p == 0:
return weights
Expand All @@ -53,12 +59,15 @@ def dropout_weights(weights, p):
optimizer.zero_grad()
weights_ = dropout_weights(weights, dropout_rate)
l = loss(torch.exp(weights_))
close = pct_close(torch.exp(weights_))
if start_loss is None:
start_loss = l.item()
loss_rel_change = (l.item() - start_loss) / start_loss
l.backward()
if i % 100 == 0:
print(f"Loss: {l.item()}, Rel change: {loss_rel_change}")
print(
f"Loss: {l.item()}, Rel change: {loss_rel_change}, Epoch: {i}, Within 10%: {close:.2%}"
)
optimizer.step()

return torch.exp(weights).detach().numpy()

0 comments on commit 6774eb9

Please sign in to comment.