Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ pub async unsafe fn execute_query(

// Create per-query context — auto-registers in the global registry
let global_pool = runtime.runtime_env.memory_pool.clone();
let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone());
let mut query_context = QueryTrackingContext::new(context_id, global_pool.clone(), query_tracker::QueryType::Shard);

let query_memory_pool = query_context
.memory_pool()
.map(|p| p as Arc<dyn datafusion::execution::memory_pool::MemoryPool>);
Expand Down Expand Up @@ -1407,7 +1408,7 @@ pub async unsafe fn execute_local_plan(
// Per-query memory tracking — wraps the session's global pool. A
// `context_id` of 0 disables tracking (pool is not consulted) and no
// cancellation token is registered in the global QUERY_REGISTRY.
let query_context = QueryTrackingContext::new(context_id, session.memory_pool());
let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator);
let token = query_tracker::get_cancellation_token(context_id);

// Race substrait planning + execution against the cancellation token so
Expand Down Expand Up @@ -1461,7 +1462,7 @@ pub unsafe fn execute_local_prepared_plan(
// so a `cancel_query(context_id)` call fires the token here too.
// The token is held via the QueryStreamHandle's context and consulted by
// stream_next on each batch pull.
let query_context = QueryTrackingContext::new(context_id, session.memory_pool());
let query_context = QueryTrackingContext::new(context_id, session.memory_pool(), query_tracker::QueryType::Coordinator);

// DataFusion's execute_stream is sync, but kicks off RepartitionExec /
// stream channels that require a Tokio reactor. Enter the IO runtime's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ pub extern "C" fn df_cancel_query(context_id: i64) {
api::cancel_query(context_id);
}

/// Sets the cancellation stats threshold in milliseconds.
/// Queries cancelled for less than this duration are not counted in stats.
#[no_mangle]
pub extern "C" fn df_set_cancel_stats_threshold_ms(millis: i64) {
crate::query_tracker::set_cancel_stats_threshold(millis as u64);
}

// ---------------------------------------------------------------------------
// Per-query registry top-N snapshot
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pub async fn execute_indexed_query(
table_path: shard_view.table_path.clone(),
object_metas: shard_view.object_metas.clone(),
writer_generations: shard_view.writer_generations.clone(),
query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone()),
query_context: crate::query_tracker::QueryTrackingContext::new(0, runtime.runtime_env.memory_pool.clone(), crate::query_tracker::QueryType::Shard),
table_name: table_name.clone(),
indexed_config: None, // derive classification from tree
query_config: Arc::unwrap_or_clone(query_config),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ mod tests {
let ctx_id = 98_765;
let pool: Arc<dyn datafusion::execution::memory_pool::MemoryPool> =
Arc::new(GreedyMemoryPool::new(10_000));
let _tracking = QueryTrackingContext::new(ctx_id, pool);
let _tracking = QueryTrackingContext::new(ctx_id, pool, query_tracker::QueryType::Coordinator);

// A future that would block indefinitely — `cancel_query` is the
// only way out. Mirrors a coord reduce stalled on an input partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,25 @@
use std::sync::atomic::{AtomicI64, Ordering};

// ---------------------------------------------------------------------------
// 4 independent atomic counters — no struct, no lock
// Total counters — incremented by QueryTrackingContext::drop() when a
// cancelled query ran past the threshold before completing.
// ---------------------------------------------------------------------------

static NATIVE_SEARCH_TASK_CURRENT: AtomicI64 = AtomicI64::new(0);
static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0);
static NATIVE_SEARCH_SHARD_TASK_CURRENT: AtomicI64 = AtomicI64::new(0);
static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0);
pub(crate) static NATIVE_SEARCH_TASK_TOTAL: AtomicI64 = AtomicI64::new(0);
pub(crate) static NATIVE_SEARCH_SHARD_TASK_TOTAL: AtomicI64 = AtomicI64::new(0);

// ---------------------------------------------------------------------------
// Public increment/decrement functions for producers
// Public increment functions for producers (called from query_tracker Drop)
// ---------------------------------------------------------------------------

/// Increments the current count of native search tasks executing post-cancellation.
pub fn inc_native_search_task_current() {
NATIVE_SEARCH_TASK_CURRENT.fetch_add(1, Ordering::Relaxed);
}

/// Decrements the current count of native search tasks executing post-cancellation.
pub fn dec_native_search_task_current() {
NATIVE_SEARCH_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed);
}

/// Increments the total count of native search tasks that have executed post-cancellation.
/// Increments the total count of native search tasks that executed post-cancellation
/// beyond the threshold duration.
pub fn inc_native_search_task_total() {
NATIVE_SEARCH_TASK_TOTAL.fetch_add(1, Ordering::Relaxed);
}

/// Increments the current count of native search shard tasks executing post-cancellation.
pub fn inc_native_search_shard_task_current() {
NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_add(1, Ordering::Relaxed);
}

/// Decrements the current count of native search shard tasks executing post-cancellation.
pub fn dec_native_search_shard_task_current() {
NATIVE_SEARCH_SHARD_TASK_CURRENT.fetch_sub(1, Ordering::Relaxed);
}

/// Increments the total count of native search shard tasks that have executed post-cancellation.
/// Increments the total count of native search shard tasks that executed post-cancellation
/// beyond the threshold duration.
pub fn inc_native_search_shard_task_total() {
NATIVE_SEARCH_SHARD_TASK_TOTAL.fetch_add(1, Ordering::Relaxed);
}
Expand All @@ -64,7 +45,9 @@ pub fn inc_native_search_shard_task_total() {
// FFM entry point
// ---------------------------------------------------------------------------

/// Reads 4 AtomicI64 counters and writes them as `[i64; 4]` to the caller buffer.
/// Reads task cancellation stats and writes them as `[i64; 4]` to the caller buffer.
/// `current` counts are computed by scanning the live query registry.
/// `total` counts come from the atomic counters (incremented on query drop).
/// Returns 0 on success, -1 if buffer too small.
///
/// Buffer layout (32 bytes):
Expand All @@ -77,14 +60,26 @@ pub fn inc_native_search_shard_task_total() {
/// `out_ptr` must point to a valid buffer of at least `out_cap` bytes.
#[no_mangle]
pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) -> i64 {
use crate::query_tracker;

if (out_cap as usize) < 32 {
return -1;
}
// Read totals FIRST, then scan registry. This ordering prevents double-counting:
// if a query drops between our total read and scan, it increments total (which
// we already captured at the lower value) and leaves the registry (so the scan
// won't find it). Worst case: transient undercount by 1, self-corrects next read.
// The reverse order (scan first, then total) would allow a query to appear in
// both the scan result AND the incremented total.
let coordinator_total = NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed);
let shard_total = NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed);
let (shard_current, coordinator_current) =
query_tracker::count_cancelled_running(query_tracker::cancel_stats_threshold());
let vals: [i64; 4] = [
NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed),
NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed),
NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed),
NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed),
coordinator_current,
coordinator_total + coordinator_current,
shard_current,
shard_total + shard_current,
];
std::ptr::copy_nonoverlapping(vals.as_ptr() as *const u8, out_ptr, 32);
0
Expand All @@ -96,41 +91,25 @@ pub unsafe extern "C" fn df_native_node_stats(out_ptr: *mut u8, out_cap: i64) ->

#[cfg(test)]
pub(crate) fn reset_all_counters() {
NATIVE_SEARCH_TASK_CURRENT.store(0, Ordering::Relaxed);
NATIVE_SEARCH_TASK_TOTAL.store(0, Ordering::Relaxed);
NATIVE_SEARCH_SHARD_TASK_CURRENT.store(0, Ordering::Relaxed);
NATIVE_SEARCH_SHARD_TASK_TOTAL.store(0, Ordering::Relaxed);
}

#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;

/// Reset counters before each test to avoid cross-test interference.
fn setup() {
reset_all_counters();
}

#[test]
fn test_counters_initialized_to_zero() {
fn test_total_counters_initialized_to_zero() {
setup();
assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 0);
assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0);
assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0);
assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0);
}

#[test]
fn test_inc_dec_native_search_task_current() {
setup();
inc_native_search_task_current();
inc_native_search_task_current();
assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 2);
dec_native_search_task_current();
assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1);
}

#[test]
fn test_inc_native_search_task_total() {
setup();
Expand All @@ -140,48 +119,12 @@ mod tests {
assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 3);
}

#[test]
fn test_inc_dec_native_search_shard_task_current() {
setup();
inc_native_search_shard_task_current();
inc_native_search_shard_task_current();
inc_native_search_shard_task_current();
assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 3);
dec_native_search_shard_task_current();
dec_native_search_shard_task_current();
assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 1);
}

#[test]
fn test_inc_native_search_shard_task_total() {
setup();
inc_native_search_shard_task_total();
assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 1);
}

#[test]
fn test_df_native_node_stats_reads_correct_values() {
setup();
// Set up known state
inc_native_search_task_current();
inc_native_search_task_current();
inc_native_search_task_total();
inc_native_search_task_total();
inc_native_search_task_total();
inc_native_search_shard_task_current();
inc_native_search_shard_task_total();
inc_native_search_shard_task_total();

let mut buf = [0u8; 32];
let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) };
assert_eq!(rc, 0);

// Decode the buffer as [i64; 4]
let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) };
assert_eq!(vals[0], 2); // native_search_task_current
assert_eq!(vals[1], 3); // native_search_task_total
assert_eq!(vals[2], 1); // native_search_shard_task_current
assert_eq!(vals[3], 2); // native_search_shard_task_total
assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 2);
}

#[test]
Expand All @@ -207,103 +150,24 @@ mod tests {
}

#[test]
fn test_counters_are_independent() {
setup();
// Increment only one counter and verify others remain zero
inc_native_search_task_current();
assert_eq!(NATIVE_SEARCH_TASK_CURRENT.load(Ordering::Relaxed), 1);
assert_eq!(NATIVE_SEARCH_TASK_TOTAL.load(Ordering::Relaxed), 0);
assert_eq!(NATIVE_SEARCH_SHARD_TASK_CURRENT.load(Ordering::Relaxed), 0);
assert_eq!(NATIVE_SEARCH_SHARD_TASK_TOTAL.load(Ordering::Relaxed), 0);
}

// -----------------------------------------------------------------------
// Property-based test: AtomicI64 increment/decrement correctness
// -----------------------------------------------------------------------

/// **Validates: Requirements 1.1, 2.1–2.6**
///
/// Property 1: AtomicI64 increment/decrement correctness
///
/// For any random number of increments and decrements applied to each of
/// the 4 counters, reading the counters via `df_native_node_stats` SHALL
/// return values equal to the algebraic sum (increments - decrements) for
/// each counter.
proptest! {
#[test]
fn prop_atomic_inc_dec_correctness(
search_task_current_incs in 0u32..100,
search_task_current_decs in 0u32..100,
search_task_total_incs in 0u32..100,
shard_task_current_incs in 0u32..100,
shard_task_current_decs in 0u32..100,
shard_task_total_incs in 0u32..100,
) {
// Reset counters before each iteration
reset_all_counters();

// Apply increments and decrements for NATIVE_SEARCH_TASK_CURRENT
for _ in 0..search_task_current_incs {
inc_native_search_task_current();
}
for _ in 0..search_task_current_decs {
dec_native_search_task_current();
}

// Apply only increments for NATIVE_SEARCH_TASK_TOTAL (no dec function exists)
for _ in 0..search_task_total_incs {
inc_native_search_task_total();
}

// Apply increments and decrements for NATIVE_SEARCH_SHARD_TASK_CURRENT
for _ in 0..shard_task_current_incs {
inc_native_search_shard_task_current();
}
for _ in 0..shard_task_current_decs {
dec_native_search_shard_task_current();
}

// Apply only increments for NATIVE_SEARCH_SHARD_TASK_TOTAL (no dec function exists)
for _ in 0..shard_task_total_incs {
inc_native_search_shard_task_total();
}

// Expected algebraic sums
let expected_search_task_current =
search_task_current_incs as i64 - search_task_current_decs as i64;
let expected_search_task_total = search_task_total_incs as i64;
let expected_shard_task_current =
shard_task_current_incs as i64 - shard_task_current_decs as i64;
let expected_shard_task_total = shard_task_total_incs as i64;
fn test_df_native_node_stats_total_counters() {
// Read baseline (other tests may run in parallel and share globals)
let mut buf = [0u8; 32];
let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) };
assert_eq!(rc, 0);
let baseline: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) };

// Read counters via df_native_node_stats
let mut buf = [0u8; 32];
let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) };
prop_assert_eq!(rc, 0);
inc_native_search_task_total();
inc_native_search_task_total();
inc_native_search_shard_task_total();

// Decode the buffer as [i64; 4]
let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) };
let rc = unsafe { df_native_node_stats(buf.as_mut_ptr(), 32) };
assert_eq!(rc, 0);
let vals: [i64; 4] = unsafe { std::ptr::read(buf.as_ptr() as *const [i64; 4]) };

prop_assert_eq!(
vals[0], expected_search_task_current,
"native_search_task_current mismatch: incs={}, decs={}",
search_task_current_incs, search_task_current_decs
);
prop_assert_eq!(
vals[1], expected_search_task_total,
"native_search_task_total mismatch: incs={}",
search_task_total_incs
);
prop_assert_eq!(
vals[2], expected_shard_task_current,
"native_search_shard_task_current mismatch: incs={}, decs={}",
shard_task_current_incs, shard_task_current_decs
);
prop_assert_eq!(
vals[3], expected_shard_task_total,
"native_search_shard_task_total mismatch: incs={}",
shard_task_total_incs
);
}
// vals[1] = coordinator total — should be baseline + 2
assert_eq!(vals[1], baseline[1] + 2);
// vals[3] = shard total — should be baseline + 1
assert_eq!(vals[3], baseline[3] + 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ pub fn wrap_stream_as_handle(
let query_context = crate::query_tracker::QueryTrackingContext::new(
0,
runtime.runtime_env.memory_pool.clone(),
crate::query_tracker::QueryType::Shard,
);
let handle = crate::api::QueryStreamHandle::new(wrapped, query_context, None);
Box::into_raw(Box::new(handle)) as i64
Expand Down
Loading
Loading