I want to create a decision tree (using evtree which has a VERY LONG run time with large datasets) on a subsample of data.
I then want to take this model and update the terminal node estimates with estimates from hold out data. This is analogous to the concept of "honesty" in the GRF package where bias in model construction from sampling is countered by looking at hold out data. The end result in such a scenario would be a final model that is generally less biased, runs faster (smaller training input) and which has lower variance. Ideally I'd be able to take the new model and inference new data on it.
library(partykit)
mtcars
set.seed(12)
train = sample(nrow(mtcars), nrow(mtcars)/1.5)
sample_tree = ctree(mpg ~. , data = mtcars[train, ])
sample_tree %>% as.simpleparty
# Fitted party:
# [1] root
# | [2] cyl <= 6: 23.755 (n = 11, err = 224.8)
# | [3] cyl > 6: 15.380 (n = 10, err = # 42.1)
data.frame(node = predict(sample_tree, newdata = mtcars[-train, ], type = 'node'),
prediction = mtcars[-train, ]$mpg) %>%
group_by(node) %>%
summarize(mpg = mean(prediction)) %>% as.list
# $node
# [1] 2 3
# $mpg
# [1] 24.31429 14.40000
In this case I'd update the nodes id as 2,3 in the tree to 24.31429 and 14.40000 respective.
Things I've tried: chat GPT 1000x, a lot of googling, jumping through hoops to figure out how to get terminal node values, etc.
edit2: this seems to work but I don't 100% understand why. Proceed with caution
Adapted from Achim Zeileis's answer
# library(evtree)
set.seed(123)
train = sample(nrow(diamonds), nrow(diamonds)/20)
diamonds_evtree = evtree("price ~ .", data = (diamonds %>% select(any_of(c("carat", "depth", "table", "price"))))[train, ],
maxdepth = 3L, niterations = 101)
diamonds_ctree = ctree(price ~ ., data = (diamonds %>% select(any_of(c("depth", "table", "price", "x", "y", "y"))))[train, ])
refit_constparty(as.constparty(diamonds_evtree), diamonds[-train,]) #fails
refit_constparty(diamonds_ctree, diamonds[-train,]) #works
as.constparty(diamonds_evtree)
refit_simpleparty <- function(object, newdata) {
stopifnot(inherits(object, "constparty") | inherits(object, "simpleparty"))
if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
stop("weights not implemented yet")
}
d <- model.frame(terms(object), data = newdata)
ret <- party(object$node,
data = d,
fitted = data.frame(
"(fitted)" = fitted_node(object$node, d),
"(response)" = d[[1L]],
"(weights)" = 1L,
check.names = FALSE),
terms = terms(object))
as.simpleparty(ret)
}
# works with "arbitrary data"
refit_simpleparty(diamonds_ctree %>% as.simpleparty, newdata = diamonds)
This can be accomplished by setting up a new
party()with the new data and fitted values and subsequently coercing toconstparty. Seevignette("constparty", package = "partykit")for more details and worked examples.I have written a short function that encapsulates the necessary steps:
Note that calling the
model.frame()is important for potentially re-ordering and transforming the variables (e.g., setting up factors or logs on the fly).For your data split I obtain the following:
In Node 2 the fitted value is NA because there are no observations. (Maybe I did something wrong but I could not replicate the fitted values you show above.)