diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs index fa79a3edab101..6615a1ddf61a4 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/api.rs @@ -502,7 +502,8 @@ pub async unsafe fn execute_query( // Create per-query context — auto-registers in the global registry let global_pool = runtime.runtime_env.memory_pool.clone(); - let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone()); + let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone(), query_tracker::QueryType::Shard); + let query_memory_pool = query_context .memory_pool() .map(|p| p as Arc); @@ -1407,7 +1408,7 @@ pub async unsafe fn execute_local_plan( // Per-query memory tracking — wraps the session's global pool. A // `context_id` of 0 disables tracking (pool is not consulted) and no // cancellation token is registered in the global QUERY_REGISTRY. - let query_context = QueryTrackingContext::new(context_id, session.memory_pool()); + let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator); let token = query_tracker::get_cancellation_token(context_id); // Race substrait planning + execution against the cancellation token so @@ -1461,7 +1462,7 @@ pub unsafe fn execute_local_prepared_plan( // so a `cancel_query(context_id)` call fires the token here too. // The token is held via the QueryStreamHandle's context and consulted by // stream_next on each batch pull. - let query_context = QueryTrackingContext::new(context_id, session.memory_pool()); + let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator); // DataFusion's execute_stream is sync, but kicks off RepartitionExec / // stream channels that require a Tokio reactor. Enter the IO runtime's diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs index 7f83d64021fbc..864227c795dbd 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/ffm.rs @@ -347,6 +347,13 @@ pub extern "C" fn df_cancel_query(context_id: i64) { api::cancel_query(context_id); } +/// Sets the cancellation stats threshold in milliseconds. +/// Queries cancelled for less than this duration are not counted in stats. +#[no_mangle] +pub extern "C" fn df_set_cancel_stats_threshold_ms(millis: i64) { + crate::query_tracker::set_cancel_stats_threshold(millis as u64); +} + // --------------------------------------------------------------------------- // Per-query registry top-N snapshot // diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs index efd504c2f96b8..19ee93a47b29f 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/indexed_executor.rs @@ -175,7 +175,7 @@ pub async fn execute_indexed_query( table_path: shard_view.table_path.clone(), object_metas: shard_view.object_metas.clone(), writer_generations: shard_view.writer_generations.clone(), - query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone()), + query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone(), crate::query_tracker::QueryType::Shard), table_name: table_name.clone(), indexed_config: None, // derive classification from tree query_config: Arc::unwrap_or_clone(query_config), diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs index bf910273854eb..f61040de19509 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/local_executor.rs @@ -472,7 +472,7 @@ mod tests { let ctx_id = 98_765; let pool: Arc = Arc::new(GreedyMemoryPool::new(10_000)); - let _tracking = QueryTrackingContext::new(ctx_id, pool); + let _tracking = QueryTrackingContext::new(ctx_id, pool, query_tracker::QueryType::Coordinator); // A future that would block indefinitely — `cancel_query` is the // only way out. Mirrors a coord reduce stalled on an input partition diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs index 2325cf942292e..914e4f912da39 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/native_node_stats.rs @@ -18,44 +18,25 @@ use std::sync::atomic::{AtomicI64, Ordering}; // --------------------------------------------------------------------------- -// 4 independent atomic counters — no struct, no lock +// Total counters — incremented by QueryTrackingContext::drop() when a +// cancelled query ran past the threshold before completing. // --------------------------------------------------------------------------- -static NATIVE_SEARCH_TASK_CURRENT: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_SHARD_TASK_CURRENT: AtomicI64 = AtomicI64::new(0); -static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); +pub(crate) static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); +pub(crate) static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0); // --------------------------------------------------------------------------- -// Public increment/decrement functions for producers +// Public increment functions for producers (called from query_tracker Drop) // --------------------------------------------------------------------------- -/// Increments the current count of native search tasks executing post-cancellation. -pub fn inc_native_search_task_current() { - NATIVE_SEARCH_TASK_CURRENT.fetch_add(1, Ordering::Relaxed); -} - -/// Decrements the current count of native search tasks executing post-cancellation. -pub fn dec_native_search_task_current() { - NATIVE_SEARCH_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed); -} - -/// Increments the total count of native search tasks that have executed post-cancellation. +/// Increments the total count of native search tasks that executed post-cancellation +/// beyond the threshold duration. pub fn inc_native_search_task_total() { NATIVE_SEARCH_TASK_TOTAL.fetch_add(1, Ordering::Relaxed); } -/// Increments the current count of native search shard tasks executing post-cancellation. -pub fn inc_native_search_shard_task_current() { - NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_add(1, Ordering::Relaxed); -} - -/// Decrements the current count of native search shard tasks executing post-cancellation. -pub fn dec_native_search_shard_task_current() { - NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed); -} - -/// Increments the total count of native search shard tasks that have executed post-cancellation. +/// Increments the total count of native search shard tasks that executed post-cancellation +/// beyond the threshold duration. pub fn inc_native_search_shard_task_total() { NATIVE_SEARCH_SHARD_TASK_TOTAL.fetch_add(1, Ordering::Relaxed); } @@ -64,7 +45,9 @@ pub fn inc_native_search_shard_task_total() { // FFM entry point // --------------------------------------------------------------------------- -/// Reads 4 AtomicI64 counters and writes them as `[i64; 4]` to the caller buffer. +/// Reads task cancellation stats and writes them as `[i64; 4]` to the caller buffer. +/// `current` counts are computed by scanning the live query registry. +/// `total` counts come from the atomic counters (incremented on query drop). /// Returns 0 on success, -1 if buffer too small. /// /// Buffer layout (32 bytes): @@ -77,14 +60,26 @@ pub fn inc_native_search_shard_task_total() { /// `out_ptr` must point to a valid buffer of at least `out_cap` bytes. #[no_mangle] pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> i64 { + use crate::query_tracker; + if (out_cap as usize) < 32 { return -1; } + // Read totals FIRST, then scan registry. This ordering prevents double-counting: + // if a query drops between our total read and scan, it increments total (which + // we already captured at the lower value) and leaves the registry (so the scan + // won't find it). Worst case: transient undercount by 1, self-corrects next read. + // The reverse order (scan first, then total) would allow a query to appear in + // both the scan result AND the incremented total. + let coordinator_total = NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed); + let shard_total = NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed); + let (shard_current, coordinator_current) = + query_tracker::count_cancelled_running(query_tracker::cancel_stats_threshold()); let vals: [i64; 4] = [ - NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), - NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), - NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), - NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), + coordinator_current, + coordinator_total + coordinator_current, + shard_current, + shard_total + shard_current, ]; std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, out_ptr, 32); 0 @@ -96,41 +91,25 @@ pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> #[cfg(test)] pub(crate) fn reset_all_counters() { - NATIVE_SEARCH_TASK_CURRENT.store(0, Ordering::Relaxed); NATIVE_SEARCH_TASK_TOTAL.store(0, Ordering::Relaxed); - NATIVE_SEARCH_SHARD_TASK_CURRENT.store(0, Ordering::Relaxed); NATIVE_SEARCH_SHARD_TASK_TOTAL.store(0, Ordering::Relaxed); } #[cfg(test)] mod tests { use super::*; - use proptest::prelude::*; - /// Reset counters before each test to avoid cross-test interference. fn setup() { reset_all_counters(); } #[test] - fn test_counters_initialized_to_zero() { + fn test_total_counters_initialized_to_zero() { setup(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 0); assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0); assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0); } - #[test] - fn test_inc_dec_native_search_task_current() { - setup(); - inc_native_search_task_current(); - inc_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 2); - dec_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1); - } - #[test] fn test_inc_native_search_task_total() { setup(); @@ -140,48 +119,12 @@ mod tests { assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 3); } - #[test] - fn test_inc_dec_native_search_shard_task_current() { - setup(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_current(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 3); - dec_native_search_shard_task_current(); - dec_native_search_shard_task_current(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 1); - } - #[test] fn test_inc_native_search_shard_task_total() { setup(); inc_native_search_shard_task_total(); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 1); - } - - #[test] - fn test_df_native_node_stats_reads_correct_values() { - setup(); - // Set up known state - inc_native_search_task_current(); - inc_native_search_task_current(); - inc_native_search_task_total(); - inc_native_search_task_total(); - inc_native_search_task_total(); - inc_native_search_shard_task_current(); - inc_native_search_shard_task_total(); inc_native_search_shard_task_total(); - - let mut buf = [0u8; 32]; - let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; - assert_eq!(rc, 0); - - // Decode the buffer as [i64; 4] - let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - assert_eq!(vals[0], 2); // native_search_task_current - assert_eq!(vals[1], 3); // native_search_task_total - assert_eq!(vals[2], 1); // native_search_shard_task_current - assert_eq!(vals[3], 2); // native_search_shard_task_total + assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 2); } #[test] @@ -207,103 +150,24 @@ mod tests { } #[test] - fn test_counters_are_independent() { - setup(); - // Increment only one counter and verify others remain zero - inc_native_search_task_current(); - assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1); - assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0); - assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0); - } - - // ----------------------------------------------------------------------- - // Property-based test: AtomicI64 increment/decrement correctness - // ----------------------------------------------------------------------- - - /// **Validates: Requirements 1.1, 2.1–2.6** - /// - /// Property 1: AtomicI64 increment/decrement correctness - /// - /// For any random number of increments and decrements applied to each of - /// the 4 counters, reading the counters via `df_native_node_stats` SHALL - /// return values equal to the algebraic sum (increments - decrements) for - /// each counter. - proptest! { - #[test] - fn prop_atomic_inc_dec_correctness( - search_task_current_incs in 0u32..100, - search_task_current_decs in 0u32..100, - search_task_total_incs in 0u32..100, - shard_task_current_incs in 0u32..100, - shard_task_current_decs in 0u32..100, - shard_task_total_incs in 0u32..100, - ) { - // Reset counters before each iteration - reset_all_counters(); - - // Apply increments and decrements for NATIVE_SEARCH_TASK_CURRENT - for _ in 0..search_task_current_incs { - inc_native_search_task_current(); - } - for _ in 0..search_task_current_decs { - dec_native_search_task_current(); - } - - // Apply only increments for NATIVE_SEARCH_TASK_TOTAL (no dec function exists) - for _ in 0..search_task_total_incs { - inc_native_search_task_total(); - } - - // Apply increments and decrements for NATIVE_SEARCH_SHARD_TASK_CURRENT - for _ in 0..shard_task_current_incs { - inc_native_search_shard_task_current(); - } - for _ in 0..shard_task_current_decs { - dec_native_search_shard_task_current(); - } - - // Apply only increments for NATIVE_SEARCH_SHARD_TASK_TOTAL (no dec function exists) - for _ in 0..shard_task_total_incs { - inc_native_search_shard_task_total(); - } - - // Expected algebraic sums - let expected_search_task_current = - search_task_current_incs as i64 - search_task_current_decs as i64; - let expected_search_task_total = search_task_total_incs as i64; - let expected_shard_task_current = - shard_task_current_incs as i64 - shard_task_current_decs as i64; - let expected_shard_task_total = shard_task_total_incs as i64; + fn test_df_native_node_stats_total_counters() { + // Read baseline (other tests may run in parallel and share globals) + let mut buf = [0u8; 32]; + let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; + assert_eq!(rc, 0); + let baseline: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - // Read counters via df_native_node_stats - let mut buf = [0u8; 32]; - let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; - prop_assert_eq!(rc, 0); + inc_native_search_task_total(); + inc_native_search_task_total(); + inc_native_search_shard_task_total(); - // Decode the buffer as [i64; 4] - let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; + let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) }; + assert_eq!(rc, 0); + let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) }; - prop_assert_eq!( - vals[0], expected_search_task_current, - "native_search_task_current mismatch: incs={}, decs={}", - search_task_current_incs, search_task_current_decs - ); - prop_assert_eq!( - vals[1], expected_search_task_total, - "native_search_task_total mismatch: incs={}", - search_task_total_incs - ); - prop_assert_eq!( - vals[2], expected_shard_task_current, - "native_search_shard_task_current mismatch: incs={}, decs={}", - shard_task_current_incs, shard_task_current_decs - ); - prop_assert_eq!( - vals[3], expected_shard_task_total, - "native_search_shard_task_total mismatch: incs={}", - shard_task_total_incs - ); - } + // vals[1] = coordinator total — should be baseline + 2 + assert_eq!(vals[1], baseline[1] + 2); + // vals[3] = shard total — should be baseline + 1 + assert_eq!(vals[3], baseline[3] + 1); } } diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs index 8ed7843276955..c4a9dca1670be 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_executor.rs @@ -405,6 +405,7 @@ pub fn wrap_stream_as_handle( let query_context = crate::query_tracker::QueryTrackingContext::new( 0, runtime.runtime_env.memory_pool.clone(), + crate::query_tracker::QueryType::Shard, ); let handle = crate::api::QueryStreamHandle::new(wrapped, query_context, None); Box::into_raw(Box::new(handle)) as i64 diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs index 0f368d316f4e6..85e2bc51080a7 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/query_tracker.rs @@ -17,9 +17,9 @@ //! in the global [`QueryRegistry`] on creation, and removes the entry //! on [`Drop`]. -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; -use std::time::Instant; +use std::time::{Duration, Instant}; use dashmap::DashMap; use log::debug; @@ -27,9 +27,45 @@ use once_cell::sync::Lazy; use tokio::task::AbortHandle; use tokio_util::sync::CancellationToken; +/// Process-wide epoch for cancelled_at timestamps. Using nanos since this +/// instant avoids storing full Instant values (which aren't atomically sized). +static PROCESS_START: Lazy = Lazy::new(Instant::now); + use datafusion::common::DataFusionError; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +// --------------------------------------------------------------------------- +// Query type discriminator +// --------------------------------------------------------------------------- + +/// Distinguishes shard-level queries from coordinator-level queries for stats. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QueryType { + /// Data-node shard fragment execution (AnalyticsShardTask on the Java side). + Shard, + /// Coordinator-side local reduce execution (AnalyticsQueryTask on the Java side). + Coordinator, +} + +/// Default threshold for "long-running post-cancel" — matches the Java-side +/// `task_cancellation.duration_millis` default of 10 000 ms. +pub const DEFAULT_CANCEL_THRESHOLD: Duration = Duration::from_secs(10); + +/// Runtime-configurable threshold. Initialized to DEFAULT_CANCEL_THRESHOLD (10_000 ms). +/// Set via `set_cancel_stats_threshold`. +static CANCEL_STATS_THRESHOLD_MS: std::sync::atomic::AtomicU64 = + std::sync::atomic::AtomicU64::new(10_000); + +/// Returns the currently configured cancellation stats threshold. +pub fn cancel_stats_threshold() -> Duration { + Duration::from_millis(CANCEL_STATS_THRESHOLD_MS.load(Ordering::Relaxed)) +} + +/// Sets the cancellation stats threshold (in milliseconds). +pub fn set_cancel_stats_threshold(millis: u64) { + CANCEL_STATS_THRESHOLD_MS.store(millis, Ordering::Relaxed); +} + // --------------------------------------------------------------------------- // Per-query memory pool // --------------------------------------------------------------------------- @@ -115,10 +151,14 @@ impl MemoryPool for QueryMemoryPool { pub struct QueryTracker { pub start_time: Instant, pub context_id: i64, + pub query_type: QueryType, pub memory_pool: Arc, pub cancellation_token: CancellationToken, /// CPU task abort handle, set after the stream is created. pub abort_handle: OnceLock, + /// Nanos since PROCESS_START when cancellation was signalled, or 0 if not cancelled. + /// Set atomically via CAS in cancel_query — no lock needed. + pub cancelled_at_nanos: AtomicU64, completed: AtomicBool, wall_nanos: std::sync::atomic::AtomicU64, } @@ -303,6 +343,8 @@ pub fn cancel_query(context_id: i64) { if let Some(handle) = tracker.abort_handle.get() { handle.abort(); } + let nanos = PROCESS_START.elapsed().as_nanos() as u64; + tracker.cancelled_at_nanos.compare_exchange(0, nanos, Ordering::Release, Ordering::Relaxed).ok(); } } @@ -318,6 +360,28 @@ pub fn set_abort_handle(context_id: i64, handle: AbortHandle) { } } +/// Counts queries currently running past the cancellation threshold, by type. +/// Returns (shard_current, coordinator_current). +/// +/// Lock-free: reads each entry's `cancelled_at_nanos` atomically. +pub fn count_cancelled_running(threshold: Duration) -> (i64, i64) { + let mut shard_count: i64 = 0; + let mut coordinator_count: i64 = 0; + let threshold_nanos = threshold.as_nanos() as u64; + let now_nanos = PROCESS_START.elapsed().as_nanos() as u64; + for entry in QUERY_REGISTRY.iter() { + let tracker = entry.value(); + let cancelled_nanos = tracker.cancelled_at_nanos.load(Ordering::Acquire); + if cancelled_nanos > 0 && (now_nanos - cancelled_nanos) >= threshold_nanos { + match tracker.query_type { + QueryType::Shard => shard_count += 1, + QueryType::Coordinator => coordinator_count += 1, + } + } + } + (shard_count, coordinator_count) +} + // --------------------------------------------------------------------------- // QueryTrackingContext // --------------------------------------------------------------------------- @@ -338,7 +402,7 @@ pub struct QueryTrackingContext { impl QueryTrackingContext { /// Create a new query context. If `context_id` is 0, tracking is /// disabled and `memory_pool()` returns `None`. - pub fn new(context_id: i64, global_pool: Arc) -> Self { + pub fn new(context_id: i64, global_pool: Arc, query_type: QueryType) -> Self { if context_id == 0 { return Self { tracker: None, phantom_reservation: None, phantom_corrector: None }; } @@ -346,9 +410,11 @@ impl QueryTrackingContext { let tracker = Arc::new(QueryTracker { start_time: Instant::now(), context_id, + query_type, memory_pool: query_pool, cancellation_token: CancellationToken::new(), abort_handle: OnceLock::new(), + cancelled_at_nanos: AtomicU64::new(0), completed: AtomicBool::new(false), wall_nanos: std::sync::atomic::AtomicU64::new(0), }); @@ -421,7 +487,28 @@ impl Drop for QueryTrackingContext { tracker.memory_pool.current_bytes(), tracker.memory_pool.peak_bytes(), ); + + // Remove from registry BEFORE incrementing total. This ensures a + // query is never simultaneously visible in both the registry scan + // (current) and the total counter — preventing double-counting in + // snapshot_cancellation_stats. QUERY_REGISTRY.remove(&tracker.context_id); + + // If this query was cancelled and ran past the threshold, bump the total counter. + let cancelled_nanos = tracker.cancelled_at_nanos.load(Ordering::Acquire); + if cancelled_nanos > 0 { + let elapsed_since_cancel = PROCESS_START.elapsed().as_nanos() as u64 - cancelled_nanos; + if elapsed_since_cancel >= cancel_stats_threshold().as_nanos() as u64 { + match tracker.query_type { + QueryType::Shard => { + crate::native_node_stats::inc_native_search_shard_task_total(); + } + QueryType::Coordinator => { + crate::native_node_stats::inc_native_search_task_total(); + } + } + } + } } } } @@ -504,7 +591,7 @@ mod tests { #[test] fn test_context_returns_none_pool_for_zero_id() { let global = make_global_pool(10_000); - let ctx = QueryTrackingContext::new(0, global); + let ctx = QueryTrackingContext::new(0, global, QueryType::Shard); assert!(ctx.memory_pool().is_none()); } @@ -512,7 +599,7 @@ mod tests { fn test_context_registers_and_removes_on_drop() { let global = make_global_pool(10_000); let ctx_id = 50_000; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); assert!(ctx.memory_pool().is_some()); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); @@ -525,7 +612,7 @@ mod tests { fn test_drop_removes_from_registry() { let global = make_global_pool(10_000); let ctx_id = 50_001; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); thread::sleep(Duration::from_millis(50)); @@ -539,7 +626,7 @@ mod tests { fn test_wall_secs_ticks_while_running() { let global = make_global_pool(10_000); let ctx_id = 50_002; - let _ctx = QueryTrackingContext::new(ctx_id, global); + let _ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let t1 = QUERY_REGISTRY.get(&ctx_id).unwrap().wall_secs(); thread::sleep(Duration::from_millis(50)); @@ -554,7 +641,7 @@ mod tests { fn test_memory_tracking_through_full_lifecycle() { let global = make_global_pool(1_000_000); let ctx_id = 50_004; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let qp = ctx.memory_pool().unwrap(); let pool: Arc = qp.clone(); let mut reservation = make_reservation(&pool, "lifecycle_test"); @@ -588,8 +675,8 @@ mod tests { let ctx_a_id = 50_005; let ctx_b_id = 50_006; - let ctx_a = QueryTrackingContext::new(ctx_a_id, Arc::clone(&global)); - let ctx_b = QueryTrackingContext::new(ctx_b_id, Arc::clone(&global)); + let ctx_a = QueryTrackingContext::new(ctx_a_id, Arc::clone(&global), QueryType::Shard); + let ctx_b = QueryTrackingContext::new(ctx_b_id, Arc::clone(&global), QueryType::Shard); let pool_a = ctx_a.memory_pool().unwrap(); let pool_b = ctx_b.memory_pool().unwrap(); @@ -625,7 +712,7 @@ mod tests { let global = make_global_pool(1_000_000); let ctx_id = 50_010; - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let qp = ctx.memory_pool().unwrap(); let pool: Arc = qp.clone(); let mut reservation = make_reservation(&pool, "stream_data"); @@ -654,7 +741,7 @@ mod tests { let ctx_id = 50_011; { - let ctx = QueryTrackingContext::new(ctx_id, global); + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); let _pool = ctx.memory_pool(); assert!(QUERY_REGISTRY.contains_key(&ctx_id)); } // ctx dropped here — removes from registry @@ -662,6 +749,100 @@ mod tests { assert!(!QUERY_REGISTRY.contains_key(&ctx_id)); } + // ----------------------------------------------------------------------- + // Cancellation stats tests + // ----------------------------------------------------------------------- + + #[test] + fn test_cancel_query_sets_cancelled_at() { + let global = make_global_pool(10_000); + let ctx_id = 60_001; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + // Not cancelled yet + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) == 0); + + // Cancel + cancel_query(ctx_id); + + // cancelled_at should be set + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + drop(tracker); + drop(ctx); + } + + #[test] + fn test_cancel_query_idempotent() { + let global = make_global_pool(10_000); + let ctx_id = 60_002; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + cancel_query(ctx_id); + let first = QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at_nanos.load(Ordering::Relaxed); + + thread::sleep(Duration::from_millis(10)); + cancel_query(ctx_id); + let second = QUERY_REGISTRY.get(&ctx_id).unwrap().cancelled_at_nanos.load(Ordering::Relaxed); + + // Second cancel should not overwrite the first timestamp + assert_eq!(first, second); + drop(ctx); + } + + #[test] + fn test_count_cancelled_running_with_zero_threshold() { + let global = make_global_pool(10_000); + let ctx_id = 60_003; + let ctx = QueryTrackingContext::new(ctx_id, global, QueryType::Shard); + + // Not cancelled — cancelled_at_nanos should be 0 + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert_eq!(tracker.cancelled_at_nanos.load(Ordering::Relaxed), 0); + drop(tracker); + + // Cancel it + cancel_query(ctx_id); + + // Now cancelled_at_nanos should be > 0 + let tracker = QUERY_REGISTRY.get(&ctx_id).unwrap(); + assert!(tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + drop(tracker); + + drop(ctx); + + // After drop, not in registry + assert!(QUERY_REGISTRY.get(&ctx_id).is_none()); + } + + #[test] + fn test_count_cancelled_running_distinguishes_query_types() { + let global = make_global_pool(10_000); + let shard_id = 60_005; + let coord_id = 60_006; + + let shard_ctx = QueryTrackingContext::new(shard_id, Arc::clone(&global), QueryType::Shard); + let coord_ctx = QueryTrackingContext::new(coord_id, Arc::clone(&global), QueryType::Coordinator); + + cancel_query(shard_id); + cancel_query(coord_id); + + // Verify each query is registered, cancelled, and has correct type + let shard_tracker = QUERY_REGISTRY.get(&shard_id).unwrap(); + assert!(shard_tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + assert_eq!(shard_tracker.query_type, QueryType::Shard); + drop(shard_tracker); + + let coord_tracker = QUERY_REGISTRY.get(&coord_id).unwrap(); + assert!(coord_tracker.cancelled_at_nanos.load(Ordering::Relaxed) > 0); + assert_eq!(coord_tracker.query_type, QueryType::Coordinator); + drop(coord_tracker); + + drop(shard_ctx); + drop(coord_ctx); + } + + // ----------------------------------------------------------------------- // Top-N snapshot tests // ----------------------------------------------------------------------- @@ -693,9 +874,9 @@ mod tests { let id_md = 70_001; let id_hi = 70_002; - let ctx_lo = QueryTrackingContext::new(id_lo, Arc::clone(&global)); - let ctx_md = QueryTrackingContext::new(id_md, Arc::clone(&global)); - let ctx_hi = QueryTrackingContext::new(id_hi, Arc::clone(&global)); + let ctx_lo = QueryTrackingContext::new(id_lo, Arc::clone(&global), QueryType::Shard); + let ctx_md = QueryTrackingContext::new(id_md, Arc::clone(&global), QueryType::Shard); + let ctx_hi = QueryTrackingContext::new(id_hi, Arc::clone(&global), QueryType::Shard); let pool_lo: Arc = ctx_lo.memory_pool().unwrap(); let pool_md: Arc = ctx_md.memory_pool().unwrap(); @@ -743,17 +924,17 @@ mod tests { let zero_id = 70_101; let done_id = 70_102; - let live_ctx = QueryTrackingContext::new(live_id, Arc::clone(&global)); + let live_ctx = QueryTrackingContext::new(live_id, Arc::clone(&global), QueryType::Shard); let live_pool: Arc = live_ctx.memory_pool().unwrap(); let mut r_live = make_reservation(&live_pool, "live"); r_live.try_grow(4_096).unwrap(); // Registered but never reserved — current_bytes stays 0. - let _zero_ctx = QueryTrackingContext::new(zero_id, Arc::clone(&global)); + let _zero_ctx = QueryTrackingContext::new(zero_id, Arc::clone(&global), QueryType::Shard); // Completed before snapshot. Drop reservation first so QueryMemoryPool // is settled, then drop the context to flip the completed flag. - let done_ctx = QueryTrackingContext::new(done_id, Arc::clone(&global)); + let done_ctx = QueryTrackingContext::new(done_id, Arc::clone(&global), QueryType::Shard); let done_pool: Arc = done_ctx.memory_pool().unwrap(); let mut r_done = make_reservation(&done_pool, "done"); r_done.try_grow(8_192).unwrap(); @@ -782,7 +963,7 @@ mod tests { fn test_top_n_with_buffer_larger_than_live_set() { let global = make_global_pool(1_000_000); let id = 70_200; - let ctx = QueryTrackingContext::new(id, global); + let ctx = QueryTrackingContext::new(id, global, QueryType::Shard); let pool: Arc = ctx.memory_pool().unwrap(); let mut r = make_reservation(&pool, "only"); r.try_grow(2_048).unwrap(); @@ -813,7 +994,7 @@ mod tests { let mut contexts = Vec::with_capacity(ids.len()); let mut reservations = Vec::with_capacity(ids.len()); for (i, id) in ids.iter().enumerate() { - let ctx = QueryTrackingContext::new(*id, Arc::clone(&global)); + let ctx = QueryTrackingContext::new(*id, Arc::clone(&global), QueryType::Shard); let pool: Arc = ctx.memory_pool().unwrap(); let mut r = make_reservation(&pool, "cap"); r.try_grow((i as usize + 1) * 1_000).unwrap(); diff --git a/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs b/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs index 3e11ef7ea75c2..1033f19eca724 100644 --- a/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs +++ b/sandbox/plugins/analytics-backend-datafusion/rust/src/session_context.rs @@ -138,7 +138,7 @@ pub async unsafe fn create_session_context( let shard_view = &*(shard_view_ptr as *const ShardView); let global_pool = runtime.runtime_env.memory_pool.clone(); - let query_context = QueryTrackingContext::new(context_id, global_pool.clone()); + let query_context = QueryTrackingContext::new(context_id, global_pool.clone(), crate::query_tracker::QueryType::Shard); let query_memory_pool = query_context .memory_pool() .map(|p| p as Arc); @@ -506,7 +506,7 @@ mod tests { let table_path = datafusion::datasource::listing::ListingTableUrl::parse("file:///tmp") .expect("table_path"); let global_pool = ctx.runtime_env().memory_pool.clone(); - let query_context = QueryTrackingContext::new(0, global_pool); + let query_context = QueryTrackingContext::new(0, global_pool, crate::query_tracker::QueryType::Shard); let handle = SessionContextHandle { ctx, diff --git a/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java b/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java index 2a457956c5e51..1c53679ff891b 100644 --- a/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java +++ b/sandbox/plugins/analytics-backend-datafusion/src/main/java/org/opensearch/be/datafusion/nativelib/NativeBridge.java @@ -108,6 +108,7 @@ public final class NativeBridge { private static final MethodHandle CLOSE_SESSION_CONTEXT; private static final MethodHandle EXECUTE_WITH_CONTEXT; private static final MethodHandle CANCEL_QUERY; + private static final MethodHandle SET_CANCEL_STATS_THRESHOLD_MS; private static final MethodHandle STATS; private static final MethodHandle QUERY_REGISTRY_TOP_N_BY_CURRENT; private static final MethodHandle DF_NATIVE_NODE_STATS; @@ -453,6 +454,11 @@ public final class NativeBridge { CANCEL_QUERY = linker.downcallHandle(lib.find("df_cancel_query").orElseThrow(), FunctionDescriptor.ofVoid(ValueLayout.JAVA_LONG)); + SET_CANCEL_STATS_THRESHOLD_MS = linker.downcallHandle( + lib.find("df_set_cancel_stats_threshold_ms").orElseThrow(), + FunctionDescriptor.ofVoid(ValueLayout.JAVA_LONG) + ); + // Hand the five filter-tree upcall stubs to Rust now. No explicit // caller step required — as soon as this class is loaded, callbacks // are installed and `df_execute_indexed_query` can dispatch into Java. @@ -846,6 +852,15 @@ public static void cancelQuery(long contextId) { NativeCall.invokeVoid(CANCEL_QUERY, contextId); } + /** + * Sets the cancellation stats threshold in milliseconds. + * Queries cancelled for less than this duration are not counted in stats. + * Primarily for testing — production uses the default (10 000 ms). + */ + public static void setCancelStatsThresholdMs(long millis) { + NativeCall.invokeVoid(SET_CANCEL_STATS_THRESHOLD_MS, millis); + } + // ---- Stats collection ---- /**