Skip to content

Commit

Permalink
Exponentially faster tree depth (#38)
Browse files Browse the repository at this point in the history
* Loop over depth, not over nodes

* integer index and remove leave nodes from index

* fixed security condition in the while loop

* Fixed comment

* move as.matrix() into function
  • Loading branch information
mayer79 authored Mar 13, 2024
1 parent 69c3e8a commit 96572e0
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions R/min_depth_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ calculate_tree_depth <- function(frame){
stop("The data frame has to contain columns called 'right daughter' and 'left daughter'!
It should be a product of the function getTree(..., labelVar = T).")
}
# Both child values of leaf nodes are 0, i.e., lower than min(node_id)
frame[["depth"]] <- calculate_tree_depth_(
node_id = seq_len(nrow(frame)),
left_child = frame[["left daughter"]],
right_child = frame[["right daughter"]]
frame[, c("left daughter", "right daughter")]
)
return(frame)
}
Expand All @@ -19,22 +16,30 @@ calculate_tree_depth_ranger <- function(frame){
stop("The data frame has to contain columns called 'rightChild' and 'leftChild'!
It should be a product of the function ranger::treeInfo().")
}
# Child nodes are zero based, so we increase them by 1
frame[["depth"]] <- calculate_tree_depth_(
node_id = frame[["nodeID"]],
left_child = frame[["leftChild"]],
right_child = frame[["rightChild"]]
frame[, c("leftChild", "rightChild")] + 1
)
return(frame)
}

# Internal function used to determine the depth of each node
calculate_tree_depth_ <- function(node_id, left_child, right_child) {
n <- length(node_id)
depth <- numeric(n)
for (i in 2:n) {
parent_node <- left_child %in% node_id[i] | right_child %in% node_id[i]
depth[i] <- depth[parent_node] + 1
# Internal function used to determine the depth of each node.
# The input is a data.frame with left and right child nodes in 1:nrow(childs).
calculate_tree_depth_ <- function(childs) {
childs <- as.matrix(childs)
n <- nrow(childs)
depth <- rep(NA, times = n)
j <- depth[1L] <- 0
ix <- 1L # current nodes, initialized with root node index

# j loops over tree depth
while(anyNA(depth) && j < n) { # The second condition is never used
ix <- as.integer(childs[ix, ])
ix <- ix[!is.na(ix) & ix >= 1L] # leaf nodes do not have childs
j <- j + 1
depth[ix] <- j
}

return(depth)
}

Expand Down

0 comments on commit 96572e0

Please sign in to comment.