Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions AGENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# MultiScaleTreeGraph Performance Agent Notes

## Goal
- Optimize traversal-heavy workloads for very large trees.
- Prioritize low allocations and type-stable code paths.

## Benchmark Commands
- Local suite:
- `julia --project=benchmark benchmark/benchmarks.jl`
- Package tests (workaround for current precompile deadlock on Julia 1.12.1):
- `julia --project --compiled-modules=no -e 'using Pkg; Pkg.test()'`

## CI Benchmarks
- Uses `AirspeedVelocity.jl` via `.github/workflows/Benchmarks.yml`.
- Benchmark definitions live in `benchmark/benchmarks.jl` and must expose `const SUITE`.

## Current Hot Paths
- `src/compute_MTG/traverse.jl`
- `src/compute_MTG/ancestors.jl`
- `src/compute_MTG/descendants.jl`
- `src/compute_MTG/indexing.jl`
- `src/compute_MTG/check_filters.jl`
- `src/types/Node.jl`
- `src/compute_MTG/node_funs.jl`

## Practical Optimization Rules
- Avoid allocating temporary arrays in per-node loops.
- Prefer in-place APIs for repeated queries:
- `ancestors!(buffer, node, key; ...)`
- `descendants!(buffer, node, key; ...)`
- Keep filter checks branch-light when no filters are provided.
- Keep key access on typed attribute containers (`NamedTuple`, `MutableNamedTuple`, typed dicts) in specialized methods when possible.
- Preserve API behavior and add tests for every optimization that changes internals.
26 changes: 7 additions & 19 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,12 @@ function ancestors_workload(nodes, reps::Int)
return s
end

function ancestors_workload_inplace_1(nodes, reps::Int)
s = 0.0
@inbounds for _ in 1:reps
for n in nodes
out = ancestors!(n, :mass, recursivity_level=4, type=Float64)
for v in out
s += v
end
end
end
return s
end

function ancestors_workload_inplace_2(nodes, reps::Int)
function ancestors_workload_inplace(nodes, reps::Int)
s = 0.0
buf = Float64[]
@inbounds for _ in 1:reps
for n in nodes
ancestors!(buf, n, :mass, recursivity_level=4, type=Float64)
ancestors!(buf, n, :mass, recursivity_level=4)
for v in buf
s += v
end
Expand All @@ -130,7 +117,7 @@ end

function descendants_extraction_workload_inplace_2(root)
vals = Float64[]
descendants!(vals, root, :mass, type=Float64)
descendants!(vals, root, :mass)
end

suite_name = "mstg"
Expand All @@ -151,17 +138,18 @@ SUITE[suite_name] = BenchmarkGroup([
root, leaves, sample_nodes = synthetic_tree()
SUITE[suite_name]["traverse"]["full_tree_nodes"] = @benchmarkable traverse!($root, _ -> nothing)
SUITE[suite_name]["traverse_extract"]["descendants_mass"] = @benchmarkable descendants_extraction_workload($root)
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_1($root)
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_1"] = @benchmarkable descendants_extraction_workload_inplace_1($root)

# Add this one only if we have a method for `descendants!(val, node, key, type)`
if hasmethod(descendants!, Tuple{AbstractVector,Node,Symbol})
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace"] = @benchmarkable descendants_extraction_workload_inplace_2($root)
SUITE[suite_name]["traverse_extract"]["descendants_mass_inplace_2"] = @benchmarkable descendants_extraction_workload_inplace_2($root)
end

SUITE[suite_name]["many_queries"]["children_repeated"] = @benchmarkable children_workload($sample_nodes, 300)
SUITE[suite_name]["many_queries"]["parent_repeated"] = @benchmarkable parent_workload($sample_nodes, 300)
SUITE[suite_name]["many_queries"]["ancestors_repeated"] = @benchmarkable ancestors_workload($leaves, 40)
if hasmethod(ancestors!, Tuple{AbstractVector,Node,Symbol})
# Test if ancestors! exists in the package first:
if isdefined(MultiScaleTreeGraph, :ancestors!)
SUITE[suite_name]["many_queries"]["ancestors_repeated_inplace"] = @benchmarkable ancestors_workload_inplace($leaves, 40)
end

Expand Down
2 changes: 1 addition & 1 deletion src/MultiScaleTreeGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export insert_parents!, insert_generations!, insert_children!, insert_siblings!
export insert_parent!, insert_generation!, insert_child!, insert_sibling!
export write_mtg
export is_segment!
export descendants, ancestors, descendants!
export descendants, ancestors, ancestors!, descendants!
export Node
export AbstractNodeMTG
export NodeMTG
Expand Down
147 changes: 123 additions & 24 deletions src/compute_MTG/ancestors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ function ancestors(

# Change the filtering function if we also want to remove nodes with nothing values.
filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key)
use_no_filter = no_node_filters(scale, symbol, link, filter_fun_)

val = Array{type,1}()
# Put the recursivity level into an array so it is mutable in-place:

if self
if is_filtered(node, scale, symbol, link, filter_fun_)
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_)
val_ = unsafe_getindex(node, key)
push!(val, val_)
elseif !all
Expand All @@ -91,31 +92,47 @@ function ancestors(
end
end

ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
if use_no_filter
ancestors_values_no_filter!(node, key, val, recursivity_level)
else
ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, val, recursivity_level)
end
return val
end


function ancestors_(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)

if !isroot(node) && recursivity_level != 0
parent_ = parent(node)
function ancestors_values!(node, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
current = node
remaining = recursivity_level

# Is there any filter happening for the current node? (FALSE if filtered out):
while !isroot(current) && remaining != 0
parent_ = parent(current)
keep = is_filtered(parent_, scale, symbol, link, filter_fun)

if keep
val_ = unsafe_getindex(parent_, key)
push!(val, val_)
push!(val, unsafe_getindex(parent_, key))
# Only decrement the recursivity level when the current node is not filtered-out
recursivity_level -= 1
remaining -= 1
end

# If we want to continue even if the current node is filtered-out
if all || keep
ancestors_(parent_, key, scale, symbol, link, all, filter_fun, val, recursivity_level)
end
(all || keep) || break
current = parent_
end
return val
end

function ancestors_values_no_filter!(node, key, val, recursivity_level)
current = node
remaining = recursivity_level

while !isroot(current) && remaining != 0
parent_ = parent(current)
push!(val, unsafe_getindex(parent_, key))
remaining -= 1
current = parent_
end
return val
end

# Version that returns the nodes instead of the values:
Expand All @@ -133,41 +150,123 @@ function ancestors(
# Check the filters once, and then compute the ancestors recursively using `ancestors_`
check_filters(node, scale=scale, symbol=symbol, link=link)

# Change the filtering function if we also want to remove nodes with nothing values.
use_no_filter = no_node_filters(scale, symbol, link, filter_fun)
val = Array{typeof(node),1}()
# Put the recursivity level into an array so it is mutable in-place:

if self
if is_filtered(node, scale, symbol, link, filter_fun)
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun)
push!(val, node)
elseif !all
# We don't keep the value and we have to stop at the first filtered-out value
return val
end
end

ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
if use_no_filter
ancestors_nodes_no_filter!(node, val, recursivity_level)
else
ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
end
return val
end


function ancestors_(node, scale, symbol, link, all, filter_fun, val, recursivity_level)

if !isroot(node) && recursivity_level != 0
parent_ = parent(node)
function ancestors_nodes!(node, scale, symbol, link, all, filter_fun, val, recursivity_level)
current = node
remaining = recursivity_level

# Is there any filter happening for the current node? (FALSE if filtered out):
while !isroot(current) && remaining != 0
parent_ = parent(current)
keep = is_filtered(parent_, scale, symbol, link, filter_fun)

if keep
push!(val, parent_)
# Only decrement the recursivity level when the current node is not filtered-out
recursivity_level -= 1
remaining -= 1
end

# If we want to continue even if the current node is filtered-out
if all || keep
ancestors_(parent_, scale, symbol, link, all, filter_fun, val, recursivity_level)
(all || keep) || break
current = parent_
end
return val
end

function ancestors_nodes_no_filter!(node, val, recursivity_level)
current = node
remaining = recursivity_level

while !isroot(current) && remaining != 0
parent_ = parent(current)
push!(val, parent_)
remaining -= 1
current = parent_
end
return val
end

function ancestors!(
out::AbstractVector,
node, key;
scale=nothing,
symbol=nothing,
link=nothing,
all::Bool=true,
self=false,
filter_fun=nothing,
recursivity_level=-1,
ignore_nothing=false,
type::Union{Union,DataType}=Any,
)
check_filters(node, scale=scale, symbol=symbol, link=link)
filter_fun_ = filter_fun_nothing(filter_fun, ignore_nothing, key)
use_no_filter = no_node_filters(scale, symbol, link, filter_fun_)

empty!(out)
if self
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun_)
push!(out, unsafe_getindex(node, key))
elseif !all
return out
end
end

if use_no_filter
ancestors_values_no_filter!(node, key, out, recursivity_level)
else
ancestors_values!(node, key, scale, symbol, link, all, filter_fun_, out, recursivity_level)
end
return out
end

function ancestors!(
out::AbstractVector,
node;
scale=nothing,
symbol=nothing,
link=nothing,
all::Bool=true,
self=false,
filter_fun=nothing,
recursivity_level=-1,
)
check_filters(node, scale=scale, symbol=symbol, link=link)
use_no_filter = no_node_filters(scale, symbol, link, filter_fun)

empty!(out)
if self
if use_no_filter || is_filtered(node, scale, symbol, link, filter_fun)
push!(out, node)
elseif !all
return out
end
end

if use_no_filter
ancestors_nodes_no_filter!(node, out, recursivity_level)
else
ancestors_nodes!(node, scale, symbol, link, all, filter_fun, out, recursivity_level)
end
return out
end
9 changes: 8 additions & 1 deletion src/compute_MTG/check_filters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ check_filters(mtg, scale = (1,2))
check_filters(mtg, scale = (1,2), symbol = "Leaf", link = "<")
```
"""
@inline no_node_filters(scale, symbol, link, filter_fun=nothing) =
isnothing(scale) && isnothing(symbol) && isnothing(link) && isnothing(filter_fun)

function check_filters(node::Node{N,A}; scale=nothing, symbol=nothing, link=nothing) where {N<:AbstractNodeMTG,A}
no_node_filters(scale, symbol, link) && return nothing

root_node = get_root(node)

Expand Down Expand Up @@ -76,7 +80,10 @@ end
end

@inline function is_filtered(filter, value::T) where {T<:Union{Tuple,Array}}
all(map(x -> is_filtered(filter, x), value))
for x in value
is_filtered(filter, x) || return false
end
return true
end


Expand Down
Loading
Loading