diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 0000000..c99dd9c --- /dev/null +++ b/AGENT.md @@ -0,0 +1,33 @@ +# MultiScaleTreeGraph Performance Agent Notes + +## Goal +- Optimize traversal-heavy workloads for very large trees. +- Prioritize low allocations and type-stable code paths. + +## Benchmark Commands +- Local suite: + - `julia --project=benchmark benchmark/benchmarks.jl` +- Package tests (workaround for current precompile deadlock on Julia 1.12.1): + - `julia --project --compiled-modules=no -e 'using Pkg; Pkg.test()'` + +## CI Benchmarks +- Uses `AirspeedVelocity.jl` via `.github/workflows/Benchmarks.yml`. +- Benchmark definitions live in `benchmark/benchmarks.jl` and must expose `const SUITE`. + +## Current Hot Paths +- `src/compute_MTG/traverse.jl` +- `src/compute_MTG/ancestors.jl` +- `src/compute_MTG/descendants.jl` +- `src/compute_MTG/indexing.jl` +- `src/compute_MTG/check_filters.jl` +- `src/types/Node.jl` +- `src/compute_MTG/node_funs.jl` + +## Practical Optimization Rules +- Avoid allocating temporary arrays in per-node loops. +- Prefer in-place APIs for repeated queries: + - `ancestors!(buffer, node, key; ...)` + - `descendants!(buffer, node, key; ...)` +- Keep filter checks branch-light when no filters are provided. +- Keep key access on typed attribute containers (`NamedTuple`, `MutableNamedTuple`, typed dicts) in specialized methods when possible. +- Preserve API behavior and add tests for every optimization that changes internals. diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 3024600..f35bb4b 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -88,25 +88,12 @@ function ancestors_workload(nodes, reps::Int) return s end -function ancestors_workload_inplace_1(nodes, reps::Int) - s = 0.0 - @inbounds for _ in 1:reps - for n in nodes - out = ancestors!(n, :mass, recursivity_level=4, type=Float64) - for v in out - s += v - end - end - end - return s -end - -function ancestors_workload_inplace_2(nodes, reps::Int) +function ancestors_workload_inplace(nodes, reps::Int) s = 0.0 buf = Float64[] @inbounds for _ in 1:reps for n in nodes - ancestors!(buf, n, :mass, recursivity_level=4, type=Float64) + ancestors!(buf, n, :mass, recursivity_level=4) for v in buf s += v end @@ -130,7 +117,7 @@ end function descendants_extraction_workload_inplace_2(root) vals = Float64[] - descendants!(vals, root, :mass, type=Float64) + descendants!(vals, root, :mass) end suite_name = "mstg" @@ -151,17 +138,18 @@ SUITE[suite_name] = BenchmarkGroup([ root, leaves, sample_nodes = synthetic_tree() SUITE[suite_name]["traverse"]["full_tree_nodes"] = @benchmarkable traverse!($root, _ -> nothing) SUITE[suite_name]["traverse_extract"]["descendants_mass"] = @benchmarkable descendants_extraction_workload($root) -SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_1($root) +SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_1"] = @benchmarkable descendants_extraction_workload_inplace_1($root) # Add this one only if we have a method for `descendants!(val, node, key, type)` if hasmethod(descendants!, Tuple{AbstractVector,Node,Symbol}) - SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_2($root) + SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_2"] = @benchmarkable descendants_extraction_workload_inplace_2($root) end SUITE[suite_name]["many_queries"]["children_repeated"] = @benchmarkable children_workload($sample_nodes, 300) SUITE[suite_name]["many_queries"]["parent_repeated"] = @benchmarkable parent_workload($sample_nodes, 300) SUITE[suite_name]["many_queries"]["ancestors_repeated"] = @benchmarkable ancestors_workload($leaves, 40) -if hasmethod(ancestors!, Tuple{AbstractVector,Node,Symbol}) +# Test if ancestors! exists in the package first: +if isdefined(MultiScaleTreeGraph, :ancestors!) SUITE[suite_name]["many_queries"]["ancestors_repeated_inplace"] = @benchmarkable ancestors_workload_inplace($leaves, 40) end diff --git a/src/MultiScaleTreeGraph.jl b/src/MultiScaleTreeGraph.jl index bf43ac1..f25a5f7 100644 --- a/src/MultiScaleTreeGraph.jl +++ b/src/MultiScaleTreeGraph.jl @@ -79,7 +79,7 @@ export insert_parents!, insert_generations!, insert_children!, insert_siblings! export insert_parent!, insert_generation!, insert_child!, insert_sibling! export write_mtg export is_segment! -export descendants, ancestors, descendants! +export descendants, ancestors, ancestors!, descendants! export Node export AbstractNodeMTG export NodeMTG diff --git a/src/compute_MTG/ancestors.jl b/src/compute_MTG/ancestors.jl index 96e61c1..80ac022 100644 --- a/src/compute_MTG/ancestors.jl +++ b/src/compute_MTG/ancestors.jl @@ -77,12 +77,13 @@ function ancestors( # Change the filtering function if we also want to remove nodes with nothing values. filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key) + use_no_filter = no_node_filters(scale, symbol, link, filter_fun_) val = Array{type,1}() # Put the recursivity level into an array so it is mutable in-place: if self - if is_filtered(node, scale, symbol, link, filter_fun_) + if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_) val_ = unsafe_getindex(node, key) push!(val, val_) elseif !all @@ -91,31 +92,47 @@ function ancestors( end end - ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level) + if use_no_filter + ancestors_values_no_filter!(node, key, val, recursivity_level) + else + ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, val, recursivity_level) + end return val end -function ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level) - - if !isroot(node) && recursivity_level != 0 - parent_ = parent(node) +function ancestors_values!(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level) + current = node + remaining = recursivity_level - # Is there any filter happening for the current node? (FALSE if filtered out): + while !isroot(current) && remaining != 0 + parent_ = parent(current) keep = is_filtered(parent_, scale, symbol, link, filter_fun) if keep - val_ = unsafe_getindex(parent_, key) - push!(val, val_) + push!(val, unsafe_getindex(parent_, key)) # Only decrement the recursivity level when the current node is not filtered-out - recursivity_level -= 1 + remaining -= 1 end # If we want to continue even if the current node is filtered-out - if all || keep - ancestors_(parent_, key, scale, symbol, link, all, filter_fun, val, recursivity_level) - end + (all || keep) || break + current = parent_ + end + return val +end + +function ancestors_values_no_filter!(node, key, val, recursivity_level) + current = node + remaining = recursivity_level + + while !isroot(current) && remaining != 0 + parent_ = parent(current) + push!(val, unsafe_getindex(parent_, key)) + remaining -= 1 + current = parent_ end + return val end # Version that returns the nodes instead of the values: @@ -133,12 +150,12 @@ function ancestors( # Check the filters once, and then compute the ancestors recursively using `ancestors_` check_filters(node, scale=scale, symbol=symbol, link=link) - # Change the filtering function if we also want to remove nodes with nothing values. + use_no_filter = no_node_filters(scale, symbol, link, filter_fun) val = Array{typeof(node),1}() # Put the recursivity level into an array so it is mutable in-place: if self - if is_filtered(node, scale, symbol, link, filter_fun) + if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun) push!(val, node) elseif !all # We don't keep the value and we have to stop at the first filtered-out value @@ -146,28 +163,110 @@ function ancestors( end end - ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level) + if use_no_filter + ancestors_nodes_no_filter!(node, val, recursivity_level) + else + ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level) + end return val end -function ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level) - - if !isroot(node) && recursivity_level != 0 - parent_ = parent(node) +function ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level) + current = node + remaining = recursivity_level - # Is there any filter happening for the current node? (FALSE if filtered out): + while !isroot(current) && remaining != 0 + parent_ = parent(current) keep = is_filtered(parent_, scale, symbol, link, filter_fun) if keep push!(val, parent_) # Only decrement the recursivity level when the current node is not filtered-out - recursivity_level -= 1 + remaining -= 1 end # If we want to continue even if the current node is filtered-out - if all || keep - ancestors_(parent_, scale, symbol, link, all, filter_fun, val, recursivity_level) + (all || keep) || break + current = parent_ + end + return val +end + +function ancestors_nodes_no_filter!(node, val, recursivity_level) + current = node + remaining = recursivity_level + + while !isroot(current) && remaining != 0 + parent_ = parent(current) + push!(val, parent_) + remaining -= 1 + current = parent_ + end + return val +end + +function ancestors!( + out::AbstractVector, + node, key; + scale=nothing, + symbol=nothing, + link=nothing, + all::Bool=true, + self=false, + filter_fun=nothing, + recursivity_level=-1, + ignore_nothing=false, + type::Union{Union,DataType}=Any, +) + check_filters(node, scale=scale, symbol=symbol, link=link) + filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key) + use_no_filter = no_node_filters(scale, symbol, link, filter_fun_) + + empty!(out) + if self + if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_) + push!(out, unsafe_getindex(node, key)) + elseif !all + return out + end + end + + if use_no_filter + ancestors_values_no_filter!(node, key, out, recursivity_level) + else + ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, out, recursivity_level) + end + return out +end + +function ancestors!( + out::AbstractVector, + node; + scale=nothing, + symbol=nothing, + link=nothing, + all::Bool=true, + self=false, + filter_fun=nothing, + recursivity_level=-1, +) + check_filters(node, scale=scale, symbol=symbol, link=link) + use_no_filter = no_node_filters(scale, symbol, link, filter_fun) + + empty!(out) + if self + if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun) + push!(out, node) + elseif !all + return out end end + + if use_no_filter + ancestors_nodes_no_filter!(node, out, recursivity_level) + else + ancestors_nodes!(node, scale, symbol, link, all, filter_fun, out, recursivity_level) + end + return out end diff --git a/src/compute_MTG/check_filters.jl b/src/compute_MTG/check_filters.jl index b0dd0fe..c8002b9 100644 --- a/src/compute_MTG/check_filters.jl +++ b/src/compute_MTG/check_filters.jl @@ -11,7 +11,11 @@ check_filters(mtg, scale = (1,2)) check_filters(mtg, scale = (1,2), symbol = "Leaf", link = "<") ``` """ +@inline no_node_filters(scale, symbol, link, filter_fun=nothing) = + isnothing(scale) && isnothing(symbol) && isnothing(link) && isnothing(filter_fun) + function check_filters(node::Node{N,A}; scale=nothing, symbol=nothing, link=nothing) where {N<:AbstractNodeMTG,A} + no_node_filters(scale, symbol, link) && return nothing root_node = get_root(node) @@ -76,7 +80,10 @@ end end @inline function is_filtered(filter, value::T) where {T<:Union{Tuple,Array}} - all(map(x -> is_filtered(filter, x), value)) + for x in value + is_filtered(filter, x) || return false + end + return true end diff --git a/src/compute_MTG/descendants.jl b/src/compute_MTG/descendants.jl index 4037a89..e8cbb14 100644 --- a/src/compute_MTG/descendants.jl +++ b/src/compute_MTG/descendants.jl @@ -1,3 +1,59 @@ +function collect_descendant_values!(node, key, scale, symbol, link, filter_fun, all, val, recursivity_level) + recursivity_level == 0 && return val + recursivity_level -= 1 + + keep = is_filtered(node, scale, symbol, link, filter_fun) + if keep + push!(val, unsafe_getindex(node, key)) + elseif !all + return val + end + + @inbounds for chnode in children(node) + collect_descendant_values!(chnode, key, scale, symbol, link, filter_fun, all, val, recursivity_level) + end + return val +end + +function collect_descendant_values_no_filter!(node, key, val, recursivity_level) + recursivity_level == 0 && return val + recursivity_level -= 1 + + push!(val, unsafe_getindex(node, key)) + @inbounds for chnode in children(node) + collect_descendant_values_no_filter!(chnode, key, val, recursivity_level) + end + return val +end + +function collect_descendant_nodes!(node, scale, symbol, link, filter_fun, all, val, recursivity_level) + recursivity_level == 0 && return val + recursivity_level -= 1 + + keep = is_filtered(node, scale, symbol, link, filter_fun) + if keep + push!(val, node) + elseif !all + return val + end + + @inbounds for chnode in children(node) + collect_descendant_nodes!(chnode, scale, symbol, link, filter_fun, all, val, recursivity_level) + end + return val +end + +function collect_descendant_nodes_no_filter!(node, val, recursivity_level) + recursivity_level == 0 && return val + recursivity_level -= 1 + + push!(val, node) + @inbounds for chnode in children(node) + collect_descendant_nodes_no_filter!(chnode, val, recursivity_level) + end + return val +end + function descendants( node, key; scale=nothing, @@ -17,17 +73,21 @@ function descendants( filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key) val = Array{type,1}() + use_no_filter = no_node_filters(scale, symbol, link, filter_fun_) if self - traverse!(node, scale=scale, symbol=symbol, link=link, filter_fun=filter_fun_, all=all, recursivity_level=recursivity_level) do chnode - push!(val, unsafe_getindex(chnode, key)) - # Only decrement the recursivity level when the current node is not filtered-out + if use_no_filter + collect_descendant_values_no_filter!(node, key, val, recursivity_level) + else + collect_descendant_values!(node, key, scale, symbol, link, filter_fun_, all, val, recursivity_level) end else # If we don't want to include the value of the current node, we apply the traversal to its children directly: for chnode in children(node) - traverse!(chnode, scale=scale, symbol=symbol, link=link, filter_fun=filter_fun_, all=all, recursivity_level=recursivity_level) do chnode - push!(val, unsafe_getindex(chnode, key)) + if use_no_filter + collect_descendant_values_no_filter!(chnode, key, val, recursivity_level) + else + collect_descendant_values!(chnode, key, scale, symbol, link, filter_fun_, all, val, recursivity_level) end end end @@ -51,16 +111,21 @@ function descendants( check_filters(node, scale=scale, symbol=symbol, link=link) val = Array{typeof(node),1}() + use_no_filter = no_node_filters(scale, symbol, link, filter_fun) if self - traverse!(node, scale=scale, symbol=symbol, link=link, filter_fun=filter_fun, all=all, recursivity_level=recursivity_level) do chnode - push!(val, chnode) + if use_no_filter + collect_descendant_nodes_no_filter!(node, val, recursivity_level) + else + collect_descendant_nodes!(node, scale, symbol, link, filter_fun, all, val, recursivity_level) end else # If we don't want to include the value of the current node, we apply the traversal to its children directly: for chnode in children(node) - traverse!(chnode, scale=scale, symbol=symbol, link=link, filter_fun=filter_fun, all=all, recursivity_level=recursivity_level) do chnode - push!(val, chnode) + if use_no_filter + collect_descendant_nodes_no_filter!(chnode, val, recursivity_level) + else + collect_descendant_nodes!(chnode, scale, symbol, link, filter_fun, all, val, recursivity_level) end end end @@ -68,6 +133,74 @@ function descendants( return val end +function descendants!( + out::AbstractVector, + node, key; + scale=nothing, + symbol=nothing, + link=nothing, + all::Bool=true, + self=false, + filter_fun=nothing, + recursivity_level=Inf, + ignore_nothing::Bool=false, +) + check_filters(node, scale=scale, symbol=symbol, link=link) + filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key) + use_no_filter = no_node_filters(scale, symbol, link, filter_fun_) + + empty!(out) + if self + if use_no_filter + collect_descendant_values_no_filter!(node, key, out, recursivity_level) + else + collect_descendant_values!(node, key, scale, symbol, link, filter_fun_, all, out, recursivity_level) + end + else + for chnode in children(node) + if use_no_filter + collect_descendant_values_no_filter!(chnode, key, out, recursivity_level) + else + collect_descendant_values!(chnode, key, scale, symbol, link, filter_fun_, all, out, recursivity_level) + end + end + end + return out +end + +function descendants!( + out::AbstractVector, + node; + scale=nothing, + symbol=nothing, + link=nothing, + all::Bool=true, + self=false, + filter_fun=nothing, + recursivity_level=Inf, +) + check_filters(node, scale=scale, symbol=symbol, link=link) + use_no_filter = no_node_filters(scale, symbol, link, filter_fun) + + empty!(out) + if self + if use_no_filter + collect_descendant_nodes_no_filter!(node, out, recursivity_level) + else + collect_descendant_nodes!(node, scale, symbol, link, filter_fun, all, out, recursivity_level) + end + else + for chnode in children(node) + if use_no_filter + collect_descendant_nodes_no_filter!(chnode, out, recursivity_level) + else + collect_descendant_nodes!(chnode, scale, symbol, link, filter_fun, all, out, recursivity_level) + end + end + end + return out +end + #Note: The mutating version is more complicated, so we don't use `traverse!` but make another implementation. function descendants!( node::Node{N,A}, key; diff --git a/src/compute_MTG/filter/filter-funs.jl b/src/compute_MTG/filter/filter-funs.jl index 8d4a262..96ab173 100644 --- a/src/compute_MTG/filter/filter-funs.jl +++ b/src/compute_MTG/filter/filter-funs.jl @@ -39,6 +39,13 @@ function is_segment!(node::Node{N,A}) where {N<:AbstractNodeMTG,A} return false end +@inline function all_not_nothing(node, attr_keys) + for key in attr_keys + unsafe_getindex(node, key) === nothing && return false + end + return true +end + """ filter_fun_nothing(filter_fun, ignore_nothing, attr_keys) @@ -51,12 +58,12 @@ function filter_fun_nothing(filter_fun, ignore_nothing, attr_keys) if filter_fun !== nothing filter_fun_ = function (node) - all([unsafe_getindex(node, i) !== nothing for i in attr_keys]) && filter_fun(node) + all_not_nothing(node, attr_keys) && filter_fun(node) end else filter_fun_ = function (node) - all([unsafe_getindex(node, i) !== nothing for i in attr_keys]) + all_not_nothing(node, attr_keys) end end else diff --git a/src/compute_MTG/indexing.jl b/src/compute_MTG/indexing.jl index fa4ae30..8722cb8 100644 --- a/src/compute_MTG/indexing.jl +++ b/src/compute_MTG/indexing.jl @@ -24,13 +24,33 @@ end unsafe_getindex(node::Node, key) = unsafe_getindex(node, Symbol(key)) +@inline function unsafe_getindex(node::Node{M,NamedTuple}, key::Symbol) where {M<:AbstractNodeMTG} + attrs = node_attributes(node) + hasproperty(attrs, key) ? getproperty(attrs, key) : nothing +end + +@inline function unsafe_getindex(node::Node{M,MutableNamedTuple}, key::Symbol) where {M<:AbstractNodeMTG} + attrs = node_attributes(node) + hasproperty(attrs, key) ? getproperty(attrs, key) : nothing +end + # For a vector of keys: -unsafe_getindex(node::Node, key::Union{Vector{Symbol},Vector{String}}) = [unsafe_getindex(node, i) for i in key] +function unsafe_getindex(node::Node, key::Union{Vector{Symbol},Vector{String}}) + vals = Vector{Any}(undef, length(key)) + @inbounds for i in eachindex(key) + vals[i] = unsafe_getindex(node, key[i]) + end + vals +end function unsafe_getindex( node::Node{M,T} where {M<:AbstractNodeMTG,T<:AbstractDict{Symbol,S} where {S}}, key::Union{Vector{Symbol},Vector{String}} ) - [unsafe_getindex(node, i) for i in key] + vals = Vector{Any}(undef, length(key)) + @inbounds for i in eachindex(key) + vals[i] = unsafe_getindex(node, key[i]) + end + vals end function unsafe_getindex( diff --git a/src/compute_MTG/node_funs.jl b/src/compute_MTG/node_funs.jl index 1b02d0c..a9f54e2 100644 --- a/src/compute_MTG/node_funs.jl +++ b/src/compute_MTG/node_funs.jl @@ -2,7 +2,7 @@ isleaf(node::Node) Test whether a node is a leaf or not. """ -isleaf(node::Node) = length(children(node)) == 0 +isleaf(node::Node) = isempty(children(node)) """ isroot(node::Node) @@ -20,8 +20,7 @@ function lastchild(node::Node) if isleaf(node) return nothing else - allchildren = children(node) - return allchildren[maximum(keys(allchildren))] + return last(children(node)) end end @@ -85,11 +84,7 @@ function addchild!(p::Node{N,A}, child::Node; force=false) where {N<:AbstractNod error("The node already has a parent. Hint: use `force=true` if needed.") end - if children(p) === nothing - rechildren!(child, Node{N,A}[child]) - else - push!(children(p), child) - end + push!(children(p), child) return child end @@ -99,11 +94,11 @@ end Find the root node of a tree, given any node in the tree. """ function get_root(node::Node) - if isroot(node) - return (node) - else - get_root(parent(node)) + root = node + while !isroot(root) + root = parent(root) end + return root end """ @@ -113,11 +108,23 @@ Return the siblings of `node` as a vector of nodes (or `nothing` if non-existant """ function siblings(node::Node) # If there is no parent, no siblings, return nothing: - parent(node) === nothing && return nothing - - all_siblings = children(parent(node)) - - return all_siblings[findall(x -> x != node, all_siblings)] + parent_ = parent(node) + parent_ === nothing && return nothing + + all_siblings = children(parent_) + nsiblings = length(all_siblings) + nsiblings <= 1 && return similar(all_siblings, 0) + + out = Vector{eltype(all_siblings)}(undef, nsiblings - 1) + j = 1 + @inbounds for sibling in all_siblings + if sibling !== node + out[j] = sibling + j += 1 + end + end + resize!(out, j - 1) + return out end """ @@ -127,12 +134,9 @@ Return the last sibling of `node` (or `nothing` if non-existant). """ function lastsibling(node::Node) # If there is no parent, no siblings, return nothing: - parent(node) === nothing && return nothing - - all_siblings = children(parent(node)) - # Get the index of the current node in the siblings: - - return all_siblings[maximum(keys(all_siblings))] + parent_ = parent(node) + parent_ === nothing && return nothing + return last(children(parent_)) end """ diff --git a/src/compute_MTG/traverse.jl b/src/compute_MTG/traverse.jl index 30665ef..462cacf 100644 --- a/src/compute_MTG/traverse.jl +++ b/src/compute_MTG/traverse.jl @@ -44,6 +44,65 @@ end """ traverse!, traverse +function traverse_no_filter!(node::Node, f::Function, recursivity_level) + nodes = Vector{typeof(node)}(undef, 1) + levels = Vector{typeof(recursivity_level)}(undef, 1) + nodes[1] = node + levels[1] = recursivity_level + + while !isempty(nodes) + current = pop!(nodes) + current_level = pop!(levels) + current_level == 0 && continue + next_level = current_level - 1 + + try + f(current) + catch e + println("Issue in function $f for node #$(node_id(current)).") + rethrow(e) + end + + all_children = children(current) + @inbounds for i in lastindex(all_children):-1:firstindex(all_children) + push!(nodes, all_children[i]) + push!(levels, next_level) + end + end + + return nothing +end + +function traverse_no_filter(node::Node, f::Function, val, recursivity_level) + nodes = Vector{typeof(node)}(undef, 1) + levels = Vector{typeof(recursivity_level)}(undef, 1) + nodes[1] = node + levels[1] = recursivity_level + + while !isempty(nodes) + current = pop!(nodes) + current_level = pop!(levels) + current_level == 0 && continue + next_level = current_level - 1 + + val_ = try + f(current) + catch e + println("Issue in function $f for node $(node_id(current)).") + rethrow(e) + end + push!(val, val_) + + all_children = children(current) + @inbounds for i in lastindex(all_children):-1:firstindex(all_children) + push!(nodes, all_children[i]) + push!(levels, next_level) + end + end + + return val +end + function traverse!(node::Node, f::Function, args...; scale=nothing, symbol=nothing, link=nothing, filter_fun=nothing, all=true, recursivity_level=Inf) if !isempty(args) g = node -> f(node, args...) @@ -52,11 +111,21 @@ function traverse!(node::Node, f::Function, args...; scale=nothing, symbol=nothi end # If the node has already a cache of the traversal, we use it instead of traversing the mtg: - if haskey(node_traversal_cache(node), cache_name(scale, symbol, link, all, filter_fun)) - for i in node_traversal_cache(node)[cache_name(scale, symbol, link, all, filter_fun)] - # NB: node_traversal_cache(node)[cache_name(scale, symbol, link, filter_fun)] is a Vector of nodes corresponding to the traversal filters applied. - g(i) + cache = node_traversal_cache(node) + if !isempty(cache) + cache_key = cache_name(scale, symbol, link, all, filter_fun) + cached_nodes = get(cache, cache_key, nothing) + if cached_nodes !== nothing + for i in cached_nodes + # NB: node_traversal_cache(node)[cache_name(scale, symbol, link, filter_fun)] is a Vector of nodes corresponding to the traversal filters applied. + g(i) + end + return end + end + + if no_node_filters(scale, symbol, link, filter_fun) + traverse_no_filter!(node, g, recursivity_level) return end @@ -99,17 +168,28 @@ function traverse(node::Node, f::Function, args...; scale=nothing, symbol=nothin # NB: f has to return someting here, if its a mutating function, use traverse! # If the node has already a cache of the traversal, we use it instead of traversing the mtg: - if haskey(node_traversal_cache(node), cache_name(scale, symbol, link, all, filter_fun)) - for i in node_traversal_cache(node)[cache_name(scale, symbol, link, all, filter_fun)] - # NB: node_traversal_cache(node)[cache_name(scale, symbol, link, filter_fun)] is a Vector of nodes corresponding to the traversal filters applied. - val_ = try - g(i) - catch e - error("Issue in function $f for node $(node_id(node)).") - rethrow(e) + cache = node_traversal_cache(node) + if !isempty(cache) + cache_key = cache_name(scale, symbol, link, all, filter_fun) + cached_nodes = get(cache, cache_key, nothing) + + if cached_nodes !== nothing + for i in cached_nodes + # NB: node_traversal_cache(node)[cache_name(scale, symbol, link, filter_fun)] is a Vector of nodes corresponding to the traversal filters applied. + val_ = try + g(i) + catch e + error("Issue in function $f for node $(node_id(node)).") + rethrow(e) + end + push!(val, val_) end - push!(val, val_) + return val end + end + + if no_node_filters(scale, symbol, link, filter_fun) + traverse_no_filter(node, g, val, recursivity_level) return val end @@ -175,4 +255,4 @@ function traverse( recursivity_level=Inf, ) traverse(node, f, args...; scale=scale, symbol=symbol, link=link, filter_fun=filter_fun, all=all, type=type, recursivity_level=recursivity_level) -end \ No newline at end of file +end diff --git a/src/types/Node.jl b/src/types/Node.jl index 86f246d..9a7b9e4 100644 --- a/src/types/Node.jl +++ b/src/types/Node.jl @@ -193,31 +193,40 @@ AbstractTrees.ChildIndexing(::Type{<:Node{T,A}}) where {T<:AbstractNodeMTG,A} = AbstractTrees.NodeType(::Type{<:Node{T,A}}) where {T<:AbstractNodeMTG,A} = HasNodeType() AbstractTrees.nodetype(::Type{<:Node{T,A}}) where {T<:AbstractNodeMTG,A} = Node{T,A} +@inline function sibling_index(all_siblings, node) + @inbounds for i in eachindex(all_siblings) + all_siblings[i] === node && return i + end + return nothing +end + function AbstractTrees.nextsibling(node::Node) # If there is no parent, no siblings, return nothing: - parent(node) === nothing && return nothing + parent_ = parent(node) + parent_ === nothing && return nothing - all_siblings = children(parent(node)) + all_siblings = children(parent_) # Get the index of the current node in the siblings: - node_index = findfirst(x -> x == node, all_siblings) - if node_index < length(all_siblings) - all_siblings[node_index+1] - else + node_index = sibling_index(all_siblings, node) + if node_index === nothing || node_index >= lastindex(all_siblings) nothing + else + all_siblings[node_index+1] end end function AbstractTrees.prevsibling(node::Node) # If there is no parent, no siblings, return nothing: - parent(node) === nothing && return nothing + parent_ = parent(node) + parent_ === nothing && return nothing - all_siblings = children(parent(node)) + all_siblings = children(parent_) # Get the index of the current node in the siblings: - node_index = findfirst(x -> x == node, all_siblings) - if node_index > 1 - all_siblings[node_index-1] - else + node_index = sibling_index(all_siblings, node) + if node_index === nothing || node_index <= firstindex(all_siblings) nothing + else + all_siblings[node_index-1] end end diff --git a/test/test-ancestors.jl b/test/test-ancestors.jl index ca00bc9..b48a194 100644 --- a/test/test-ancestors.jl +++ b/test/test-ancestors.jl @@ -19,7 +19,19 @@ @test ancestors(leaf_node, :Width, symbol=("Leaf", "Internode"), self=true) == width_all[end:-1:end-1] + buf_vals = Union{Nothing,Float64}[] + @test ancestors!(buf_vals, leaf_node, :Width; type=Union{Nothing,Float64}) == + reverse(width_all[1:4]) + @test ancestors!(buf_vals, leaf_node, :Width, symbol=("Leaf", "Internode"), self=true) == + width_all[end:-1:end-1] + # Using the method that returns the nodes directly: @test ancestors(leaf_node) == [leaf_node |> parent, leaf_node |> parent |> parent, leaf_node |> parent |> parent |> parent, leaf_node |> parent |> parent |> parent |> parent] @test ancestors(leaf_node, self=true) == [leaf_node, leaf_node |> parent, leaf_node |> parent |> parent, leaf_node |> parent |> parent |> parent, leaf_node |> parent |> parent |> parent |> parent] + + buf_nodes = typeof(leaf_node)[] + @test ancestors!(buf_nodes, leaf_node) == + [leaf_node |> parent, leaf_node |> parent |> parent, leaf_node |> parent |> parent |> parent, leaf_node |> parent |> parent |> parent |> parent] + @test ancestors!(buf_nodes, leaf_node, self=true) == + [leaf_node, leaf_node |> parent, leaf_node |> parent |> parent, leaf_node |> parent |> parent |> parent, leaf_node |> parent |> parent |> parent |> parent] end diff --git a/test/test-descendants.jl b/test/test-descendants.jl index 9a5e986..262763c 100644 --- a/test/test-descendants.jl +++ b/test/test-descendants.jl @@ -18,6 +18,11 @@ @test descendants(mtg2, :Width, symbol=("Leaf", "Internode"), self=true) == width_all[end-1:end] + out_vals = Union{Nothing,Float64}[] + @test descendants!(out_vals, mtg, :Width) == width_all + @test descendants!(out_vals, mtg2, :Width, symbol=("Leaf", "Internode"), self=true) == + width_all[end-1:end] + # Using the mutating version: @test descendants!(mtg, :Width) == descendants(mtg, :Width) @test descendants!(mtg2, :Width, symbol=("Leaf", "Internode"), self=true) == @@ -31,6 +36,11 @@ @test descendants(mtg) == traverse(mtg[1], x -> x) @test descendants(mtg, self=true) == traverse(mtg, x -> x) @test descendants(get_node(mtg, 6), self=true) == [get_node(mtg, 6), get_node(mtg, 7)] + + out_nodes = typeof(mtg)[] + @test descendants!(out_nodes, mtg) == traverse(mtg[1], x -> x) + @test descendants!(out_nodes, mtg, self=true) == traverse(mtg, x -> x) + @test descendants!(out_nodes, get_node(mtg, 6), self=true) == [get_node(mtg, 6), get_node(mtg, 7)] end # using BenchmarkTools diff --git a/test/test-traverse.jl b/test/test-traverse.jl index 86018c2..709c919 100644 --- a/test/test-traverse.jl +++ b/test/test-traverse.jl @@ -38,4 +38,21 @@ end @test traverse(mtg, x -> x, symbol="Internode", all=false) == Any[] # No internode in the first level, all=false -> iteration stops before the first node @test traverse(mtg, node -> node[:Length]) == Any[nothing, nothing, nothing, 0.1, 0.2, 0.1, 0.2] @test traverse(mtg, node -> node[:Length], type=Union{Nothing,Float64}) == Union{Nothing,Float64}[nothing, nothing, nothing, 0.1, 0.2, 0.1, 0.2] -end \ No newline at end of file +end + +@testset "traverse deep no-filter path" begin + root = Node(1, NodeMTG("/", "Plant", 1, 1)) + current = root + for i in 1:5000 + current = Node(i + 1, current, NodeMTG("<", "Segment", i, 2)) + end + + out = traverse(root, node -> node_id(node), type=Int) + @test length(out) == 5001 + @test out[1] == 1 + @test out[end] == 5001 + + n = Ref(0) + traverse!(root, _ -> n[] += 1) + @test n[] == 5001 +end