From 81b587fb521dc49cca43141d4730ff0561fddef6 Mon Sep 17 00:00:00 2001 From: john Date: Thu, 19 Oct 2023 16:27:32 +0200 Subject: [PATCH] fix and update tests --- tests/testthat/test_glmnet_surv_cv_glmnet.R | 2 +- tests/testthat/test_glmnet_surv_glmnet.R | 2 +- tests/testthat/test_paramtest_glmnet_surv_cv_glmnet.R | 6 +++++- tests/testthat/test_paramtest_glmnet_surv_glmnet.R | 5 ++++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test_glmnet_surv_cv_glmnet.R b/tests/testthat/test_glmnet_surv_cv_glmnet.R index a40385a61..dc7f3ae9c 100644 --- a/tests/testthat/test_glmnet_surv_cv_glmnet.R +++ b/tests/testthat/test_glmnet_surv_cv_glmnet.R @@ -12,7 +12,7 @@ test_that("autotest", { test_that("selected_features", { task = tsk("gbcs") - learner = lrn("surv.glmnet") + learner = lrn("surv.cv_glmnet") learner$train(task) expect_equal( diff --git a/tests/testthat/test_glmnet_surv_glmnet.R b/tests/testthat/test_glmnet_surv_glmnet.R index e06e42b89..9850bab3b 100644 --- a/tests/testthat/test_glmnet_surv_glmnet.R +++ b/tests/testthat/test_glmnet_surv_glmnet.R @@ -12,7 +12,7 @@ test_that("autotest", { test_that("selected_features", { task = tsk("gbcs") - learner = lrn("surv.cv_glmnet") + learner = lrn("surv.glmnet") learner$train(task) expect_equal( diff --git a/tests/testthat/test_paramtest_glmnet_surv_cv_glmnet.R b/tests/testthat/test_paramtest_glmnet_surv_cv_glmnet.R index 52b2780d2..489f06741 100644 --- a/tests/testthat/test_paramtest_glmnet_surv_cv_glmnet.R +++ b/tests/testthat/test_paramtest_glmnet_surv_cv_glmnet.R @@ -20,7 +20,11 @@ test_that("predict surv.cv_glmnet", { exclude = c( "object", # handled via mlr3 "newx", # handled via mlr3 - "predict.gamma" # renamed from gamma + "predict.gamma", # renamed from gamma + "offset", # for distr prediction + "newoffset", # for distr prediction + "stype", # for distr prediction + "ctype" # for distr prediction ) paramtest = run_paramtest(learner, fun, exclude, tag = "predict") diff --git a/tests/testthat/test_paramtest_glmnet_surv_glmnet.R b/tests/testthat/test_paramtest_glmnet_surv_glmnet.R index 225244e2a..a51a0b1f1 100644 --- a/tests/testthat/test_paramtest_glmnet_surv_glmnet.R +++ b/tests/testthat/test_paramtest_glmnet_surv_glmnet.R @@ -26,7 +26,10 @@ test_that("predict surv.glmnet", { "object", # handled via mlr3 "newx", # handled via mlr3 "type", # handled via mlr3 - "predict.gamma" # renamed from gamma + "predict.gamma", # renamed from gamma + "offset", # for distr prediction + "stype", # for distr prediction + "ctype" # for distr prediction ) paramtest = run_paramtest(learner, fun, exclude, tag = "predict")