Skip to content

Commit

Permalink
hierarchical iterations as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixWick committed Oct 16, 2023
1 parent ecb603c commit 5a1d7c8
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions cyclic_boosting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ class CyclicBoostingBase(
idea of such hierarchical iterations is to support the modeling of
hierarchical or causal effects (e.g., mitigate confounding).
If this argument is omitted, such no hierarchical iterations are run.
If this argument is not explicitly set, no such hierarchical iterations
are run.
feature_properties: :obj:`dict` of :obj:`int`
Dictionary listing the names of all features for the training as keys
Expand Down Expand Up @@ -268,6 +269,7 @@ def __init__(
self,
feature_groups=None,
hierarchical_feature_groups=None,
training_iterations_hierarchical_features=3,
feature_properties: Optional[Dict[int, int]] = None,
weight_column: Optional[Union[str, int, None]] = None,
prior_prediction_column: Optional[Union[str, int, None]] = None,
Expand Down Expand Up @@ -297,6 +299,7 @@ def __init__(
for fg in self.hierarchical_feature_groups:
hierarchical_feature = create_feature_id(fg)
self.hierarchical_features.append(hierarchical_feature.feature_group)
self.training_iterations_hierarchical_features = training_iterations_hierarchical_features
self.feature_importances = {}
self.aggregate = aggregate

Expand Down Expand Up @@ -543,7 +546,10 @@ def get_state(self) -> Dict[str, Any]:
"globale_scale": self.global_scale_,
"insample_loss": self.insample_loss_,
}
if self.hierarchical_feature_groups is not None and self.iteration_ < 3:
if (
self.hierarchical_feature_groups is not None
and self.iteration_ < self.training_iterations_hierarchical_features
):
est_state["features"] = [
feature for feature in self.features if feature.feature_group in self.hierarchical_features
]
Expand Down Expand Up @@ -728,7 +734,7 @@ def _fit_main(self, X: np.ndarray, y: np.ndarray, pred: CBLinkPredictionsFactors
for i, feature, pf_data in self.cb_features(X, y, pred, prefit_data):
if (
self.hierarchical_feature_groups is not None
and self.iteration_ < 3
and self.iteration_ < self.training_iterations_hierarchical_features
and feature.feature_group not in self.hierarchical_features
):
feature.factors_link_old = feature.factors_link.copy()
Expand Down Expand Up @@ -874,7 +880,7 @@ def transform(self, X: pd.DataFrame, y: Optional[np.ndarray] = None) -> pd.DataF

def _check_stop_criteria(self, iterations: int, convergence_parameters: ConvergenceParameters) -> bool:
"""
Checks the stop criteria and returns True if none are satisfied else False.
Checks the stop criteria and returns True if at least one is satisfied.
You can check the stop criteria in the estimated parameter
`stop_criteria_`.
Expand Down Expand Up @@ -920,7 +926,10 @@ def _check_stop_criteria(self, iterations: int, convergence_parameters: Converge
"analysis plots."
)

if iterations <= 3 and self.hierarchical_feature_groups is not None:
if (
iterations <= self.training_iterations_hierarchical_features
and self.hierarchical_feature_groups is not None
):
veto_hierarchical = True

self.stop_criteria_ = (stop_iterations, stop_factor_change, stop_loss_change)
Expand Down

0 comments on commit 5a1d7c8

Please sign in to comment.