-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
open up PipeOpLearnerCV to all resampling methods #513
base: master
Are you sure you want to change the base?
Conversation
I guess if we want to be really flexible with respect to which Test currently fail because:
|
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
I just did a reiterate which I summarize here:
|
#' | ||
#' @usage NULL | ||
#' @name mlr_pipeops_learner_cv | ||
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreproc`]/[`PipeOp`]. | ||
#' | ||
#' @description | ||
#' Wraps an [`mlr3::Learner`] into a [`PipeOp`]. | ||
#' Wraps a [`mlr3::Learner`] and [`mlr3::Resampling`] into a [`PipeOp`]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just say learner
#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from. | ||
#' In the case of the resampling method returning multiple predictions per row id, the predictions | ||
#' are returned unaltered. The output [`Task`][mlr3::Task] always gains a `row_reference` column | ||
#' named `pre.<ID>` indicating the original row id prior to the resampling process. [`PipeOpAggregate`] should then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rowid.
prds = as.data.table(private$.learner$predict(task)) | ||
} | ||
# compute resampled predictions | ||
rr = resample(task, private$.learner, private$.resampling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to check if the resampling fits a model on all the data, which could then be used for prediction without needing to fit twice?
# dependency). We will opt for the least annoying behaviour here and just not use dependencies | ||
# in PipeOp ParamSets. | ||
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this. | ||
private$.additional_param_set$values = list(keep_response = FALSE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name resampling.keep_response
(and hope no resampling method has keep_response
as parameter)
|
||
# get task_type from mlr_reflections and call constructor | ||
constructor = get(mlr_reflections$task_types[["task"]][chmatch(task$task_type, table = mlr_reflections$task_types[["type"]], nomatch = 0L)][[1L]]) | ||
newtask = invoke(constructor$new, id = task$id, backend = backend, target = task$target_names, .args = task$extra_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to @mb706, this needs to be brought in accord with PipeOpTaskPreproc's affect_columns
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it should be the PipeOpTaskPreproc's responsibility to keep all col-roles (that are not disabled by affect_columns), so in particular should respect weights etc.
Maybe use this for inspiration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the previous-id-column should really be a different col-role to avoid accidentally training on the ID.
renaming = setdiff(colnames(prds), c("row_ids", "truth")) | ||
setnames(prds, old = renaming, new = sprintf("%s.%s", self$id, renaming)) | ||
setnames(prds, old = "truth", new = task$target_names) | ||
row_reference = paste0("pre.", self$id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change
setnames(prds, old = "truth", new = task$target_names) | ||
row_reference = paste0("pre.", self$id) | ||
while (row_reference %in% task$col_info$id) { | ||
row_reference = paste0(row_reference, ".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead throw an error when IDs collide, user has to change PipOp's ID then. In any case IDs are unique within graph so usually shouldn't be a problem.
# the following is needed to pertain correct row ids in the case of e.g. cv | ||
# here we do not necessarily apply PipeOpAggregate later | ||
backend = if (identical(sort(prds[[row_reference]]), sort(task$row_ids))) { | ||
set(prds, j = task$backend$primary_key, value = prds[[row_reference]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check primary_key is not in names(prds)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(assert)
@@ -143,7 +143,12 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold", | |||
}, | |||
.task_to_prediction = function(input) { | |||
prob = as.matrix(input$data(cols = input$feature_names)) | |||
colnames(prob) = unlist(input$levels()) | |||
# setting the column names the following way is safer | |||
nms = map_chr(strsplit(colnames(prob), "\\."), function(x) x[length(x)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe breaks when factor level has a period?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better way: use input$levels(input$target_name), generate putative colnames from that, then compare to given col names. the assert is good though.
@@ -143,7 +143,12 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold", | |||
}, | |||
.task_to_prediction = function(input) { | |||
prob = as.matrix(input$data(cols = input$feature_names)) | |||
colnames(prob) = unlist(input$levels()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this worked before because there was only one col in the task with factors for classification, but it is very brittle and should be avoided.
@@ -15,6 +15,9 @@ register_mlr3 = function() { | |||
c("abstract", "meta", "missings", "feature selection", "imbalanced data", | |||
"data transform", "target transform", "ensemble", "robustify", "learner", "encode", | |||
"multiplicity"))) | |||
if (!all(grepl("row_reference", x$task_col_roles))) { | |||
x$task_col_roles = map(x$task_col_roles, function(col_roles) c(col_roles, "row_reference")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work if other tasks get added after mlr3pipelines, so we should ask @mllg (1) is there a way to do things like this well?, (2) can we just have the col role like that
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. | ||
#' | ||
#' @description | ||
#' Aggregates features row-wise based on multiple observations indicated via a column of role `row_reference` according to expressions given as formulas. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we don't need to restrict to that colrole?
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as: | ||
#' * `aggregation` :: named `list` of `formula`\cr | ||
#' Expressions for how features should be aggregated, in the form of `formula`. | ||
#' Each element of the list is a `formula` with the name of the element naming the feature to aggregate and the formula expression determining the result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this shouldn't be a named list of formulae, but just a single formula naming a data.table expression, such as
lapply(.SD, mean)
or .(Sepal.Length = first(Sepal.Length), Sepal.Width = last(Sepal.Width))
. We wouldn't need to teach the user data.table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alternatively: aggregation.all
(single argument function), aggregation.specific
: named list
of formulae, similar to PipeOpMutate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aggregation.all
does not apply to (1) things named in aggregation.specific
or (2) by
columns
#' Initialized to `list()`, i.e., no aggregation is performed. | ||
#' * `by` :: `character(1)` | `NULL`\cr | ||
#' Column indicating the `row_reference` column of the [`Task`][mlr3::Task] that should be the row-wise basis for the aggregation. | ||
#' Initialized to `NULL`, i.e., no aggregation is performed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also .SDcols
# checks that `aggregation` is | ||
# * a named list of `formula` | ||
# * that each element has only a rhs | ||
check_aggregation_formulae = function(x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by now we can use mlr3misc::crate()
private = list( | ||
.transform = function(task) { | ||
|
||
if (length(self$param_set$values$aggregation) == 0L || is.null(self$param_set$values$by)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty aggregation should not be allowed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty by
should still not early-exit.
Allow all resamplings currently listed in
mlr_resamplings
(and more, e.g. all that inherit fromResampling
). Closes #500If a resampling returns multiple predictions for a row id, then they are aggregated using the mean; iftask_type = "classif"
aggregation forresponse
is done using the mode; maybe we should simply base this on the argmax of the mean aggregation of the probs instead - if available). Maybe we also should open up to custom aggregation functions that can be passed as a hyperparameter but we have to make sure that boundaries of probs etc. are respected, e.g., all aggregated probs must be within[0, 1]
and the sum for one row id must still be1
(or we enforce this later),If a resampling fails to return predictions for a row id present in the input task, this row id is added with missing values.All in all, this results in the row ids of the input task matching the row ids of the output features based on the resampled prediction.
Probably should be renamed from
PipeOpLearnerCV
toPipeOpLearnerResamling
or something.For custom resampling,train_sets
andtest_sets
are currently passed asParamUty
s"resampling.custom.train_sets"
and"resampling.custom.test_sets"
; not sure if this is the best way here.