Skip to content

Commit

Permalink
[xgboost] Fix eval dataset issues
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <[email protected]>
  • Loading branch information
wbo4958 committed Oct 22, 2024
1 parent 1c36fb9 commit eba85f2
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -316,23 +316,13 @@
"## Benchmark and train\n",
"The object `benchmark` is used to compute the elapsed time of some operations.\n",
"\n",
"Training with evaluation sets is also supported in 2 ways, the same as CPU version's behavior:\n",
"Training with evaluation dataset is also supported, the same as CPU version's behavior:\n",
"\n",
"* Call API `setEvalSets` after initializing an XGBoostClassifier\n",
"* Call API `setEvalDataset` after initializing an XGBoostClassifier\n",
"\n",
"```scala\n",
"xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))\n",
"\n",
"```\n",
"\n",
"* Use parameter `eval_sets` when initializing an XGBoostClassifier\n",
"\n",
"```scala\n",
"val paramMapWithEval = paramMap + (\"eval_sets\" -> Map(\"eval\" -> evalSet))\n",
"val xgbClassifierWithEval = new XGBoostClassifier(paramMapWithEval)\n",
"```\n",
"\n",
"Here chooses the API way to set evaluation sets."
"xgbClassifier.setEvalDataset(evalSet)\n",
"```"
]
},
{
Expand All @@ -352,7 +342,7 @@
}
],
"source": [
"xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))"
"xgbClassifier.setEvalDataset(evalSet)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ object Main {
// build XGBoost classifier
val paramMap = xgboostArgs.xgboostParams(Map(
"objective" -> "binary:logistic",
"eval_sets" -> datasets(1).map(ds => Map("eval" -> ds)).getOrElse(Map.empty)
))
val xgbClassifier = new XGBoostClassifier(paramMap)
.setLabelCol(labelName)
// === diff ===
.setFeaturesCol(featureCols)

datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_))

println("\n------ Training ------")
val (model, _) = benchmark.time("train") {
xgbClassifier.fit(datasets(0).get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,23 +304,13 @@
"## Benchmark and train\n",
"The object `benchmark` is used to compute the elapsed time of some operations.\n",
"\n",
"Training with evaluation sets is also supported in 2 ways, the same as CPU version's behavior:\n",
"Training with evaluation dataset is also supported, the same as CPU version's behavior:\n",
"\n",
"* Call API `setEvalSets` after initializing an XGBoostClassifier\n",
"* Call API `setEvalDataset` after initializing an XGBoostClassifier\n",
"\n",
"```scala\n",
"xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))\n",
"\n",
"```\n",
"\n",
"* Use parameter `eval_sets` when initializing an XGBoostClassifier\n",
"\n",
"```scala\n",
"val paramMapWithEval = paramMap + (\"eval_sets\" -> Map(\"eval\" -> evalSet))\n",
"val xgbClassifierWithEval = new XGBoostClassifier(paramMapWithEval)\n",
"```\n",
"\n",
"Here chooses the API way to set evaluation sets."
"xgbClassifier.setEvalDataset(evalSet)\n",
"```"
]
},
{
Expand All @@ -340,7 +330,7 @@
}
],
"source": [
"xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))"
"xgbClassifier.setEvalDataset(evalSet)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ object Main extends Mortgage {

val xgbClassificationModel = if (appArgs.isToTrain) {
// build XGBoost classifier
val xgbParamFinal = appArgs.xgboostParams(commParamMap +
// Add train-eval dataset if specified
("eval_sets" -> datasets(1).map(ds => Map("eval" -> ds)).getOrElse(Map.empty))
)
val xgbParamFinal = appArgs.xgboostParams(commParamMap)
val xgbClassifier = new XGBoostClassifier(xgbParamFinal)
.setLabelCol(labelColName)
.setFeaturesCol(featureNames)

datasets(1).foreach(_ => xgbClassifier.setEvalDataset(_))

// Start training
println("\n------ Training ------")
// Shall we not log the time if it is abnormal, which is usually caused by training failure
Expand Down
22 changes: 6 additions & 16 deletions examples/XGBoost-Examples/taxi/notebooks/scala/taxi-gpu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,13 @@
"## Benchmark and train\n",
"The object `benchmark` is used to compute the elapsed time of some operations.\n",
"\n",
"Training with evaluation sets is also supported in 2 ways, the same as CPU version's behavior:\n",
"Training with evaluation dataset is also supported, the same as CPU version's behavior:\n",
"\n",
"* Call API `setEvalSets` after initializing an XGBoostRegressor\n",
"* Call API `setEvalDataset` after initializing an XGBoostClassifier\n",
"\n",
"```scala\n",
"xgbRegressor.setEvalSets(Map(\"eval\" -> evalSet))\n",
"\n",
"```\n",
"\n",
"* Use parameter `eval_sets` when initializing an XGBoostRegressor\n",
"\n",
"```scala\n",
"val paramMapWithEval = paramMap + (\"eval_sets\" -> Map(\"eval\" -> evalSet))\n",
"val xgbRegressorWithEval = new XGBoostRegressor(paramMapWithEval)\n",
"```\n",
"\n",
"Here chooses the API way to set evaluation sets."
"xgbClassifier.setEvalDataset(evalSet)\n",
"```"
]
},
{
Expand All @@ -356,7 +346,7 @@
}
],
"source": [
"xgbRegressor.setEvalSets(Map(\"eval\" -> evalSet))"
"xgbRegressor.setEvalDataset(evalSet)"
]
},
{
Expand Down Expand Up @@ -609,4 +599,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,13 @@ object Main extends Taxi {

val xgbRegressionModel = if (xgboostArgs.isToTrain) {
// build XGBoost XGBoostRegressor
val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap +
// Add train-eval dataset if specified
("eval_sets" -> datasets(1).map(ds => Map("test" -> ds)).getOrElse(Map.empty))
)
val xgbParamFinal = xgboostArgs.xgboostParams(commParamMap)
val xgbRegressor = new XGBoostRegressor(xgbParamFinal)
.setLabelCol(labelColName)
.setFeaturesCol(featureNames)

datasets(1).foreach(_ => xgbRegressor.setEvalDataset(_))

println("\n------ Training ------")
// Shall we not log the time if it is abnormal, which is usually caused by training failure
val (model, _) = benchmark.time("train") {
Expand Down

0 comments on commit eba85f2

Please sign in to comment.