diff --git a/src/graphnet/training/weight_fitting.py b/src/graphnet/training/weight_fitting.py index ae9d63f6b..2e4a1879d 100644 --- a/src/graphnet/training/weight_fitting.py +++ b/src/graphnet/training/weight_fitting.py @@ -76,9 +76,16 @@ def fit( transform: A callable method that transform the variable into a desired space. E.g. np.log10 for energy. If given, fitting will happen in this space. - db_count_norm: If given, the total sum of the weights for the given db will be this number. + db_count_norm: If given, the total sum of the weights for the given + db will be this number. + automatic_log_bins: If True, the bins are generated as a log10 space + between the min and max of the variable. + max_weight: If given, the weights are capped such that the sum of a + single event's weight cannot exceed this number times the sum of + all weights. **kwargs: Additional arguments passed to `_fit_weights`. + Returns: DataFrame that contains weights, event_nos. """ @@ -191,6 +198,8 @@ def _fit_weights( # type: ignore[override] 1/(1+a*(x_low -x)) curve. alpha: A scalar factor that controls how fast the weights above x_low approaches zero. Larger means faster. + percentile: If True, x_low is interpreted as a percentile of the + truth variable. Returns: The fitted weights.