Skip to content

Commit

Permalink
Merge pull request #717 from Aske-Rosted/weight_utils
Browse files Browse the repository at this point in the history
weights utilities + fix
  • Loading branch information
Aske-Rosted authored May 29, 2024
2 parents 1d942a4 + e338c05 commit 6837b28
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions src/graphnet/training/weight_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def fit(
selection: Optional[List[int]] = None,
transform: Optional[Callable] = None,
db_count_norm: Optional[int] = None,
automatic_log_bins: bool = False,
max_weight: Optional[float] = None,
**kwargs: Any,
) -> pd.DataFrame:
"""Fit weights.
Expand All @@ -74,9 +76,16 @@ def fit(
transform: A callable method that transform the variable into a
desired space. E.g. np.log10 for energy. If given, fitting will
happen in this space.
db_count_norm: If given, the total sum of the weights for the given db will be this number.
db_count_norm: If given, the total sum of the weights for the given
db will be this number.
automatic_log_bins: If True, the bins are generated as a log10
space between the min and max of the variable.
max_weight: If given, the weights are capped such that a single
event weight cannot exceed this number times the sum of
all weights.
**kwargs: Additional arguments passed to `_fit_weights`.
Returns:
DataFrame that contains weights, event_nos.
"""
Expand All @@ -86,6 +95,9 @@ def fit(
self._selection = selection
self._bins = bins
self._transform = transform
if max_weight is not None:
assert max_weight > 0 and max_weight < 1
self._max_weight = max_weight

if weight_name is None:
self._weight_name = self._generate_weight_name()
Expand All @@ -95,12 +107,29 @@ def fit(
truth = self._get_truth(self._variable, self._selection)
if self._transform is not None:
truth[self._variable] = self._transform(truth[self._variable])
if automatic_log_bins:
assert isinstance(bins, int)
self._bins = np.logspace(
np.log10(truth[self._variable].min()),
np.log10(truth[self._variable].max() + 1),
bins,
)

weights = self._fit_weights(truth, **kwargs)
if self._max_weight is not None:
weights[self._weight_name] = np.where(
weights[self._weight_name]
> weights[self._weight_name].sum() * self._max_weight,
weights[self._weight_name].sum() * self._max_weight,
weights[self._weight_name],
)

if db_count_norm is not None:
weights[self._weight_name] = (
weights[self._weight_name] * db_count_norm / len(weights)
weights[self._weight_name]
* db_count_norm
/ weights[self._weight_name].sum()
)

if add_to_database:
create_table_and_save_to_sql(
weights, self._weight_name, self._database_path
Expand Down Expand Up @@ -154,7 +183,11 @@ class BjoernLow(WeightFitter):
"""

def _fit_weights( # type: ignore[override]
self, truth: pd.DataFrame, x_low: float, alpha: float = 0.05
self,
truth: pd.DataFrame,
x_low: float,
alpha: float = 0.05,
percentile: bool = False,
) -> pd.DataFrame:
"""Fit per-event weights.
Expand All @@ -165,6 +198,8 @@ def _fit_weights( # type: ignore[override]
1/(1+a*(x_low -x)) curve.
alpha: A scalar factor that controls how fast the weights above
x_low approaches zero. Larger means faster.
percentile: If True, x_low is interpreted as a percentile of the
truth variable.
Returns:
The fitted weights.
Expand All @@ -186,6 +221,11 @@ def _fit_weights( # type: ignore[override]
weights=truth[self._weight_name],
)
c = bin_counts.max()

if percentile:
assert 0 < x_low < 1
x_low = np.quantile(truth[self._variable], x_low)

slice = truth[self._variable][truth[self._variable] > x_low]
truth[self._weight_name][truth[self._variable] > x_low] = 1 / (
1 + alpha * (slice - x_low)
Expand Down

0 comments on commit 6837b28

Please sign in to comment.