Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break label to include [min, max] of continuous splitvar rather than breakpoint only #45

Open
leombastos opened this issue Apr 21, 2021 · 1 comment

Comments

@leombastos
Copy link

Hello.

I was trying to change the break label layout from showing only a numerical splitvar breakpoint to showing the splitvar range of values.

For ex., assuming a tree is fit and only one variable comes out in the final model, and this variable ranges from 0 to 60, and the model finds a breakpoint at 30.

Instead of showing break labels of "<30" and ">=30", I wish it would show break labels of "[0,30)" and "[30,60]".

Any chance this can be implemented with the current version of the package?

I gave it a try with add_vars, but gave up after spending some time.

Thanks!

@martin-borkovec
Copy link
Owner

Do you need it for just a single instance?
If so, the easiest way would probably be to just adjust the ggplot data like this:

library(ggparty)
#> Loading required package: ggplot2
#> Loading required package: partykit
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
    partynode(3L, info = "yes"),
    partynode(4L, info = "no"))),
  partynode(5L, info = "yes"),
  partynode(6L, split = sp_w, kids = list(
    partynode(7L, info = "yes"),
    partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

ggpy <- ggparty(py)

ggpy$data$breaks_label[3] <- "(-Inf, 75]"
ggpy$data$breaks_label[4] <- "(75, Inf)"

ggpy +
geom_edge() +
  geom_edge_label() +
  geom_node_label(aes(label = splitvar),
                  ids = "inner") +
  geom_node_label(aes(label = info),
                  ids = "terminal")

Created on 2021-04-26 by the reprex package (v0.3.0)

If you want to automate it for every plot, it will be a bit more tricky, but should be possible. Let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants