From e36bc67a8b6289c0fd18afa9f9e6af15e95c3588 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 11:42:00 +0900 Subject: [PATCH 1/3] add functionality --- src/graphnet/training/weight_fitting.py | 37 +++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index 0766facf8..ae9d63f6b 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -55,6 +55,8 @@ def fit( selection: Optional[List[int]] = None, transform: Optional[Callable] = None, db_count_norm: Optional[int] = None, + automatic_log_bins: bool = False, + max_weight: Optional[float] = None, **kwargs: Any, ) -> pd.DataFrame: """Fit weights. @@ -86,6 +88,9 @@ def fit( self._selection = selection self._bins = bins self._transform = transform + if max_weight is not None: + assert max_weight > 0 and max_weight < 1 + self._max_weight = max_weight if weight_name is None: self._weight_name = self._generate_weight_name() @@ -95,12 +100,29 @@ def fit( truth = self._get_truth(self._variable, self._selection) if self._transform is not None: truth[self._variable] = self._transform(truth[self._variable]) + if automatic_log_bins: + assert isinstance(bins, int) + self._bins = np.logspace( + np.log10(truth[self._variable].min()), + np.log10(truth[self._variable].max() + 1), + bins, + ) + weights = self._fit_weights(truth, **kwargs) + if self._max_weight is not None: + weights[self._weight_name] = np.where( + weights[self._weight_name] + > weights[self._weight_name].sum() * self._max_weight, + weights[self._weight_name].sum() * self._max_weight, + weights[self._weight_name], + ) + if db_count_norm is not None: weights[self._weight_name] = ( - weights[self._weight_name] * db_count_norm / len(weights) + weights[self._weight_name] + * db_count_norm + / weights[self._weight_name].sum() ) - if add_to_database: create_table_and_save_to_sql( weights, self._weight_name, self._database_path @@ -154,7 +176,11 @@ class BjoernLow(WeightFitter): """ def _fit_weights( # type: ignore[override] - self, truth: pd.DataFrame, x_low: float, alpha: float = 0.05 + self, + truth: pd.DataFrame, + x_low: float, + alpha: float = 0.05, + percentile: bool = False, ) -> pd.DataFrame: """Fit per-event weights. @@ -186,6 +212,11 @@ def _fit_weights( # type: ignore[override] weights=truth[self._weight_name], ) c = bin_counts.max() + + if percentile: + assert 0 < x_low < 1 + x_low = np.quantile(truth[self._variable], x_low) + slice = truth[self._variable][truth[self._variable] > x_low] truth[self._weight_name][truth[self._variable] > x_low] = 1 / ( 1 + alpha * (slice - x_low) From b452e3ba93ad48fb826e14d4bc0c5e5511f888f4 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 14 May 2024 14:00:38 +0900 Subject: [PATCH 2/3] update docstrings --- src/graphnet/training/weight_fitting.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index ae9d63f6b..2e4a1879d 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -76,9 +76,16 @@ def fit( transform: A callable method that transform the variable into a desired space. E.g. np.log10 for energy. If given, fitting will happen in this space. - db_count_norm: If given, the total sum of the weights for the given db will be this number. + db_count_norm: If given, the total sum of the weights for the given + db will be this number. + automatic_log_bins: If True, the bins are generated as a log10 space + between the min and max of the variable. + max_weight: If given, the weights are capped such that the sum of a + single event's weight cannot exceed this number times the sum of + all weights. **kwargs: Additional arguments passed to `_fit_weights`. + Returns: DataFrame that contains weights, event_nos. """ @@ -191,6 +198,8 @@ def _fit_weights( # type: ignore[override] 1/(1+a*(x_low -x)) curve. alpha: A scalar factor that controls how fast the weights above x_low approaches zero. Larger means faster. + percentile: If True, x_low is interpreted as a percentile of the + truth variable. Returns: The fitted weights. From e338c058c37f9377f92ad2784f6e5c42df7d3008 Mon Sep 17 00:00:00 2001 From: "askerosted@gmail.com" Date: Tue, 28 May 2024 14:23:43 +0900 Subject: [PATCH 3/3] Code climate fix --- src/graphnet/training/weight_fitting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index 2e4a1879d..e66c2d4c5 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -78,10 +78,10 @@ def fit( happen in this space. db_count_norm: If given, the total sum of the weights for the given db will be this number. - automatic_log_bins: If True, the bins are generated as a log10 space - between the min and max of the variable. - max_weight: If given, the weights are capped such that the sum of a - single event's weight cannot exceed this number times the sum of + automatic_log_bins: If True, the bins are generated as a log10 + space between the min and max of the variable. + max_weight: If given, the weights are capped such that a single + event weight cannot exceed this number times the sum of all weights. **kwargs: Additional arguments passed to `_fit_weights`.