Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 129 additions & 11 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ const SPAWN_FUNCTIONS: &[&str] = &["spawn", "scope", "scope_fifo", "join"];
/// Scope-like functions whose closures coordinate workers (don't adopt, recurse body).
const SCOPE_FUNCTIONS: &[&str] = &["scope", "scope_fifo"];

/// Functions that create scoped concurrency boundaries where closures can
/// capture non-'static references, making fork/adopt injection safe.
///
/// "spawn" is intentionally excluded: all free-function spawns
/// (std::thread::spawn, rayon::spawn, tokio::spawn) require F: Send + 'static.
/// Moving the owned SpanContext into the closure compiles, but CPU time
/// attribution breaks: SpanContext::drop calls apply_children() which accesses
/// the STACK thread-local. When SpanContext drops on the child thread, the
/// child's stack is empty (AdoptGuard already popped the synthetic parent),
/// so the child's CPU contribution is silently lost.
///
/// Scoped s.spawn() method calls inside scope closure bodies are handled
/// separately by recurse_closure_body_for_spawns -- they work because the
/// scope guarantees the closure completes before the parent returns, so
/// SpanContext stays on the parent thread.
const FORK_TRIGGER_FUNCTIONS: &[&str] = &["scope", "scope_fifo", "join"];

struct Instrumenter {
targets: HashSet<String>,
current_impl: Option<String>,
Expand Down Expand Up @@ -106,8 +123,8 @@ impl Instrumenter {
/// Find the first concurrency pattern in a block and return its name.
///
/// For method calls matching PARALLEL_ITER_METHODS: returns the method name (e.g. "par_iter").
/// For method calls matching SPAWN_FUNCTIONS: returns the method name (e.g. "spawn").
/// For function calls matching SPAWN_FUNCTIONS: returns the full path (e.g. "rayon::scope").
/// For method calls matching FORK_TRIGGER_FUNCTIONS: returns the method name (e.g. "scope").
/// For function calls matching FORK_TRIGGER_FUNCTIONS: returns the full path (e.g. "rayon::scope").
fn find_concurrency_pattern(block: &syn::Block) -> Option<String> {
block.stmts.iter().find_map(find_pattern_in_stmt)
}
Expand All @@ -130,7 +147,7 @@ fn find_pattern_in_expr(expr: &syn::Expr) -> Option<String> {
if PARALLEL_ITER_METHODS.contains(&method.as_str()) {
return Some(method);
}
if SPAWN_FUNCTIONS.contains(&method.as_str()) {
if FORK_TRIGGER_FUNCTIONS.contains(&method.as_str()) {
return Some(method);
}
if let Some(p) = find_pattern_in_expr(&mc.receiver) {
Expand All @@ -142,7 +159,7 @@ fn find_pattern_in_expr(expr: &syn::Expr) -> Option<String> {
if let syn::Expr::Path(path) = &*call.func {
let last_seg = path.path.segments.last().map(|s| s.ident.to_string());
if let Some(ref name) = last_seg
&& SPAWN_FUNCTIONS.contains(&name.as_str())
&& FORK_TRIGGER_FUNCTIONS.contains(&name.as_str())
{
// Build the full path, e.g. "rayon::scope"
let full_path: String = path
Expand All @@ -155,6 +172,14 @@ fn find_pattern_in_expr(expr: &syn::Expr) -> Option<String> {
return Some(full_path);
}
}
// Don't recurse into detached spawn args -- anything inside
// inherits the 'static boundary, so fork/adopt can't help.
if let syn::Expr::Path(path) = &*call.func {
let last = path.path.segments.last().map(|s| s.ident.to_string());
if last.as_deref() == Some("spawn") {
return None;
}
}
call.args.iter().find_map(find_pattern_in_expr)
}
syn::Expr::Block(b) => b.block.stmts.iter().find_map(find_pattern_in_stmt),
Expand Down Expand Up @@ -218,6 +243,7 @@ fn inject_adopt_in_concurrency_closures(expr: &mut syn::Expr, in_parallel_chain:
let is_scope = func_name
.as_ref()
.is_some_and(|n| SCOPE_FUNCTIONS.contains(&n.as_str()));
let is_detached = func_name.as_deref() == Some("spawn");

if is_scope {
for arg in &mut call.args {
Expand All @@ -227,14 +253,19 @@ fn inject_adopt_in_concurrency_closures(expr: &mut syn::Expr, in_parallel_chain:
inject_adopt_in_concurrency_closures(arg, false);
}
}
} else if is_spawn {
} else if is_spawn && !is_detached {
for arg in &mut call.args {
if let syn::Expr::Closure(closure) = arg {
inject_adopt_at_closure_start(closure);
} else {
inject_adopt_in_concurrency_closures(arg, false);
}
}
} else if is_detached {
// Detached spawn (std::thread::spawn, rayon::spawn, etc.)
// Don't inject adopt -- can't cross 'static boundary.
// Don't recurse into closure body -- nested scopes can't
// use the parent's fork either.
} else {
for arg in &mut call.args {
inject_adopt_in_concurrency_closures(arg, false);
Expand Down Expand Up @@ -955,7 +986,7 @@ fn process_all(items: &[Item]) -> Vec<Result> {
}

#[test]
fn injects_fork_and_adopt_for_thread_spawn() {
fn skips_fork_for_thread_spawn() {
let source = r#"
fn do_work() {
std::thread::spawn(|| {
Expand All @@ -964,15 +995,102 @@ fn do_work() {
}
"#;
let targets: HashSet<String> = ["do_work".to_string()].into();
let result = instrument_source(source, &targets).unwrap().source;
let result = instrument_source(source, &targets).unwrap();

assert!(
result.contains("piano_runtime::fork()"),
"should inject fork. Got:\n{result}"
!result.source.contains("piano_runtime::fork()"),
"should NOT inject fork for std::thread::spawn. Got:\n{}",
result.source
);
assert!(
result.contains("piano_runtime::adopt"),
"should inject adopt in spawn closure. Got:\n{result}"
!result.source.contains("piano_runtime::adopt"),
"should NOT inject adopt for std::thread::spawn. Got:\n{}",
result.source
);
assert!(
result.source.contains("piano_runtime::enter(\"do_work\")"),
"should still inject enter guard. Got:\n{}",
result.source
);
// Detached spawns should not report concurrency (no fork/adopt to act on)
assert!(
result.concurrency.is_empty(),
"should not report concurrency for detached spawn. Got: {:?}",
result.concurrency
);
}

#[test]
fn mixed_scope_and_thread_spawn() {
let source = r#"
fn mixed() {
rayon::scope(|s| {
s.spawn(|_| { work_a(); });
});
std::thread::spawn(|| {
work_b();
});
}
"#;
let targets: HashSet<String> = ["mixed".to_string()].into();
let result = instrument_source(source, &targets).unwrap();

// Fork should be injected (rayon::scope triggers it)
assert!(
result.source.contains("piano_runtime::fork()"),
"should inject fork for rayon::scope. Got:\n{}",
result.source
);
// Adopt should appear (for s.spawn inside scope)
assert!(
result.source.contains("piano_runtime::adopt"),
"should inject adopt for scoped s.spawn. Got:\n{}",
result.source
);

// Count adopt occurrences -- should be exactly 1 (in s.spawn, NOT in thread::spawn)
let adopt_count = result.source.matches("piano_runtime::adopt").count();
assert_eq!(
adopt_count, 1,
"should have exactly 1 adopt (in s.spawn), not in thread::spawn. Got {adopt_count} in:\n{}",
result.source
);
}

#[test]
fn skips_fork_for_short_path_thread_spawn() {
// `use std::thread::spawn; spawn(|| ...)` -- bare name, no path prefix
let source = r#"
fn do_work() {
spawn(|| {
heavy_computation();
});
}
"#;
let targets: HashSet<String> = ["do_work".to_string()].into();
let result = instrument_source(source, &targets).unwrap();

assert!(
!result.source.contains("piano_runtime::fork()"),
"should NOT inject fork for bare spawn(). Got:\n{}",
result.source
);
}

#[test]
fn no_concurrency_for_thread_spawn() {
let source = r#"
fn do_work() {
std::thread::spawn(|| { work(); });
}
"#;
let targets: HashSet<String> = ["do_work".to_string()].into();
let result = instrument_source(source, &targets).unwrap();

assert!(
result.concurrency.is_empty(),
"detached spawn should not report concurrency. Got: {:?}",
result.concurrency
);
}

Expand Down