Skip to content

Commit

Permalink
added extra tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aglenis committed Jan 12, 2025
1 parent 075b351 commit 5ba597d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
18 changes: 17 additions & 1 deletion exareme2/algorithms/flower/xgboost/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,27 @@ class CustomFedXgbBagging(FedXgbBagging):
def __init__(self, num_rounds, **kwargs):
super().__init__(**kwargs)
self.num_rounds = num_rounds
self.initial_auc = 0.0

def aggregate_evaluate(self, rnd, results, failures):
aggregated_metrics = super().aggregate_evaluate(rnd, results, failures)
if rnd == 1:
self.initial_auc = aggregated_metrics["AUC"]
if rnd == self.num_rounds:
post_result({"metrics_aggregated": aggregated_metrics})
print(aggregated_metrics)
curr_auc = aggregated_metrics["AUC"]
auc_diff = curr_auc - self.initial_auc
auc_ascending = ""
if auc_diff > 0.0:
auc_ascending = "correct"
else:
auc_ascending = "not_correct"
post_result(
{
"metrics_aggregated": aggregated_metrics,
"auc_ascending": auc_ascending,
}
)
return aggregated_metrics


Expand Down
2 changes: 2 additions & 0 deletions tests/algorithm_validation_tests/flower/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ def test_xgboost(get_algorithm_result):
# {'metrics_aggregated': {'AUC': 0.7575790087463558}}
print(algorithm_result)
auc_aggregated = algorithm_result["metrics_aggregated"][1]["AUC"]
auc_ascending = algorithm_result["auc_ascending"]
assert auc_aggregated > 0.0
assert auc_ascending == "correct"

0 comments on commit 5ba597d

Please sign in to comment.