Skip to content

Commit

Permalink
Speed up calculate_tree_depth() (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 authored Mar 11, 2024
1 parent 2f544c5 commit dcacb6b
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions R/min_depth_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ calculate_tree_depth <- function(frame){
stop("The data frame has to contain columns called 'right daughter' and 'left daughter'!
It should be a product of the function getTree(..., labelVar = T).")
}
frame$depth <- NA
frame$depth[1] <- 0
for(i in 2:nrow(frame)){
frame[i, "depth"] <-
frame[frame[, "left daughter"] == as.numeric(rownames(frame[i,])) |
frame[, "right daughter"] == as.numeric(rownames(frame[i,])), "depth"] + 1
}
# Both child values of leaf nodes are 0, i.e., lower than min(node_id)
frame[["depth"]] <- calculate_tree_depth_(
node_id = seq_len(nrow(frame)),
left_child = frame[["left daughter"]],
right_child = frame[["right daughter"]]
)
return(frame)
}

Expand All @@ -20,16 +19,25 @@ calculate_tree_depth_ranger <- function(frame){
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
It should be a product of the function ranger::treeInfo().")
}
frame$depth <- NA
frame$depth[1] <- 0
for(i in 2:nrow(frame)){
frame[i, "depth"] <-
frame[(!is.na(frame[, "leftChild"]) & frame[, "leftChild"] == frame[i, "nodeID"]) |
(!is.na(frame[, "rightChild"]) & frame[, "rightChild"] == frame[i, "nodeID"]), "depth"] + 1
}
frame[["depth"]] <- calculate_tree_depth_(
node_id = frame[["nodeID"]],
left_child = frame[["leftChild"]],
right_child = frame[["rightChild"]]
)
return(frame)
}

# Internal function used to determine the depth of each node
calculate_tree_depth_ <- function(node_id, left_child, right_child) {
n <- length(node_id)
depth <- numeric(n)
for (i in 2:n) {
parent_node <- left_child %in% node_id[i] | right_child %in% node_id[i]
depth[i] <- depth[parent_node] + 1
}
return(depth)
}

#' Calculate minimal depth distribution of a random forest
#'
#' Get minimal depth values for all trees in a random forest
Expand Down

0 comments on commit dcacb6b

Please sign in to comment.