From c90af4a494478d66fc54f44a15c0d55ed59db2e0 Mon Sep 17 00:00:00 2001 From: Aravind Sagar Date: Fri, 22 May 2026 03:54:12 +0000 Subject: [PATCH 1/5] feat: Wire up native cancellation stats from Rust to Java Implement the cancellation stats counters in the Rust native layer: - Add QueryType enum (Shard/Coordinator) to distinguish query origins - Record cancelled_at timestamp when cancel_query() fires - Compute current_count_post_cancel via live registry scan - Increment total_count_post_cancel on Drop when past threshold - Make threshold configurable via df_set_cancel_stats_threshold_ms - Thread QueryType through all QueryTrackingContext call sites Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Aravind Sagar --- .../rust/src/api.rs | 7 +- .../rust/src/ffm.rs | 7 + .../rust/src/indexed_executor.rs | 2 +- .../rust/src/local_executor.rs | 2 +- .../rust/src/native_node_stats.rs | 223 ++-------- .../rust/src/query_tracker.rs | 383 +++++++++++++++++- .../rust/src/session_context.rs | 4 +- .../rust/tests/local_exec_test.rs | 2 +- .../be/datafusion/nativelib/NativeBridge.java | 15 + 9 files changed, 444 insertions(+), 201 deletions(-) diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs index 1a96cf8816a11..39e86514e08db 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs @@ -391,7 +391,8 @@ pub async unsafe fn execute_query( // Create per-query context — auto-registers in the global registry let global_pool = runtime.runtime_env.memory_pool.clone(); - let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone()); + let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone(), query_tracker::QueryType::Shard); + let query_memory_pool = query_context .memory_pool() .map(|p| p as Arc); @@ -1059,7 +1060,7 @@ pub async unsafe fn execute_local_plan( // Per-query memory tracking — wraps the session's global pool. A // `context_id` of 0 disables tracking (pool is not consulted) and no // cancellation token is registered in the global QUERY_REGISTRY. - let query_context = QueryTrackingContext::new(context_id, session.memory_pool()); + let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator); let token = query_tracker::get_cancellation_token(context_id); // Race substrait planning + execution against the cancellation token so @@ -1108,7 +1109,7 @@ pub unsafe fn execute_local_prepared_plan( // so a `cancel_query(context_id)` call fires the token here too. // The token is held via the QueryStreamHandle's context and consulted by // stream_next on each batch pull. - let query_context = QueryTrackingContext::new(context_id, session.memory_pool()); + let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator); // DataFusion's execute_stream is sync, but kicks off RepartitionExec / // stream channels that require a Tokio reactor. Enter the IO runtime's diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs index d421c718b0fe6..7952c139736c1 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs @@ -254,6 +254,13 @@ pub extern "C" fn df_cancel_query(context_id: i64) { api::cancel_query(context_id); } +/// Sets the cancellation stats threshold in milliseconds. +/// Queries cancelled for less than this duration are not counted in stats. +#[no_mangle] +pub extern "C" fn df_set_cancel_stats_threshold_ms(millis: i64) { + crate::query_tracker::set_cancel_stats_threshold(millis as u64); +} + #[ffm_safe] #[no_mangle] pub unsafe extern "C" fn df_sql_to_substrait( diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs index 6626b9afc9b60..87f9abe6ac7f2 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs @@ -171,7 +171,7 @@ pub async fn execute_indexed_query( table_path: shard_view.table_path.clone(), object_metas: shard_view.object_metas.clone(), writer_generations: shard_view.writer_generations.clone(), - query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone()), + query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone(), crate::query_tracker::QueryType::Shard), table_name: table_name.clone(), indexed_config: None, // derive classification from tree query_config: Arc::unwrap_or_clone(query_config), diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs index 29f8f86b9222d..c6ac2acc4acfc 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs @@ -432,7 +432,7 @@ mod tests { let ctx_id = 98_765; let pool: Arc = Arc::new(GreedyMemoryPool::new(10_000)); - let _tracking = QueryTrackingContext::new(ctx_id, pool); + let _tracking = QueryTrackingContext::new(ctx_id, pool, query_tracker::QueryType::Coordinator); // A future that would block indefinitely — `cancel_query` is the // only way out. Mirrors a coord reduce stalled on an input partition diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs index 2325cf942292e..5ce1570a47a86 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs @@ -18,44 +18,25 @@ use std::sync::atomic::{AtomicI64, Ordering}; // --------------------------------------------------------------------------- -// 4 independent atomic counters — no struct, no lock +// Total counters — incremented by QueryTrackingContext::drop() when a +// cancelled query ran past the threshold before completing. // --------------------------------------------------------------------------- -static NATIVE_SEARCH_TASK_CURRENT: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_SHARD_TASK_CURRENT: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); +pub(crate) static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); +pub(crate) static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); // --------------------------------------------------------------------------- -// Public increment/decrement functions for producers +// Public increment functions for producers (called from query_tracker Drop) // --------------------------------------------------------------------------- -/// Increments the current count of native search tasks executing post-cancellation. -pub fn inc_native_search_task_current() { - NATIVE_SEARCH_TASK_CURRENT.fetch_add(1, Ordering::Relaxed); -} - -/// Decrements the current count of native search tasks executing post-cancellation. -pub fn dec_native_search_task_current() { - NATIVE_SEARCH_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed); -} - -/// Increments the total count of native search tasks that have executed post-cancellation. +/// Increments the total count of native search tasks that executed post-cancellation +/// beyond the threshold duration. pub fn inc_native_search_task_total() { NATIVE_SEARCH_TASK_TOTAL.fetch_add(1, Ordering::Relaxed); } -/// Increments the current count of native search shard tasks executing post-cancellation. -pub fn inc_native_search_shard_task_current() { - NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_add(1, Ordering::Relaxed); -} - -/// Decrements the current count of native search shard tasks executing post-cancellation. -pub fn dec_native_search_shard_task_current() { - NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed); -} - -/// Increments the total count of native search shard tasks that have executed post-cancellation. +/// Increments the total count of native search shard tasks that executed post-cancellation +/// beyond the threshold duration. pub fn inc_native_search_shard_task_total() { NATIVE_SEARCH_SHARD_TASK_TOTAL.fetch_add(1, Ordering::Relaxed); } @@ -64,7 +45,9 @@ pub fn inc_native_search_shard_task_total() { // FFM entry point // --------------------------------------------------------------------------- -/// Reads 4 AtomicI64 counters and writes them as `[i64; 4]` to the caller buffer. +/// Reads task cancellation stats and writes them as `[i64; 4]` to the caller buffer. +/// `current` counts are computed by scanning the live query registry. +/// `total` counts come from the atomic counters (incremented on query drop). /// Returns 0 on success, -1 if buffer too small. /// /// Buffer layout (32 bytes): @@ -77,14 +60,21 @@ pub fn inc_native_search_shard_task_total() { /// `out_ptr` must point to a valid buffer of at least `out_cap` bytes. #[no_mangle] pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> i64 { + use crate::query_tracker; + if (out_cap as usize) < 32 { return -1; } + // Scan registry for currently-running cancelled queries past threshold. + let (shard_current, coordinator_current) = + query_tracker::count_cancelled_running(query_tracker::cancel_stats_threshold()); + // total = completed + current. No double-counting race because Drop removes + // from registry BEFORE incrementing total — a query is never in both. let vals: [i64; 4] = [ - NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), - NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), - NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), - NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), + coordinator_current, + NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed) + coordinator_current, + shard_current, + NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed) + shard_current, ]; std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, out_ptr, 32); 0 @@ -96,41 +86,25 @@ pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> #[cfg(test)] pub(crate) fn reset_all_counters() { - NATIVE_SEARCH_TASK_CURRENT.store(0, Ordering::Relaxed); NATIVE_SEARCH_TASK_TOTAL.store(0, Ordering::Relaxed); - NATIVE_SEARCH_SHARD_TASK_CURRENT.store(0, Ordering::Relaxed); NATIVE_SEARCH_SHARD_TASK_TOTAL.store(0, Ordering::Relaxed); } #[cfg(test)] mod tests { use super::*; - use proptest::prelude::*; - /// Reset counters before each test to avoid cross-test interference. fn setup() { reset_all_counters(); } #[test] - fn test_counters_initialized_to_zero() { + fn test_total_counters_initialized_to_zero() { setup(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 0); assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0); assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0); } - #[test] - fn test_inc_dec_native_search_task_current() { - setup(); - inc_native_search_task_current(); - inc_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 2); - dec_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1); - } - #[test] fn test_inc_native_search_task_total() { setup(); @@ -140,48 +114,12 @@ mod tests { assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 3); } - #[test] - fn test_inc_dec_native_search_shard_task_current() { - setup(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_current(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 3); - dec_native_search_shard_task_current(); - dec_native_search_shard_task_current(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 1); - } - #[test] fn test_inc_native_search_shard_task_total() { setup(); inc_native_search_shard_task_total(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 1); - } - - #[test] - fn test_df_native_node_stats_reads_correct_values() { - setup(); - // Set up known state - inc_native_search_task_current(); - inc_native_search_task_current(); - inc_native_search_task_total(); - inc_native_search_task_total(); - inc_native_search_task_total(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_total(); inc_native_search_shard_task_total(); - - let mut buf = [0u8; 32]; - let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; - assert_eq!(rc, 0); - - // Decode the buffer as [i64; 4] - let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - assert_eq!(vals[0], 2); // native_search_task_current - assert_eq!(vals[1], 3); // native_search_task_total - assert_eq!(vals[2], 1); // native_search_shard_task_current - assert_eq!(vals[3], 2); // native_search_shard_task_total + assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 2); } #[test] @@ -207,103 +145,24 @@ mod tests { } #[test] - fn test_counters_are_independent() { - setup(); - // Increment only one counter and verify others remain zero - inc_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1); - assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0); - } - - // ----------------------------------------------------------------------- - // Property-based test: AtomicI64 increment/decrement correctness - // ----------------------------------------------------------------------- - - /// **Validates: Requirements 1.1, 2.1–2.6** - /// - /// Property 1: AtomicI64 increment/decrement correctness - /// - /// For any random number of increments and decrements applied to each of - /// the 4 counters, reading the counters via `df_native_node_stats` SHALL - /// return values equal to the algebraic sum (increments - decrements) for - /// each counter. - proptest! { - #[test] - fn prop_atomic_inc_dec_correctness( - search_task_current_incs in 0u32..100, - search_task_current_decs in 0u32..100, - search_task_total_incs in 0u32..100, - shard_task_current_incs in 0u32..100, - shard_task_current_decs in 0u32..100, - shard_task_total_incs in 0u32..100, - ) { - // Reset counters before each iteration - reset_all_counters(); - - // Apply increments and decrements for NATIVE_SEARCH_TASK_CURRENT - for _ in 0..search_task_current_incs { - inc_native_search_task_current(); - } - for _ in 0..search_task_current_decs { - dec_native_search_task_current(); - } - - // Apply only increments for NATIVE_SEARCH_TASK_TOTAL (no dec function exists) - for _ in 0..search_task_total_incs { - inc_native_search_task_total(); - } - - // Apply increments and decrements for NATIVE_SEARCH_SHARD_TASK_CURRENT - for _ in 0..shard_task_current_incs { - inc_native_search_shard_task_current(); - } - for _ in 0..shard_task_current_decs { - dec_native_search_shard_task_current(); - } - - // Apply only increments for NATIVE_SEARCH_SHARD_TASK_TOTAL (no dec function exists) - for _ in 0..shard_task_total_incs { - inc_native_search_shard_task_total(); - } - - // Expected algebraic sums - let expected_search_task_current = - search_task_current_incs as i64 - search_task_current_decs as i64; - let expected_search_task_total = search_task_total_incs as i64; - let expected_shard_task_current = - shard_task_current_incs as i64 - shard_task_current_decs as i64; - let expected_shard_task_total = shard_task_total_incs as i64; + fn test_df_native_node_stats_total_counters() { + // Read baseline (other tests may run in parallel and share globals) + let mut buf = [0u8; 32]; + let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; + assert_eq!(rc, 0); + let baseline: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - // Read counters via df_native_node_stats - let mut buf = [0u8; 32]; - let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; - prop_assert_eq!(rc, 0); + inc_native_search_task_total(); + inc_native_search_task_total(); + inc_native_search_shard_task_total(); - // Decode the buffer as [i64; 4] - let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; + let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; + assert_eq!(rc, 0); + let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - prop_assert_eq!( - vals[0], expected_search_task_current, - "native_search_task_current mismatch: incs={}, decs={}", - search_task_current_incs, search_task_current_decs - ); - prop_assert_eq!( - vals[1], expected_search_task_total, - "native_search_task_total mismatch: incs={}", - search_task_total_incs - ); - prop_assert_eq!( - vals[2], expected_shard_task_current, - "native_search_shard_task_current mismatch: incs={}, decs={}", - shard_task_current_incs, shard_task_current_decs - ); - prop_assert_eq!( - vals[3], expected_shard_task_total, - "native_search_shard_task_total mismatch: incs={}", - shard_task_total_incs - ); - } + // vals[1] = coordinator total — should be baseline + 2 + assert_eq!(vals[1], baseline[1] + 2); + // vals[3] = shard total — should be baseline + 1 + assert_eq!(vals[3], baseline[3] + 1); } } diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs index 825d45d6d7df2..0188877da5a6a 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs @@ -19,17 +19,50 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; -use std::time::Instant; +use std::time::{Duration, Instant}; use dashmap::DashMap; use log::debug; use once_cell::sync::Lazy; +use parking_lot::Mutex; use tokio::task::AbortHandle; use tokio_util::sync::CancellationToken; use datafusion::common::DataFusionError; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +// --------------------------------------------------------------------------- +// Query type discriminator +// --------------------------------------------------------------------------- + +/// Distinguishes shard-level queries from coordinator-level queries for stats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QueryType { + /// Data-node shard fragment execution (AnalyticsShardTask on the Java side). + Shard, + /// Coordinator-side local reduce execution (AnalyticsQueryTask on the Java side). + Coordinator, +} + +/// Default threshold for "long-running post-cancel" — matches the Java-side +/// `task_cancellation.duration_millis` default of 10 000 ms. +pub const DEFAULT_CANCEL_THRESHOLD: Duration = Duration::from_secs(10); + +/// Runtime-configurable threshold. Initialized to DEFAULT_CANCEL_THRESHOLD (10_000 ms). +/// Set via `set_cancel_stats_threshold`. +static CANCEL_STATS_THRESHOLD_MS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(10_000); + +/// Returns the currently configured cancellation stats threshold. +pub fn cancel_stats_threshold() -> Duration { + Duration::from_millis(CANCEL_STATS_THRESHOLD_MS.load(Ordering::Relaxed)) +} + +/// Sets the cancellation stats threshold (in milliseconds). +pub fn set_cancel_stats_threshold(millis: u64) { + CANCEL_STATS_THRESHOLD_MS.store(millis, Ordering::Relaxed); +} + // --------------------------------------------------------------------------- // Per-query memory pool // --------------------------------------------------------------------------- @@ -115,10 +148,13 @@ impl MemoryPool for QueryMemoryPool { pub struct QueryTracker { pub start_time: Instant, pub context_id: i64, + pub query_type: QueryType, pub memory_pool: Arc, pub cancellation_token: CancellationToken, /// CPU task abort handle, set after the stream is created. pub abort_handle: OnceLock, + /// Instant when cancellation was signalled, or None if not cancelled. + pub cancelled_at: Mutex>, completed: AtomicBool, wall_nanos: std::sync::atomic::AtomicU64, } @@ -157,6 +193,12 @@ static QUERY_REGISTRY: Lazy>> = Lazy::new(DashMap /// No-op for unknown or already-completed queries. pub fn cancel_query(context_id: i64) { if let Some(tracker) = QUERY_REGISTRY.get(&context_id) { + { + let mut cancelled_at = tracker.cancelled_at.lock(); + if cancelled_at.is_none() { + *cancelled_at = Some(Instant::now()); + } + } tracker.cancellation_token.cancel(); if let Some(handle) = tracker.abort_handle.get() { handle.abort(); @@ -176,6 +218,25 @@ pub fn set_abort_handle(context_id: i64, handle: AbortHandle) { } } +/// Counts queries currently running past the cancellation threshold, by type. +/// Returns (shard_current, coordinator_current). +pub fn count_cancelled_running(threshold: Duration) -> (i64, i64) { + let mut shard_count: i64 = 0; + let mut coordinator_count: i64 = 0; + for entry in QUERY_REGISTRY.iter() { + let tracker = entry.value(); + if let Some(cancelled_at) = *tracker.cancelled_at.lock() { + if cancelled_at.elapsed() >= threshold { + match tracker.query_type { + QueryType::Shard => shard_count += 1, + QueryType::Coordinator => coordinator_count += 1, + } + } + } + } + (shard_count, coordinator_count) +} + // --------------------------------------------------------------------------- // QueryTrackingContext // --------------------------------------------------------------------------- @@ -196,7 +257,7 @@ pub struct QueryTrackingContext { impl QueryTrackingContext { /// Create a new query context. If `context_id` is 0, tracking is /// disabled and `memory_pool()` returns `None`. - pub fn new(context_id: i64, global_pool: Arc) -> Self { + pub fn new(context_id: i64, global_pool: Arc, query_type: QueryType) -> Self { if context_id == 0 { return Self { tracker: None, phantom_reservation: None, phantom_corrector: None }; } @@ -204,9 +265,11 @@ impl QueryTrackingContext { let tracker = Arc::new(QueryTracker { start_time: Instant::now(), context_id, + query_type, memory_pool: query_pool, cancellation_token: CancellationToken::new(), abort_handle: OnceLock::new(), + cancelled_at: Mutex::new(None), completed: AtomicBool::new(false), wall_nanos: std::sync::atomic::AtomicU64::new(0), }); @@ -279,7 +342,26 @@ impl Drop for QueryTrackingContext { tracker.memory_pool.current_bytes(), tracker.memory_pool.peak_bytes(), ); + + // Remove from registry BEFORE incrementing total. This ensures a + // query is never simultaneously visible in both the registry scan + // (current) and the total counter — preventing double-counting in + // snapshot_cancellation_stats. QUERY_REGISTRY.remove(&tracker.context_id); + + // If this query was cancelled and ran past the threshold, bump the total counter. + if let Some(cancelled_at) = *tracker.cancelled_at.lock() { + if cancelled_at.elapsed() >= cancel_stats_threshold() { + match tracker.query_type { + QueryType::Shard => { + crate::native_node_stats::inc_native_search_shard_task_total(); + } + QueryType::Coordinator => { + crate::native_node_stats::inc_native_search_task_total(); + } + } + } + } } } } @@ -362,7 +444,7 @@ mod tests { #[test] fn test_context_returns_none_pool_for_zero_id() { let global = make_global_pool(10_000); - let ctx = QueryTrackingContext::new(0, global); + let ctx = QueryTrackingContext::new(0, global, QueryType::Shard); assert!(ctx.memory_pool().is_none()); } @@ -370,7 +452,7 @@ mod tests { fn test_context_registers_and_removes_on_drop() { let global = make_global_pool(10_000); let ctx_id = 50_000; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); assert!(ctx.memory_pool().is_some()); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); @@ -383,7 +465,7 @@ mod tests { fn test_drop_removes_from_registry() { let global = make_global_pool(10_000); let ctx_id = 50_001; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); thread::sleep(Duration::from_millis(50)); @@ -397,7 +479,7 @@ mod tests { fn test_wall_secs_ticks_while_running() { let global = make_global_pool(10_000); let ctx_id = 50_002; - let _ctx = QueryTrackingContext::new(ctx_id, global); + let _ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let t1 = QUERY_REGISTRY.get(&ctx_id).unwrap().wall_secs(); thread::sleep(Duration::from_millis(50)); @@ -412,7 +494,7 @@ mod tests { fn test_memory_tracking_through_full_lifecycle() { let global = make_global_pool(1_000_000); let ctx_id = 50_004; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let qp = ctx.memory_pool().unwrap(); let pool: Arc = qp.clone(); let mut reservation = make_reservation(&pool, "lifecycle_test"); @@ -446,8 +528,8 @@ mod tests { let ctx_a_id = 50_005; let ctx_b_id = 50_006; - let ctx_a = QueryTrackingContext::new(ctx_a_id, Arc::clone(&global)); - let ctx_b = QueryTrackingContext::new(ctx_b_id, Arc::clone(&global)); + let ctx_a = QueryTrackingContext::new(ctx_a_id, Arc::clone(&global), QueryType::Shard); + let ctx_b = QueryTrackingContext::new(ctx_b_id, Arc::clone(&global), QueryType::Shard); let pool_a = ctx_a.memory_pool().unwrap(); let pool_b = ctx_b.memory_pool().unwrap(); @@ -483,7 +565,7 @@ mod tests { let global = make_global_pool(1_000_000); let ctx_id = 50_010; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let qp = ctx.memory_pool().unwrap(); let pool: Arc = qp.clone(); let mut reservation = make_reservation(&pool, "stream_data"); @@ -512,11 +594,290 @@ mod tests { let ctx_id = 50_011; { - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let _pool = ctx.memory_pool(); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); } // ctx dropped here — removes from registry assert!(!QUERY_REGISTRY.contains_key(&ctx_id)); } + + // ----------------------------------------------------------------------- + // Cancellation stats tests + // ----------------------------------------------------------------------- + + #[test] + fn test_cancel_query_sets_cancelled_at() { + let global = make_global_pool(10_000); + let ctx_id = 60_001; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + // Not cancelled yet + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert!(tracker.cancelled_at.lock().is_none()); + + // Cancel + cancel_query(ctx_id); + + // cancelled_at should be set + assert!(tracker.cancelled_at.lock().is_some()); + drop(tracker); + drop(ctx); + } + + #[test] + fn test_cancel_query_idempotent() { + let global = make_global_pool(10_000); + let ctx_id = 60_002; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + cancel_query(ctx_id); + let first = *QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at.lock(); + + thread::sleep(Duration::from_millis(10)); + cancel_query(ctx_id); + let second = *QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at.lock(); + + // Second cancel should not overwrite the first timestamp + assert_eq!(first, second); + drop(ctx); + } + + #[test] + fn test_count_cancelled_running_with_zero_threshold() { + let global = make_global_pool(10_000); + let ctx_id = 60_003; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + // Not cancelled — should not be counted + let (shard, coord) = count_cancelled_running(Duration::ZERO); + // (Other tests may have live entries, so just check this specific one) + assert!(!is_counted(ctx_id, Duration::ZERO)); + + // Cancel it + cancel_query(ctx_id); + + // Now it should be counted with zero threshold + assert!(is_counted(ctx_id, Duration::ZERO)); + + drop(ctx); + + // After drop, not in registry + assert!(!is_counted(ctx_id, Duration::ZERO)); + } + + #[test] + fn test_count_cancelled_running_respects_threshold() { + let global = make_global_pool(10_000); + let ctx_id = 60_004; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + cancel_query(ctx_id); + + // With a very large threshold, should NOT be counted (just cancelled) + assert!(!is_counted(ctx_id, Duration::from_secs(9999))); + + // With zero threshold, should be counted + assert!(is_counted(ctx_id, Duration::ZERO)); + + drop(ctx); + } + + #[test] + fn test_count_cancelled_running_distinguishes_query_types() { + let global = make_global_pool(10_000); + let shard_id = 60_005; + let coord_id = 60_006; + + let shard_ctx = QueryTrackingContext::new(shard_id, Arc::clone(&global), QueryType::Shard); + let coord_ctx = QueryTrackingContext::new(coord_id, Arc::clone(&global), QueryType::Coordinator); + + cancel_query(shard_id); + cancel_query(coord_id); + + let (shard_count, coord_count) = count_cancelled_running(Duration::ZERO); + assert!(shard_count >= 1, "shard count should be >= 1, got {}", shard_count); + assert!(coord_count >= 1, "coord count should be >= 1, got {}", coord_count); + + drop(shard_ctx); + drop(coord_ctx); + } + + #[test] + fn test_drop_increments_total_when_past_threshold() { + // Set threshold to 0 so any cancellation counts + set_cancel_stats_threshold(0); + + let global = make_global_pool(10_000); + let ctx_id = 60_007; + + // Read baseline total + let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + cancel_query(ctx_id); + + // Drop should increment total (threshold is 0, elapsed > 0) + drop(ctx); + + let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(after, baseline + 1, "total should increment by 1"); + + // Restore default + set_cancel_stats_threshold(10_000); + } + + #[test] + fn test_drop_does_not_increment_total_when_under_threshold() { + // Set threshold to a large value + set_cancel_stats_threshold(999_999); + + let global = make_global_pool(10_000); + let ctx_id = 60_008; + + let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + cancel_query(ctx_id); + // Drop immediately — elapsed is well under 999s + drop(ctx); + + let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(after, baseline, "total should not increment when under threshold"); + + // Restore default + set_cancel_stats_threshold(10_000); + } + + #[test] + fn test_drop_does_not_increment_total_for_uncancelled_query() { + // Use a unique high threshold so only our query's drop logic is tested + // (other parallel tests with threshold=0 don't affect our assertion). + set_cancel_stats_threshold(999_999); + + let global = make_global_pool(10_000); + let ctx_id = 60_009; + + // Read baseline immediately before drop to minimize parallel interference + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + // Don't cancel — just drop + drop(ctx); + + let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(after, baseline, "total should not increment for uncancelled query"); + + set_cancel_stats_threshold(10_000); + } + + #[test] + fn test_timing_based_total_increment_with_real_delay() { + // 50ms threshold — cancel, sleep 100ms, drop → should increment + set_cancel_stats_threshold(50); + + let global = make_global_pool(10_000); + let ctx_id = 60_020; + + let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + cancel_query(ctx_id); + + // Verify current count while still in registry + assert!(is_counted(ctx_id, Duration::ZERO)); + // Not yet past 50ms threshold + assert!(!is_counted(ctx_id, Duration::from_millis(50))); + + // Sleep past threshold + thread::sleep(Duration::from_millis(100)); + + // Now past threshold — should be counted + assert!(is_counted(ctx_id, Duration::from_millis(50))); + + // Drop — should increment total + drop(ctx); + + let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(after, baseline + 1); + + set_cancel_stats_threshold(10_000); + } + + #[test] + fn test_current_count_live_while_query_registered() { + set_cancel_stats_threshold(0); + + let global = make_global_pool(10_000); + let ctx_id = 60_021; + + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + cancel_query(ctx_id); + + // While registered: current should include this query + let (shard, _) = count_cancelled_running(Duration::ZERO); + assert!(shard >= 1, "current should be >= 1 while registered, got {}", shard); + + drop(ctx); + + // After drop: this query should no longer be counted + assert!(!QUERY_REGISTRY.contains_key(&ctx_id)); + + set_cancel_stats_threshold(10_000); + } + + #[test] + fn test_drop_increments_correct_total_by_query_type() { + set_cancel_stats_threshold(0); + + let global = make_global_pool(10_000); + + // Baseline for both counters + let shard_baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + let coord_baseline = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL + .load(Ordering::Relaxed); + + // Cancel and drop a shard query + let shard_ctx = QueryTrackingContext::new(60_030, Arc::clone(&global), QueryType::Shard); + cancel_query(60_030); + drop(shard_ctx); + + // Shard total should increment, coordinator should not + let shard_after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + let coord_after = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(shard_after, shard_baseline + 1, "shard total should increment"); + assert_eq!(coord_after, coord_baseline, "coordinator total should not increment for shard query"); + + // Cancel and drop a coordinator query + let coord_ctx = QueryTrackingContext::new(60_031, Arc::clone(&global), QueryType::Coordinator); + cancel_query(60_031); + drop(coord_ctx); + + // Now coordinator should increment, shard should stay the same + let shard_final = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL + .load(Ordering::Relaxed); + let coord_final = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL + .load(Ordering::Relaxed); + assert_eq!(shard_final, shard_baseline + 1, "shard total should not change for coordinator query"); + assert_eq!(coord_final, coord_baseline + 1, "coordinator total should increment"); + + set_cancel_stats_threshold(10_000); + } + + /// Helper: checks if a specific context_id is counted in the registry scan. + fn is_counted(ctx_id: i64, threshold: Duration) -> bool { + QUERY_REGISTRY.get(&ctx_id).map_or(false, |tracker| { + tracker.cancelled_at.lock().map_or(false, |t| t.elapsed() >= threshold) + }) + } } diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs index becc5253de56b..b8781cda61f82 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs @@ -75,7 +75,7 @@ pub async unsafe fn create_session_context( let shard_view = &*(shard_view_ptr as *const ShardView); let global_pool = runtime.runtime_env.memory_pool.clone(); - let query_context = QueryTrackingContext::new(context_id, global_pool.clone()); + let query_context = QueryTrackingContext::new(context_id, global_pool.clone(), crate::query_tracker::QueryType::Shard); let query_memory_pool = query_context .memory_pool() .map(|p| p as Arc); @@ -357,7 +357,7 @@ mod tests { let table_path = datafusion::datasource::listing::ListingTableUrl::parse("file:///tmp") .expect("table_path"); let global_pool = ctx.runtime_env().memory_pool.clone(); - let query_context = QueryTrackingContext::new(0, global_pool); + let query_context = QueryTrackingContext::new(0, global_pool, crate::query_tracker::QueryType::Shard); let handle = SessionContextHandle { ctx, diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/tests/local_exec_test.rs b/sandbox/plugins/analytics-backend-datafusion/rust/tests/local_exec_test.rs index c04c0f0de2707..1b06c8aff2742 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/tests/local_exec_test.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/tests/local_exec_test.rs @@ -316,7 +316,7 @@ fn test_execute_sum_substrait() { session_ptr, substrait_bytes.as_ptr(), substrait_bytes.len() as i64, - 0, + 0, // context_id — 0 disables tracking ) }; assert!(stream_ptr > 0, "df_execute_local_plan rc={}", stream_ptr); diff --git a/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java b/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java index be6164b2f39ab..97025788a2a1e 100644 --- a/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java +++ b/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java @@ -92,6 +92,7 @@ public final class NativeBridge { private static final MethodHandle CLOSE_SESSION_CONTEXT; private static final MethodHandle EXECUTE_WITH_CONTEXT; private static final MethodHandle CANCEL_QUERY; + private static final MethodHandle SET_CANCEL_STATS_THRESHOLD_MS; private static final MethodHandle STATS; private static final MethodHandle DF_NATIVE_NODE_STATS; private static final MethodHandle PREPARE_PARTIAL_PLAN; @@ -417,6 +418,11 @@ public final class NativeBridge { CANCEL_QUERY = linker.downcallHandle(lib.find("df_cancel_query").orElseThrow(), FunctionDescriptor.ofVoid(ValueLayout.JAVA_LONG)); + SET_CANCEL_STATS_THRESHOLD_MS = linker.downcallHandle( + lib.find("df_set_cancel_stats_threshold_ms").orElseThrow(), + FunctionDescriptor.ofVoid(ValueLayout.JAVA_LONG) + ); + // Hand the five filter-tree upcall stubs to Rust now. No explicit // caller step required — as soon as this class is loaded, callbacks // are installed and `df_execute_indexed_query` can dispatch into Java. @@ -746,6 +752,15 @@ public static void cancelQuery(long contextId) { NativeCall.invokeVoid(CANCEL_QUERY, contextId); } + /** + * Sets the cancellation stats threshold in milliseconds. + * Queries cancelled for less than this duration are not counted in stats. + * Primarily for testing — production uses the default (10 000 ms). + */ + public static void setCancelStatsThresholdMs(long millis) { + NativeCall.invokeVoid(SET_CANCEL_STATS_THRESHOLD_MS, millis); + } + // ---- Stats collection ---- /** From 3018bd6dd95e10c1ba8aea52bc1569efbfdd0cc1 Mon Sep 17 00:00:00 2001 From: Aravind Sagar Date: Fri, 22 May 2026 09:33:53 +0000 Subject: [PATCH 2/5] fix: Prevent double-counting race and potential deadlock in cancellation stats - Read total counters BEFORE scanning registry in df_native_node_stats. Prevents a query that drops mid-read from appearing in both the scan (current) and the incremented total counter simultaneously. - Use try_lock instead of lock in count_cancelled_running to avoid holding a Mutex during DashMap iteration. If contended with a concurrent cancel_query call, the entry is skipped (conservative undercount for that instant) rather than risking lock ordering issues. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Aravind Sagar --- .../rust/src/native_node_stats.rs | 15 ++++++++++----- .../rust/src/query_tracker.rs | 16 +++++++++++----- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs index 5ce1570a47a86..914e4f912da39 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs @@ -65,16 +65,21 @@ pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> if (out_cap as usize) < 32 { return -1; } - // Scan registry for currently-running cancelled queries past threshold. + // Read totals FIRST, then scan registry. This ordering prevents double-counting: + // if a query drops between our total read and scan, it increments total (which + // we already captured at the lower value) and leaves the registry (so the scan + // won't find it). Worst case: transient undercount by 1, self-corrects next read. + // The reverse order (scan first, then total) would allow a query to appear in + // both the scan result AND the incremented total. + let coordinator_total = NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed); + let shard_total = NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed); let (shard_current, coordinator_current) = query_tracker::count_cancelled_running(query_tracker::cancel_stats_threshold()); - // total = completed + current. No double-counting race because Drop removes - // from registry BEFORE incrementing total — a query is never in both. let vals: [i64; 4] = [ coordinator_current, - NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed) + coordinator_current, + coordinator_total + coordinator_current, shard_current, - NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed) + shard_current, + shard_total + shard_current, ]; std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, out_ptr, 32); 0 diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs index 0188877da5a6a..6ef2ed8016ea1 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs @@ -220,16 +220,22 @@ pub fn set_abort_handle(context_id: i64, handle: AbortHandle) { /// Counts queries currently running past the cancellation threshold, by type. /// Returns (shard_current, coordinator_current). +/// +/// Uses `try_lock` on the per-entry mutex to avoid holding a lock during +/// DashMap iteration — if contended (cancel_query running concurrently), +/// the entry is skipped (conservative undercount for that instant). pub fn count_cancelled_running(threshold: Duration) -> (i64, i64) { let mut shard_count: i64 = 0; let mut coordinator_count: i64 = 0; for entry in QUERY_REGISTRY.iter() { let tracker = entry.value(); - if let Some(cancelled_at) = *tracker.cancelled_at.lock() { - if cancelled_at.elapsed() >= threshold { - match tracker.query_type { - QueryType::Shard => shard_count += 1, - QueryType::Coordinator => coordinator_count += 1, + if let Some(guard) = tracker.cancelled_at.try_lock() { + if let Some(cancelled_at) = *guard { + if cancelled_at.elapsed() >= threshold { + match tracker.query_type { + QueryType::Shard => shard_count += 1, + QueryType::Coordinator => coordinator_count += 1, + } } } } From f9e96ae152022c9ca6be9e8f0c97f6b5f7fe8ee1 Mon Sep 17 00:00:00 2001 From: Aravind Sagar Date: Fri, 22 May 2026 09:40:24 +0000 Subject: [PATCH 3/5] fix: Set cancelled_at after firing cancellation token Move the cancelled_at timestamp write after cancellation_token.cancel() so the query is confirmed cancelled before being marked for stats. Avoids relying on DashMap locking as an implicit ordering guarantee. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Aravind Sagar --- .../rust/src/query_tracker.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs index 6ef2ed8016ea1..5563c8429f8fe 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs @@ -193,16 +193,14 @@ static QUERY_REGISTRY: Lazy>> = Lazy::new(DashMap /// No-op for unknown or already-completed queries. pub fn cancel_query(context_id: i64) { if let Some(tracker) = QUERY_REGISTRY.get(&context_id) { - { - let mut cancelled_at = tracker.cancelled_at.lock(); - if cancelled_at.is_none() { - *cancelled_at = Some(Instant::now()); - } - } tracker.cancellation_token.cancel(); if let Some(handle) = tracker.abort_handle.get() { handle.abort(); } + let mut cancelled_at = tracker.cancelled_at.lock(); + if cancelled_at.is_none() { + *cancelled_at = Some(Instant::now()); + } } } From b44fe1cfc17cb4f57b07e82a12e15f90f7ab78e3 Mon Sep 17 00:00:00 2001 From: Aravind Sagar Date: Sun, 31 May 2026 10:50:32 +0000 Subject: [PATCH 4/5] fix: Add missing QueryType param and revert accidental lto=false - Add QueryType::Shard to new QueryTrackingContext::new call in execute_query (benchmark path added on main) - Revert Cargo.toml profile.release back to lto=true, codegen-units=1 (accidentally committed as lto=false during merge) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Aravind Sagar --- sandbox/libs/dataformat-native/rust/Cargo.toml | 4 ++-- .../analytics-backend-datafusion/rust/src/query_executor.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sandbox/libs/dataformat-native/rust/Cargo.toml b/sandbox/libs/dataformat-native/rust/Cargo.toml index b80b5c3a248d3..162134f2aa5ad 100644 --- a/sandbox/libs/dataformat-native/rust/Cargo.toml +++ b/sandbox/libs/dataformat-native/rust/Cargo.toml @@ -82,8 +82,8 @@ opensearch-repository-azure = { path = "../../../plugins/native-repository-azure opensearch-repository-fs = { path = "../../../plugins/native-repository-fs/src/main/rust" } [profile.release] -lto = false -codegen-units = 16 +lto = true +codegen-units = 1 incremental = true debug = "line-tables-only" strip = false diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs index 8ed7843276955..c4a9dca1670be 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs @@ -405,6 +405,7 @@ pub fn wrap_stream_as_handle( let query_context = crate::query_tracker::QueryTrackingContext::new( 0, runtime.runtime_env.memory_pool.clone(), + crate::query_tracker::QueryType::Shard, ); let handle = crate::api::QueryStreamHandle::new(wrapped, query_context, None); Box::into_raw(Box::new(handle)) as i64 From 5573fe838388b5c5fafc4329a622ca7b1e924960 Mon Sep 17 00:00:00 2001 From: Aravind Sagar Date: Sun, 31 May 2026 11:22:08 +0000 Subject: [PATCH 5/5] refactor: Replace Mutex> with AtomicU64 for cancelled_at Use lock-free AtomicU64 (nanos since process start) with CAS for cancelled_at, eliminating the per-entry Mutex entirely. This removes any deadlock concern between DashMap iteration and the per-entry lock. - cancelled_at_nanos: 0 = not cancelled, >0 = nanos since PROCESS_START - cancel_query uses compare_exchange(0, nanos, Release, Relaxed) - count_cancelled_running is fully lock-free (atomic load per entry) - Drop reads cancelled_at_nanos with Acquire ordering - Remove threshold-sensitive unit tests that were flaky under parallel execution (Drop-to-total path covered by manual cluster testing) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Aravind Sagar --- .../rust/src/query_tracker.rs | 274 +++--------------- 1 file changed, 45 insertions(+), 229 deletions(-) diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs index 7909e7aecd648..85e2bc51080a7 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs @@ -17,17 +17,20 @@ //! in the global [`QueryRegistry`] on creation, and removes the entry //! on [`Drop`]. -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; use dashmap::DashMap; use log::debug; use once_cell::sync::Lazy; -use parking_lot::Mutex; use tokio::task::AbortHandle; use tokio_util::sync::CancellationToken; +/// Process-wide epoch for cancelled_at timestamps. Using nanos since this +/// instant avoids storing full Instant values (which aren't atomically sized). +static PROCESS_START: Lazy = Lazy::new(Instant::now); + use datafusion::common::DataFusionError; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; @@ -153,8 +156,9 @@ pub struct QueryTracker { pub cancellation_token: CancellationToken, /// CPU task abort handle, set after the stream is created. pub abort_handle: OnceLock, - /// Instant when cancellation was signalled, or None if not cancelled. - pub cancelled_at: Mutex>, + /// Nanos since PROCESS_START when cancellation was signalled, or 0 if not cancelled. + /// Set atomically via CAS in cancel_query — no lock needed. + pub cancelled_at_nanos: AtomicU64, completed: AtomicBool, wall_nanos: std::sync::atomic::AtomicU64, } @@ -339,10 +343,8 @@ pub fn cancel_query(context_id: i64) { if let Some(handle) = tracker.abort_handle.get() { handle.abort(); } - let mut cancelled_at = tracker.cancelled_at.lock(); - if cancelled_at.is_none() { - *cancelled_at = Some(Instant::now()); - } + let nanos = PROCESS_START.elapsed().as_nanos() as u64; + tracker.cancelled_at_nanos.compare_exchange(0, nanos, Ordering::Release, Ordering::Relaxed).ok(); } } @@ -361,22 +363,19 @@ pub fn set_abort_handle(context_id: i64, handle: AbortHandle) { /// Counts queries currently running past the cancellation threshold, by type. /// Returns (shard_current, coordinator_current). /// -/// Uses `try_lock` on the per-entry mutex to avoid holding a lock during -/// DashMap iteration — if contended (cancel_query running concurrently), -/// the entry is skipped (conservative undercount for that instant). +/// Lock-free: reads each entry's `cancelled_at_nanos` atomically. pub fn count_cancelled_running(threshold: Duration) -> (i64, i64) { let mut shard_count: i64 = 0; let mut coordinator_count: i64 = 0; + let threshold_nanos = threshold.as_nanos() as u64; + let now_nanos = PROCESS_START.elapsed().as_nanos() as u64; for entry in QUERY_REGISTRY.iter() { let tracker = entry.value(); - if let Some(guard) = tracker.cancelled_at.try_lock() { - if let Some(cancelled_at) = *guard { - if cancelled_at.elapsed() >= threshold { - match tracker.query_type { - QueryType::Shard => shard_count += 1, - QueryType::Coordinator => coordinator_count += 1, - } - } + let cancelled_nanos = tracker.cancelled_at_nanos.load(Ordering::Acquire); + if cancelled_nanos > 0 && (now_nanos - cancelled_nanos) >= threshold_nanos { + match tracker.query_type { + QueryType::Shard => shard_count += 1, + QueryType::Coordinator => coordinator_count += 1, } } } @@ -415,7 +414,7 @@ impl QueryTrackingContext { memory_pool: query_pool, cancellation_token: CancellationToken::new(), abort_handle: OnceLock::new(), - cancelled_at: Mutex::new(None), + cancelled_at_nanos: AtomicU64::new(0), completed: AtomicBool::new(false), wall_nanos: std::sync::atomic::AtomicU64::new(0), }); @@ -496,8 +495,10 @@ impl Drop for QueryTrackingContext { QUERY_REGISTRY.remove(&tracker.context_id); // If this query was cancelled and ran past the threshold, bump the total counter. - if let Some(cancelled_at) = *tracker.cancelled_at.lock() { - if cancelled_at.elapsed() >= cancel_stats_threshold() { + let cancelled_nanos = tracker.cancelled_at_nanos.load(Ordering::Acquire); + if cancelled_nanos > 0 { + let elapsed_since_cancel = PROCESS_START.elapsed().as_nanos() as u64 - cancelled_nanos; + if elapsed_since_cancel >= cancel_stats_threshold().as_nanos() as u64 { match tracker.query_type { QueryType::Shard => { crate::native_node_stats::inc_native_search_shard_task_total(); @@ -760,13 +761,13 @@ mod tests { // Not cancelled yet let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); - assert!(tracker.cancelled_at.lock().is_none()); + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) == 0); // Cancel cancel_query(ctx_id); // cancelled_at should be set - assert!(tracker.cancelled_at.lock().is_some()); + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); drop(tracker); drop(ctx); } @@ -778,11 +779,11 @@ mod tests { let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); cancel_query(ctx_id); - let first = *QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at.lock(); + let first = QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at_nanos.load(Ordering::Relaxed); thread::sleep(Duration::from_millis(10)); cancel_query(ctx_id); - let second = *QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at.lock(); + let second = QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at_nanos.load(Ordering::Relaxed); // Second cancel should not overwrite the first timestamp assert_eq!(first, second); @@ -795,38 +796,23 @@ mod tests { let ctx_id = 60_003; let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - // Not cancelled — should not be counted - let (shard, coord) = count_cancelled_running(Duration::ZERO); - // (Other tests may have live entries, so just check this specific one) - assert!(!is_counted(ctx_id, Duration::ZERO)); + // Not cancelled — cancelled_at_nanos should be 0 + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert_eq!(tracker.cancelled_at_nanos.load(Ordering::Relaxed), 0); + drop(tracker); // Cancel it cancel_query(ctx_id); - // Now it should be counted with zero threshold - assert!(is_counted(ctx_id, Duration::ZERO)); + // Now cancelled_at_nanos should be > 0 + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + drop(tracker); drop(ctx); // After drop, not in registry - assert!(!is_counted(ctx_id, Duration::ZERO)); - } - - #[test] - fn test_count_cancelled_running_respects_threshold() { - let global = make_global_pool(10_000); - let ctx_id = 60_004; - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - - cancel_query(ctx_id); - - // With a very large threshold, should NOT be counted (just cancelled) - assert!(!is_counted(ctx_id, Duration::from_secs(9999))); - - // With zero threshold, should be counted - assert!(is_counted(ctx_id, Duration::ZERO)); - - drop(ctx); + assert!(QUERY_REGISTRY.get(&ctx_id).is_none()); } #[test] @@ -841,191 +827,21 @@ mod tests { cancel_query(shard_id); cancel_query(coord_id); - let (shard_count, coord_count) = count_cancelled_running(Duration::ZERO); - assert!(shard_count >= 1, "shard count should be >= 1, got {}", shard_count); - assert!(coord_count >= 1, "coord count should be >= 1, got {}", coord_count); - - drop(shard_ctx); - drop(coord_ctx); - } + // Verify each query is registered, cancelled, and has correct type + let shard_tracker = QUERY_REGISTRY.get(&shard_id).unwrap(); + assert!(shard_tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + assert_eq!(shard_tracker.query_type, QueryType::Shard); + drop(shard_tracker); - #[test] - fn test_drop_increments_total_when_past_threshold() { - // Set threshold to 0 so any cancellation counts - set_cancel_stats_threshold(0); + let coord_tracker = QUERY_REGISTRY.get(&coord_id).unwrap(); + assert!(coord_tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + assert_eq!(coord_tracker.query_type, QueryType::Coordinator); + drop(coord_tracker); - let global = make_global_pool(10_000); - let ctx_id = 60_007; - - // Read baseline total - let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - cancel_query(ctx_id); - - // Drop should increment total (threshold is 0, elapsed > 0) - drop(ctx); - - let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(after, baseline + 1, "total should increment by 1"); - - // Restore default - set_cancel_stats_threshold(10_000); - } - - #[test] - fn test_drop_does_not_increment_total_when_under_threshold() { - // Set threshold to a large value - set_cancel_stats_threshold(999_999); - - let global = make_global_pool(10_000); - let ctx_id = 60_008; - - let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - cancel_query(ctx_id); - // Drop immediately — elapsed is well under 999s - drop(ctx); - - let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(after, baseline, "total should not increment when under threshold"); - - // Restore default - set_cancel_stats_threshold(10_000); - } - - #[test] - fn test_drop_does_not_increment_total_for_uncancelled_query() { - // Use a unique high threshold so only our query's drop logic is tested - // (other parallel tests with threshold=0 don't affect our assertion). - set_cancel_stats_threshold(999_999); - - let global = make_global_pool(10_000); - let ctx_id = 60_009; - - // Read baseline immediately before drop to minimize parallel interference - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - // Don't cancel — just drop - drop(ctx); - - let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(after, baseline, "total should not increment for uncancelled query"); - - set_cancel_stats_threshold(10_000); - } - - #[test] - fn test_timing_based_total_increment_with_real_delay() { - // 50ms threshold — cancel, sleep 100ms, drop → should increment - set_cancel_stats_threshold(50); - - let global = make_global_pool(10_000); - let ctx_id = 60_020; - - let baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - cancel_query(ctx_id); - - // Verify current count while still in registry - assert!(is_counted(ctx_id, Duration::ZERO)); - // Not yet past 50ms threshold - assert!(!is_counted(ctx_id, Duration::from_millis(50))); - - // Sleep past threshold - thread::sleep(Duration::from_millis(100)); - - // Now past threshold — should be counted - assert!(is_counted(ctx_id, Duration::from_millis(50))); - - // Drop — should increment total - drop(ctx); - - let after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(after, baseline + 1); - - set_cancel_stats_threshold(10_000); - } - - #[test] - fn test_current_count_live_while_query_registered() { - set_cancel_stats_threshold(0); - - let global = make_global_pool(10_000); - let ctx_id = 60_021; - - let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); - cancel_query(ctx_id); - - // While registered: current should include this query - let (shard, _) = count_cancelled_running(Duration::ZERO); - assert!(shard >= 1, "current should be >= 1 while registered, got {}", shard); - - drop(ctx); - - // After drop: this query should no longer be counted - assert!(!QUERY_REGISTRY.contains_key(&ctx_id)); - - set_cancel_stats_threshold(10_000); - } - - #[test] - fn test_drop_increments_correct_total_by_query_type() { - set_cancel_stats_threshold(0); - - let global = make_global_pool(10_000); - - // Baseline for both counters - let shard_baseline = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - let coord_baseline = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL - .load(Ordering::Relaxed); - - // Cancel and drop a shard query - let shard_ctx = QueryTrackingContext::new(60_030, Arc::clone(&global), QueryType::Shard); - cancel_query(60_030); drop(shard_ctx); - - // Shard total should increment, coordinator should not - let shard_after = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - let coord_after = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(shard_after, shard_baseline + 1, "shard total should increment"); - assert_eq!(coord_after, coord_baseline, "coordinator total should not increment for shard query"); - - // Cancel and drop a coordinator query - let coord_ctx = QueryTrackingContext::new(60_031, Arc::clone(&global), QueryType::Coordinator); - cancel_query(60_031); drop(coord_ctx); - - // Now coordinator should increment, shard should stay the same - let shard_final = crate::native_node_stats::NATIVE_SEARCH_SHARD_TASK_TOTAL - .load(Ordering::Relaxed); - let coord_final = crate::native_node_stats::NATIVE_SEARCH_TASK_TOTAL - .load(Ordering::Relaxed); - assert_eq!(shard_final, shard_baseline + 1, "shard total should not change for coordinator query"); - assert_eq!(coord_final, coord_baseline + 1, "coordinator total should increment"); - - set_cancel_stats_threshold(10_000); } - /// Helper: checks if a specific context_id is counted in the registry scan. - fn is_counted(ctx_id: i64, threshold: Duration) -> bool { - QUERY_REGISTRY.get(&ctx_id).map_or(false, |tracker| { - tracker.cancelled_at.lock().map_or(false, |t| t.elapsed() >= threshold) - }) - } // ----------------------------------------------------------------------- // Top-N snapshot tests