diff --git a/R/check_dag.R b/R/check_dag.R index 7a64ae4d0..272cc5814 100644 --- a/R/check_dag.R +++ b/R/check_dag.R @@ -285,15 +285,13 @@ check_dag <- function(..., adjustment_set <- unlist(dagitty::adjustmentSets(dag, effect = x), use.names = FALSE) adjustment_nodes <- unlist(dagitty::adjustedNodes(dag), use.names = FALSE) minimal_adjustments <- as.list(dagitty::adjustmentSets(dag, effect = x)) - collider <- adjustment_nodes[vapply(adjustment_nodes, ggdag::is_collider, logical(1), .dag = dag, downstream = FALSE)] - if (!length(collider)) { + collider <- adjustment_nodes[vapply(adjustment_nodes, ggdag::is_collider, logical(1), .dag = dag, downstream = FALSE)] # nolint + if (length(collider)) { + # if we *have* colliders, remove them from minimal adjustments + minimal_adjustments <- lapply(minimal_adjustments, setdiff, y = collider) + } else { # if we don't have colliders, set to NULL collider <- NULL - } else { - # if we *have* colliders, remove them from minimal adjustments - minimal_adjustments <- lapply(minimal_adjustments, function(ma) { - setdiff(ma, collider) - }) } list( # no adjustment needed when @@ -303,7 +301,7 @@ check_dag <- function(..., # incorrect adjustment when # - required is NULL and current adjustment not NULL # - OR we have a collider in current adjustments - incorrectly_adjusted = (is.null(adjustment_set) && !is.null(adjustment_nodes)) || (!is.null(collider) && collider %in% adjustment_nodes), + incorrectly_adjusted = (is.null(adjustment_set) && !is.null(adjustment_nodes)) || (!is.null(collider) && collider %in% adjustment_nodes), # nolint current_adjustments = adjustment_nodes, minimal_adjustments = minimal_adjustments, collider = collider