diff --git a/plugins/arrow-base/src/main/java/org/opensearch/arrow/allocator/ArrowNativeAllocator.java b/plugins/arrow-base/src/main/java/org/opensearch/arrow/allocator/ArrowNativeAllocator.java index afab06a249d59..e886e24ef9c30 100644 --- a/plugins/arrow-base/src/main/java/org/opensearch/arrow/allocator/ArrowNativeAllocator.java +++ b/plugins/arrow-base/src/main/java/org/opensearch/arrow/allocator/ArrowNativeAllocator.java @@ -46,6 +46,7 @@ public class ArrowNativeAllocator implements NativeAllocator { private final RootAllocator root; private final ConcurrentMap pools = new ConcurrentHashMap<>(); + private final ConcurrentMap virtualPools = new ConcurrentHashMap<>(); private final ConcurrentMap poolMins = new ConcurrentHashMap<>(); private final ConcurrentMap poolMaxes = new ConcurrentHashMap<>(); private final ScheduledExecutorService rebalancer; @@ -178,6 +179,11 @@ public void setRootLimit(long limit) { @Override public NativeAllocatorPoolStats stats() { + // Refresh Rust-side stats before collecting + Runnable refresher = this.virtualPoolStatsRefresher; + if (refresher != null) { + refresher.run(); + } List poolStats = new ArrayList<>(); for (var entry : pools.entrySet()) { BufferAllocator alloc = entry.getValue().allocator; @@ -191,6 +197,11 @@ public NativeAllocatorPoolStats stats() { ) ); } + // Include Rust-side virtual pools in the unified stats view + for (var entry : virtualPools.entrySet()) { + VirtualPoolHandle vp = entry.getValue(); + poolStats.add(new NativeAllocatorPoolStats.PoolStats(entry.getKey(), vp.allocatedBytes(), vp.peakBytes(), vp.limit(), 0)); + } return new NativeAllocatorPoolStats(root.getAllocatedMemory(), root.getPeakMemoryAllocation(), root.getLimit(), poolStats); } @@ -352,4 +363,86 @@ public BufferAllocator getAllocator() { return allocator; } } + + /** + * A virtual pool handle for Rust-side memory pools that report stats back to Java + * without using Arrow's BufferAllocator. The Rust side enforces its own limit locally; + * this handle provides a unified view in {@link #stats()}. + */ + public static class VirtualPoolHandle implements PoolHandle { + + private final String name; + private final long limit; + private volatile long allocatedBytes; + private volatile long peakBytes; + + VirtualPoolHandle(String name, long limit) { + this.name = name; + this.limit = limit; + } + + /** Called by Rust (via FFM) to report current usage. */ + public void updateStats(long allocated, long peak) { + this.allocatedBytes = allocated; + this.peakBytes = peak; + } + + @Override + public PoolHandle newChild(String childName, long childLimit) { + throw new UnsupportedOperationException("Virtual pool [" + name + "] does not support children"); + } + + @Override + public long allocatedBytes() { + return allocatedBytes; + } + + @Override + public long peakBytes() { + return peakBytes; + } + + @Override + public long limit() { + return limit; + } + + @Override + public void close() { + // No-op — Rust owns the memory lifecycle + } + } + + /** + * Registers a virtual pool for a Rust-side memory pool. The Rust side tracks + * allocations locally and periodically reports usage via {@link VirtualPoolHandle#updateStats}. + * Stats appear in {@link #stats()} alongside real Arrow pools. + * + * @param poolName logical pool name (e.g., "write") + * @param limit the limit enforced by Rust + * @return the virtual handle (caller passes to Rust for stats reporting) + */ + public VirtualPoolHandle registerVirtualPool(String poolName, long limit) { + VirtualPoolHandle handle = new VirtualPoolHandle(poolName, limit); + virtualPools.put(poolName, handle); + return handle; + } + + /** + * Returns the virtual pool handle for a given name, or null if not registered. + */ + public VirtualPoolHandle getVirtualPool(String poolName) { + return virtualPools.get(poolName); + } + + private volatile Runnable virtualPoolStatsRefresher; + + /** + * Registers a callback that refreshes virtual pool stats from the native layer. + * Called by the parquet plugin during initialization. The callback is invoked + * before stats() to ensure returned values reflect current Rust-side usage. + */ + public void setVirtualPoolStatsRefresher(Runnable refresher) { + this.virtualPoolStatsRefresher = refresher; + } } diff --git a/plugins/arrow-base/src/test/java/org/opensearch/arrow/allocator/ArrowNativeAllocatorTests.java b/plugins/arrow-base/src/test/java/org/opensearch/arrow/allocator/ArrowNativeAllocatorTests.java index 880ea445c3ea7..bf0b8c83796f2 100644 --- a/plugins/arrow-base/src/test/java/org/opensearch/arrow/allocator/ArrowNativeAllocatorTests.java +++ b/plugins/arrow-base/src/test/java/org/opensearch/arrow/allocator/ArrowNativeAllocatorTests.java @@ -232,4 +232,57 @@ public void testSetPoolMinDoesNotShrinkLiveLimit() { allocator.setPoolMin("p", 1L); assertEquals("dropping min must not shrink live limit", startLimit, pool.getLimit()); } + + public void testVirtualPoolRegistration() { + var handle = allocator.registerVirtualPool("write", 512 * 1024 * 1024); + assertNotNull(handle); + assertEquals(512 * 1024 * 1024, handle.limit()); + assertEquals(0, handle.allocatedBytes()); + assertEquals(0, handle.peakBytes()); + assertSame(handle, allocator.getVirtualPool("write")); + } + + public void testVirtualPoolUpdateStats() { + var handle = allocator.registerVirtualPool("merge", 0); + handle.updateStats(100_000, 200_000); + assertEquals(100_000, handle.allocatedBytes()); + assertEquals(200_000, handle.peakBytes()); + } + + public void testVirtualPoolAppearsInStats() { + allocator.registerVirtualPool("write", 0); + allocator.getVirtualPool("write").updateStats(42, 99); + + NativeAllocatorPoolStats stats = allocator.stats(); + boolean found = false; + for (NativeAllocatorPoolStats.PoolStats ps : stats.getPools()) { + if ("write".equals(ps.getName())) { + assertEquals(42, ps.getAllocatedBytes()); + assertEquals(99, ps.getPeakBytes()); + found = true; + } + } + assertTrue("Virtual pool 'write' should appear in stats", found); + } + + public void testStatsRefresherCalledBeforeStats() { + var handle = allocator.registerVirtualPool("write", 0); + allocator.setVirtualPoolStatsRefresher(() -> handle.updateStats(777, 888)); + + NativeAllocatorPoolStats stats = allocator.stats(); + boolean found = false; + for (NativeAllocatorPoolStats.PoolStats ps : stats.getPools()) { + if ("write".equals(ps.getName())) { + assertEquals(777, ps.getAllocatedBytes()); + assertEquals(888, ps.getPeakBytes()); + found = true; + } + } + assertTrue("Refresher should have been called before collecting stats", found); + } + + public void testVirtualPoolDoesNotSupportChildren() { + var handle = allocator.registerVirtualPool("vp", 0); + expectThrows(UnsupportedOperationException.class, () -> handle.newChild("child", 100)); + } } diff --git a/sandbox/libs/dataformat-native/rust/common/src/lib.rs b/sandbox/libs/dataformat-native/rust/common/src/lib.rs index 0f4b8c132407f..c44fa871c4fb3 100644 --- a/sandbox/libs/dataformat-native/rust/common/src/lib.rs +++ b/sandbox/libs/dataformat-native/rust/common/src/lib.rs @@ -11,6 +11,7 @@ pub mod error; pub mod logger; pub mod allocator; +pub mod memory_pool; // Re-export the proc macro so plugins use `#[native_bridge_common::ffm_safe]` pub use native_bridge_macros::ffm_safe; diff --git a/sandbox/libs/dataformat-native/rust/common/src/memory_pool.rs b/sandbox/libs/dataformat-native/rust/common/src/memory_pool.rs new file mode 100644 index 0000000000000..282b28ceaa8da --- /dev/null +++ b/sandbox/libs/dataformat-native/rust/common/src/memory_pool.rs @@ -0,0 +1,358 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +//! Memory pool for tracking native memory usage across write and merge operations. +//! +//! Provides a simple atomic counter with an optional limit. Operations that allocate +//! significant memory (RecordBatch buffering, sort read-back, merge cursors) call +//! `try_grow` before allocating and `shrink` after freeing. The pool rejects +//! allocations that would exceed the configured limit. +//! +//! `MemoryReservation` is an RAII handle that automatically returns memory to the +//! pool on drop, preventing leaks even on error paths. + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::fmt; +use crate::{log_info, log_error}; + +/// Error returned when a pool cannot satisfy an allocation request. +#[derive(Debug, Clone)] +pub struct PoolExhausted { + pub pool_name: &'static str, + pub requested: usize, + pub used: usize, + pub limit: usize, +} + +impl fmt::Display for PoolExhausted { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "[{}] memory limit exceeded: requested {} bytes, used {}, limit {}", + self.pool_name, self.requested, self.used, self.limit + ) + } +} + +impl std::error::Error for PoolExhausted {} + +/// A node-level memory pool backed by an atomic counter. +#[derive(Debug)] +pub struct MemoryPool { + name: &'static str, + used: AtomicUsize, + limit: AtomicUsize, + peak: AtomicUsize, +} + +impl MemoryPool { + /// Create a new pool. `limit = 0` means unlimited. + pub fn new(name: &'static str, limit: usize) -> Self { + Self { + name, + used: AtomicUsize::new(0), + limit: AtomicUsize::new(limit), + peak: AtomicUsize::new(0), + } + } + + /// Attempt to reserve `bytes`. Returns error if it would exceed the limit. + pub fn try_grow(&self, bytes: usize, consumer: &str) -> Result<(), PoolExhausted> { + if bytes == 0 { + return Ok(()); + } + let limit = self.limit.load(Ordering::Relaxed); + let result = self.used.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |used| { + let new_used = used.checked_add(bytes)?; + if limit > 0 && new_used > limit { + None + } else { + Some(new_used) + } + }); + + match result { + Ok(old) => { + let new_used = old + bytes; + self.peak.fetch_max(new_used, Ordering::Relaxed); + log_info!( + "[{}] +{} bytes for '{}' (used: {}, limit: {})", + self.name, bytes, consumer, new_used, limit + ); + Ok(()) + } + Err(_) => { + let used = self.used.load(Ordering::Relaxed); + log_info!( + "[{}] REJECTED +{} bytes for '{}' (used: {}, limit: {})", + self.name, bytes, consumer, used, limit + ); + Err(PoolExhausted { + pool_name: self.name, + requested: bytes, + used, + limit, + }) + } + } + } + + /// Infallible grow — use when the allocation has already happened and must be tracked. + pub fn grow(&self, bytes: usize, consumer: &str) { + if bytes == 0 { + return; + } + let new_used = self.used.fetch_add(bytes, Ordering::Relaxed) + bytes; + self.peak.fetch_max(new_used, Ordering::Relaxed); + log_info!( + "[{}] +{} bytes for '{}' (used: {}, limit: {})", + self.name, bytes, consumer, new_used, self.limit.load(Ordering::Relaxed) + ); + } + + /// Release `bytes` back to the pool. + pub fn shrink(&self, bytes: usize, consumer: &str) { + if bytes == 0 { + return; + } + let old = self.used.fetch_sub(bytes, Ordering::Relaxed); + if old < bytes { + log_error!( + "[{}] UNDERFLOW: shrink {} bytes for '{}' but only {} was tracked", + self.name, bytes, consumer, old + ); + } else { + log_info!( + "[{}] -{} bytes for '{}' (used: {}, limit: {})", + self.name, bytes, consumer, old - bytes, self.limit.load(Ordering::Relaxed) + ); + } + } + + pub fn used(&self) -> usize { + self.used.load(Ordering::Relaxed) + } + + pub fn peak(&self) -> usize { + self.peak.load(Ordering::Relaxed) + } + + pub fn limit(&self) -> usize { + self.limit.load(Ordering::Relaxed) + } + + pub fn name(&self) -> &'static str { + self.name + } + + pub fn set_limit(&self, new_limit: usize) { + let old = self.limit.swap(new_limit, Ordering::Relaxed); + log_info!("[{}] limit changed: {} -> {}", self.name, old, new_limit); + } +} + +/// RAII handle that tracks a portion of memory reserved from a [`MemoryPool`]. +/// Automatically releases all held memory on drop. +pub struct MemoryReservation { + pool: Arc, + consumer: &'static str, + size: usize, +} + +impl MemoryReservation { + pub fn new(pool: &Arc, consumer: &'static str) -> Self { + Self { + pool: Arc::clone(pool), + consumer, + size: 0, + } + } + + /// Try to grow this reservation. On failure, the reservation is unchanged. + pub fn try_grow(&mut self, bytes: usize) -> Result<(), PoolExhausted> { + self.pool.try_grow(bytes, self.consumer)?; + self.size += bytes; + Ok(()) + } + + /// Infallible grow. + pub fn grow(&mut self, bytes: usize) { + self.pool.grow(bytes, self.consumer); + self.size += bytes; + } + + /// Release `bytes` from this reservation. + pub fn shrink(&mut self, bytes: usize) { + let actual = bytes.min(self.size); + self.pool.shrink(actual, self.consumer); + self.size -= actual; + } + + /// Resize to a new total. Grows or shrinks as needed. + pub fn resize(&mut self, new_total: usize) { + if new_total > self.size { + self.grow(new_total - self.size); + } else if new_total < self.size { + self.shrink(self.size - new_total); + } + } + + /// Release all memory back to the pool. Returns bytes freed. + pub fn free(&mut self) -> usize { + let s = self.size; + if s > 0 { + self.pool.shrink(s, self.consumer); + self.size = 0; + } + s + } + + pub fn size(&self) -> usize { + self.size + } + + pub fn consumer(&self) -> &'static str { + self.consumer + } +} + +impl Drop for MemoryReservation { + fn drop(&mut self) { + if self.size > 0 { + log_info!( + "[{}] reservation '{}' dropped, releasing {} bytes", + self.pool.name, self.consumer, self.size + ); + self.pool.shrink(self.size, self.consumer); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_pool(limit: usize) -> Arc { + Arc::new(MemoryPool::new("TEST", limit)) + } + + #[test] + fn basic_grow_shrink() { + let pool = test_pool(0); // unlimited + let mut res = MemoryReservation::new(&pool, "test"); + res.grow(100); + assert_eq!(res.size(), 100); + assert_eq!(pool.used(), 100); + res.shrink(40); + assert_eq!(res.size(), 60); + assert_eq!(pool.used(), 60); + res.free(); + assert_eq!(pool.used(), 0); + } + + #[test] + fn try_grow_within_limit() { + let pool = test_pool(1000); + let mut res = MemoryReservation::new(&pool, "test"); + assert!(res.try_grow(500).is_ok()); + assert!(res.try_grow(400).is_ok()); + assert_eq!(pool.used(), 900); + } + + #[test] + fn try_grow_exceeds_limit() { + let pool = test_pool(1000); + let mut res = MemoryReservation::new(&pool, "test"); + assert!(res.try_grow(500).is_ok()); + let err = res.try_grow(600).unwrap_err(); + assert_eq!(err.requested, 600); + assert_eq!(err.used, 500); + assert_eq!(err.limit, 1000); + assert_eq!(res.size(), 500); // unchanged + assert_eq!(pool.used(), 500); + } + + #[test] + fn drop_releases_memory() { + let pool = test_pool(0); + { + let mut res = MemoryReservation::new(&pool, "test"); + res.grow(200); + assert_eq!(pool.used(), 200); + } // res dropped here + assert_eq!(pool.used(), 0); + } + + #[test] + fn resize() { + let pool = test_pool(0); + let mut res = MemoryReservation::new(&pool, "test"); + res.resize(100); + assert_eq!(res.size(), 100); + assert_eq!(pool.used(), 100); + res.resize(50); + assert_eq!(res.size(), 50); + assert_eq!(pool.used(), 50); + res.resize(200); + assert_eq!(res.size(), 200); + assert_eq!(pool.used(), 200); + } + + #[test] + fn multiple_reservations_share_pool() { + let pool = test_pool(1000); + let mut r1 = MemoryReservation::new(&pool, "writer1"); + let mut r2 = MemoryReservation::new(&pool, "writer2"); + r1.try_grow(400).unwrap(); + r2.try_grow(400).unwrap(); + assert_eq!(pool.used(), 800); + // Third allocation that would exceed + assert!(r2.try_grow(300).is_err()); + assert_eq!(pool.used(), 800); + drop(r1); + assert_eq!(pool.used(), 400); + // Now it fits + assert!(r2.try_grow(300).is_ok()); + assert_eq!(pool.used(), 700); + } + + #[test] + fn peak_tracking() { + let pool = test_pool(0); + let mut res = MemoryReservation::new(&pool, "test"); + res.grow(100); + res.grow(200); + assert_eq!(pool.peak(), 300); + res.shrink(250); + assert_eq!(pool.peak(), 300); // peak unchanged + assert_eq!(pool.used(), 50); + } + + #[test] + fn set_limit_at_runtime() { + let pool = test_pool(100); + let mut res = MemoryReservation::new(&pool, "test"); + assert!(res.try_grow(80).is_ok()); + assert!(res.try_grow(30).is_err()); // 80+30 > 100 + pool.set_limit(200); + assert!(res.try_grow(30).is_ok()); // 80+30 < 200 + assert_eq!(pool.used(), 110); + } + + #[test] + fn zero_bytes_is_noop() { + let pool = test_pool(100); + let mut res = MemoryReservation::new(&pool, "test"); + assert!(res.try_grow(0).is_ok()); + res.grow(0); + res.shrink(0); + assert_eq!(pool.used(), 0); + assert_eq!(res.size(), 0); + } +} diff --git a/sandbox/libs/dataformat-native/rust/lib/Cargo.toml b/sandbox/libs/dataformat-native/rust/lib/Cargo.toml index 6eadb23e82a21..bc23046a898cb 100644 --- a/sandbox/libs/dataformat-native/rust/lib/Cargo.toml +++ b/sandbox/libs/dataformat-native/rust/lib/Cargo.toml @@ -9,6 +9,9 @@ license = "Apache-2.0" name = "opensearch_native" crate-type = ["cdylib"] +[features] +test-limits = ["opensearch-parquet-format/test-limits"] + [dependencies] opensearch-datafusion = { path = "../../../../plugins/analytics-backend-datafusion/rust" } opensearch-parquet-format = { path = "../../../../plugins/parquet-data-format/src/main/rust" } diff --git a/sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/DelegationThreadTracker.java b/sandbox/plugins/analytics-backend-datafusion/src/main/java/DelegationThreadTracker.java similarity index 100% rename from sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/DelegationThreadTracker.java rename to sandbox/plugins/analytics-backend-datafusion/src/main/java/DelegationThreadTracker.java diff --git a/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/ParquetDataFormatPlugin.java b/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/ParquetDataFormatPlugin.java index 5a368f9753cb8..fcd1b70c1915e 100644 --- a/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/ParquetDataFormatPlugin.java +++ b/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/ParquetDataFormatPlugin.java @@ -112,6 +112,22 @@ public Collection createComponents( .addSettingsUpdateConsumer(ParquetSettings.MAX_PER_VSR_ALLOCATION_DIVISOR, v -> this.maxPerVsrAllocationDivisor = v); this.nativeAllocator = pluginComponentRegistry.getComponent(ArrowNativeAllocator.class) .orElseThrow(() -> new IllegalStateException("ArrowNativeAllocator not available; arrow-base plugin must be installed")); + + // Register virtual pools for Rust-side write and merge memory tracking. + // Stats are refreshed on demand when _nodes/stats is called. + var writePool = this.nativeAllocator.registerVirtualPool("write", 0); + var mergePool = this.nativeAllocator.registerVirtualPool("merge", 0); + this.nativeAllocator.setVirtualPoolStatsRefresher(() -> { + try { + long[] stats = org.opensearch.parquet.bridge.RustBridge.getPoolStats(); + // Layout: [writeUsed, writePeak, writeLimit, mergeUsed, mergePeak, mergeLimit] + writePool.updateStats(stats[0], stats[1]); + mergePool.updateStats(stats[3], stats[4]); + } catch (Exception e) { + // Best-effort — stats may be stale if native lib is unavailable + } + }); + return Collections.emptyList(); } diff --git a/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/bridge/RustBridge.java b/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/bridge/RustBridge.java index e59a3549a0dd1..423a71566dbbc 100644 --- a/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/bridge/RustBridge.java +++ b/sandbox/plugins/parquet-data-format/src/main/java/org/opensearch/parquet/bridge/RustBridge.java @@ -45,6 +45,7 @@ public class RustBridge { private static final MethodHandle FREE_MERGE_RESULT; private static final MethodHandle READ_AS_JSON; private static final MethodHandle FREE_ROW_ID_MAPPING; + private static final MethodHandle GET_POOL_STATS; static { SymbolLookup lib = NativeLibraryLoader.symbolLookup(); @@ -251,6 +252,11 @@ public class RustBridge { ValueLayout.JAVA_LONG // mapping_len ) ); + // i64 parquet_get_pool_stats(out_ptr, out_cap) → writes 6 × i64 + GET_POOL_STATS = linker.downcallHandle( + lib.find("parquet_get_pool_stats").orElseThrow(), + FunctionDescriptor.of(ValueLayout.JAVA_LONG, ValueLayout.ADDRESS, ValueLayout.JAVA_LONG) + ); } public static void initLogger() {} @@ -688,5 +694,29 @@ private static LongMapArrays toLongMapArrays(NativeCall call, Map return new LongMapArrays(call.strArray(keys), seg); } + /** + * Fetches write and merge pool stats from Rust in a single FFM call. + * Returns [writeUsed, writePeak, writeLimit, mergeUsed, mergePeak, mergeLimit]. + */ + public static long[] getPoolStats() { + try (var arena = java.lang.foreign.Arena.ofConfined()) { + var buf = arena.allocate(ValueLayout.JAVA_LONG, 6); + long rc; + try { + rc = (long) GET_POOL_STATS.invokeExact(buf, 6L); + } catch (Throwable t) { + throw new RuntimeException("parquet_get_pool_stats failed", t); + } + if (rc != 0) { + throw new RuntimeException("parquet_get_pool_stats returned error: " + rc); + } + long[] stats = new long[6]; + for (int i = 0; i < 6; i++) { + stats[i] = buf.getAtIndex(ValueLayout.JAVA_LONG, i); + } + return stats; + } + } + private RustBridge() {} } diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/Cargo.toml b/sandbox/plugins/parquet-data-format/src/main/rust/Cargo.toml index 365f571c62c5d..3020d4f5cfc25 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/Cargo.toml +++ b/sandbox/plugins/parquet-data-format/src/main/rust/Cargo.toml @@ -7,6 +7,7 @@ workspace = "../../../../../libs/dataformat-native/rust" [features] test-utils = [] +test-limits = [] [lib] name = "opensearch_parquet_format" diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/ffm.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/ffm.rs index e96d6b70d5b00..00bf9243baa48 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/ffm.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/ffm.rs @@ -14,7 +14,7 @@ use std::slice; use std::str; -use native_bridge_common::{ffm_safe, log_debug}; +use native_bridge_common::ffm_safe; use crate::native_settings::NativeSettings; use crate::field_config::FieldConfig; @@ -694,3 +694,82 @@ pub unsafe extern "C" fn parquet_free_row_id_mapping( let _ = Box::from_raw(slice::from_raw_parts_mut(mapping_ptr as *mut i64, mapping_len as usize)); } } + +// --------------------------------------------------------------------------- +// Memory pool management +// --------------------------------------------------------------------------- + +/// Initialize write and merge memory pools with limits (bytes). 0 = unlimited. +#[no_mangle] +pub extern "C" fn parquet_init_memory_pools(write_limit: i64, merge_limit: i64) -> i64 { + let wl = if write_limit < 0 { 0 } else { write_limit as usize }; + let ml = if merge_limit < 0 { 0 } else { merge_limit as usize }; + crate::memory::init_pools(wl, ml); + 0 +} + +/// Set write pool limit at runtime. +#[no_mangle] +pub extern "C" fn parquet_set_write_pool_limit(limit: i64) -> i64 { + let l = if limit < 0 { 0 } else { limit as usize }; + crate::memory::write_pool().set_limit(l); + 0 +} + +/// Set merge pool limit at runtime. +#[no_mangle] +pub extern "C" fn parquet_set_merge_pool_limit(limit: i64) -> i64 { + let l = if limit < 0 { 0 } else { limit as usize }; + crate::memory::merge_pool().set_limit(l); + 0 +} + +/// Returns current write pool usage in bytes. +#[no_mangle] +pub extern "C" fn parquet_get_write_pool_used() -> i64 { + crate::memory::write_pool().used() as i64 +} + +/// Returns current merge pool usage in bytes. +#[no_mangle] +pub extern "C" fn parquet_get_merge_pool_used() -> i64 { + crate::memory::merge_pool().used() as i64 +} + +/// Writes all pool stats into a caller-provided buffer in one FFM call. +/// Layout (6 × i64 = 48 bytes): +/// [0]: write_used +/// [1]: write_peak +/// [2]: write_limit +/// [3]: merge_used +/// [4]: merge_peak +/// [5]: merge_limit +/// +/// Returns 0 on success, -1 if buffer too small. +#[no_mangle] +pub unsafe extern "C" fn parquet_get_pool_stats(out_ptr: *mut i64, out_cap: i64) -> i64 { + if out_cap < 6 { + return -1; + } + let wp = crate::memory::write_pool(); + let mp = crate::memory::merge_pool(); + *out_ptr.add(0) = wp.used() as i64; + *out_ptr.add(1) = wp.peak() as i64; + *out_ptr.add(2) = wp.limit() as i64; + *out_ptr.add(3) = mp.used() as i64; + *out_ptr.add(4) = mp.peak() as i64; + *out_ptr.add(5) = mp.limit() as i64; + 0 +} + +/// Returns peak write pool usage in bytes. +#[no_mangle] +pub extern "C" fn parquet_get_write_pool_peak() -> i64 { + crate::memory::write_pool().peak() as i64 +} + +/// Returns peak merge pool usage in bytes. +#[no_mangle] +pub extern "C" fn parquet_get_merge_pool_peak() -> i64 { + crate::memory::merge_pool().peak() as i64 +} diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/lib.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/lib.rs index 2ce15506f12c4..72d27d421a919 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/lib.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/lib.rs @@ -20,6 +20,7 @@ pub mod writer_properties_builder; pub mod rate_limited_writer; pub mod crc_writer; pub mod merge; +pub mod memory; pub use native_settings::NativeSettings; pub use field_config::FieldConfig; diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/memory.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/memory.rs new file mode 100644 index 0000000000000..e81d8de3e2bbd --- /dev/null +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/memory.rs @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +//! Global memory pool instances for write and merge operations. + +use std::sync::{Arc, OnceLock}; +use native_bridge_common::memory_pool::MemoryPool; + +static WRITE_POOL: OnceLock> = OnceLock::new(); +static MERGE_POOL: OnceLock> = OnceLock::new(); + +/// Default: 0 (unlimited). Set via FFI at runtime. +/// For testing limit enforcement, build with: cargo build --features test-limits +#[cfg(not(feature = "test-limits"))] +const DEFAULT_WRITE_LIMIT: usize = 0; +#[cfg(feature = "test-limits")] +const DEFAULT_WRITE_LIMIT: usize = 200 * 1024; + +/// Default: 0 (unlimited). Set via FFI at runtime. +#[cfg(not(feature = "test-limits"))] +const DEFAULT_MERGE_LIMIT: usize = 0; +#[cfg(feature = "test-limits")] +const DEFAULT_MERGE_LIMIT: usize = 100 * 1024; + +/// Returns the node-level write memory pool. +pub fn write_pool() -> &'static Arc { + WRITE_POOL.get_or_init(|| Arc::new(MemoryPool::new("WRITE", DEFAULT_WRITE_LIMIT))) +} + +/// Returns the node-level merge memory pool. +pub fn merge_pool() -> &'static Arc { + MERGE_POOL.get_or_init(|| Arc::new(MemoryPool::new("MERGE", DEFAULT_MERGE_LIMIT))) +} + +/// Initialize both pools with limits. Called from FFI at node startup. +pub fn init_pools(write_limit: usize, merge_limit: usize) { + write_pool().set_limit(write_limit); + merge_pool().set_limit(merge_limit); +} diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/context.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/context.rs index e71707b21b9a0..f990b0b62bca3 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/context.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/context.rs @@ -20,9 +20,11 @@ use rayon::prelude::*; use tokio::sync::{mpsc as tokio_mpsc, oneshot}; use crate::crc_writer::CrcWriter; +use crate::memory::merge_pool; use crate::rate_limited_writer::RateLimitedWriter; use crate::writer_properties_builder::WriterPropertiesBuilder; use crate::{log_debug, SETTINGS_STORE}; +use native_bridge_common::memory_pool::MemoryReservation; use super::error::{MergeError, MergeResult}; use super::io_task::{ @@ -45,6 +47,7 @@ pub struct MergeContext { next_row_id: i64, total_rows_written: usize, rayon_threads: Option, + reservation: MemoryReservation, } impl MergeContext { @@ -121,6 +124,7 @@ impl MergeContext { next_row_id: 0, total_rows_written: 0, rayon_threads, + reservation: MemoryReservation::new(merge_pool(), "merge_output_buffer"), }) } @@ -131,6 +135,10 @@ impl MergeContext { /// Buffers a batch (already padded to data_schema) and auto-flushes when /// the row count threshold is reached. pub fn push_batch(&mut self, batch: RecordBatch) -> MergeResult<()> { + let batch_bytes = batch.get_array_memory_size(); + if let Err(e) = self.reservation.try_grow(batch_bytes) { + return Err(MergeError::Logic(format!("Merge memory limit exceeded: {}", e))); + } self.output_row_count += batch.num_rows(); self.output_chunks.push(batch); if self.output_row_count >= self.output_flush_rows { @@ -155,8 +163,15 @@ impl MergeContext { }; let n = merged.num_rows(); + // Track temporary spike: merged + with_id coexist briefly + let merged_bytes = merged.get_array_memory_size(); + self.reservation.grow(merged_bytes); + let with_id = append_row_id(&merged, self.next_row_id, &self.output_schema)?; + let with_id_bytes = with_id.get_array_memory_size(); + self.reservation.grow(with_id_bytes); drop(merged); + self.reservation.shrink(merged_bytes); let col_writers = self .rg_writer_factory @@ -198,6 +213,9 @@ impl MergeContext { self.total_rows_written += n; self.output_row_count = 0; + // Release buffered batch memory — data has been encoded and sent to IO + self.reservation.free(); + log_debug!( "[RUST] Flushed row group {}: {} rows (total: {})", self.row_group_index - 1, diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/sorted.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/sorted.rs index 000fff97b2bde..7a5248bc0ceb6 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/sorted.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/sorted.rs @@ -14,6 +14,8 @@ use arrow::datatypes::Schema as ArrowSchema; use parquet::schema::types::SchemaDescriptor; use crate::log_debug; +use crate::memory::merge_pool; +use native_bridge_common::memory_pool::MemoryReservation; use super::context::MergeContext; use super::cursor::FileCursor; @@ -112,7 +114,18 @@ pub fn merge_sorted( // Row-ID mapping: pre-allocate the flat mapping array and compute offsets // from file metadata row counts (known before reading any data). let total_rows: usize = file_row_counts.iter().sum(); + let mut reservation = MemoryReservation::new(merge_pool(), "merge_sorted"); + // Track mapping vec allocation + reservation.grow(total_rows * std::mem::size_of::()); let mut mapping: Vec = vec![0i64; total_rows]; + + // Track cursor batch memory: each cursor holds one batch + one prefetched batch + for cursor in &cursors { + if let Some(batch) = &cursor.current_batch { + // 2x: current batch + prefetched next batch in channel + reservation.grow(batch.get_array_memory_size() * 2); + } + } let mut gen_keys: Vec = Vec::with_capacity(num_cursors); let mut gen_offsets: Vec = Vec::with_capacity(num_cursors); let mut gen_sizes: Vec = Vec::with_capacity(num_cursors); diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/unsorted.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/unsorted.rs index 61222edb504f3..7ba2b06c6450b 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/unsorted.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/merge/unsorted.rs @@ -14,6 +14,8 @@ use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchR use parquet::schema::types::SchemaDescriptor; use crate::log_debug; +use crate::memory::merge_pool; +use native_bridge_common::memory_pool::MemoryReservation; use super::context::MergeContext; use super::error::MergeResult; @@ -87,6 +89,8 @@ pub fn merge_unsorted( // Build row-ID mapping: for unsorted merge, files are concatenated sequentially. // old_row_id maps directly to new_row_id with a per-file offset. let total_rows: usize = file_row_counts.iter().sum(); + let mut reservation = MemoryReservation::new(merge_pool(), "merge_unsorted"); + reservation.grow(total_rows * std::mem::size_of::()); let mut mapping: Vec = vec![0i64; total_rows]; let mut gen_keys: Vec = Vec::with_capacity(input_files.len()); let mut gen_offsets: Vec = Vec::with_capacity(input_files.len()); diff --git a/sandbox/plugins/parquet-data-format/src/main/rust/src/writer.rs b/sandbox/plugins/parquet-data-format/src/main/rust/src/writer.rs index 104fdc0a054bf..a3407d01a068d 100644 --- a/sandbox/plugins/parquet-data-format/src/main/rust/src/writer.rs +++ b/sandbox/plugins/parquet-data-format/src/main/rust/src/writer.rs @@ -22,9 +22,11 @@ use std::sync::{Arc, Mutex}; use crate::{log_error, log_debug, log_info}; use crate::crc_writer::CrcWriter; +use crate::memory::write_pool; use crate::merge::{merge_sorted, schema::ROW_ID_COLUMN_NAME}; use crate::native_settings::NativeSettings; use crate::writer_properties_builder::WriterPropertiesBuilder; +use native_bridge_common::memory_pool::MemoryReservation; /// Result from finalizing a writer: Parquet metadata + whole-file CRC32 + optional sort permutation. #[derive(Debug)] @@ -95,6 +97,8 @@ struct SortingChunkedWriter { total_rows: usize, /// Writer generation propagated into Parquet file metadata for each chunk. writer_generation: i64, + /// Memory reservation tracking sort-phase allocations. + reservation: MemoryReservation, } impl SortingChunkedWriter { @@ -108,6 +112,7 @@ impl SortingChunkedWriter { nulls_first: Vec, writer_generation: i64, ) -> Result> { + let reservation = MemoryReservation::new(write_pool(), "sorted_writer"); let mut writer = Self { base_path, schema, @@ -125,6 +130,7 @@ impl SortingChunkedWriter { chunk_crcs: Vec::new(), total_rows: 0, writer_generation, + reservation, }; writer.open_new_ipc()?; Ok(writer) @@ -233,29 +239,41 @@ impl SortingChunkedWriter { let file = File::open(&ipc_path)?; let reader = IpcFileReader::try_new(file, None)?; let mut batches: Vec = Vec::new(); + let mut batch_bytes: usize = 0; for batch_result in reader { let batch = batch_result?; if batch.num_rows() > 0 { + batch_bytes += batch.get_array_memory_size(); batches.push(batch); } } if batches.is_empty() { - // Nothing to sort, just reopen let _ = std::fs::remove_file(&ipc_path); self.open_new_ipc()?; return Ok(()); } - // Concat and sort + // Reserve for the full sort peak: read-back + concat + sort output coexist briefly. + // Peak is ~3x the data: batches(1x) + combined(1x) + sorted(1x). + // We reserve 3x upfront, then shrink as intermediates are freed. + let sort_peak = batch_bytes * 3; + self.reservation.try_grow(sort_peak) + .map_err(|e| -> Box { + format!("Write memory limit exceeded during sort (need {} bytes): {}", sort_peak, e).into() + })?; + + // Concat all batches into one (peak: batches + combined = 2x) let combined = concat_batches(&self.schema, &batches)?; - drop(batches); // free memory before sort allocates + drop(batches); // free batches → now holding combined + sorted reservation (2x used, 1x freed) + + // Sort (peak: combined + sorted = 2x) let sorted_batch = NativeParquetWriter::sort_batch( &combined, &self.sort_columns, &self.reverse_sorts, &self.nulls_first, )?; - drop(combined); // free unsorted data + drop(combined); // free combined → now holding only sorted (1x used, 2x freed) - // Capture original row IDs for permutation building, then rewrite to sequential 0..N + // Capture original row IDs for permutation building let row_id_col_idx = self.schema.fields().iter().position(|f| f.name() == ROW_ID_COLUMN_NAME); let final_batch = if let Some(idx) = row_id_col_idx { let row_id_array = sorted_batch.column(idx) @@ -291,6 +309,11 @@ impl SortingChunkedWriter { // Delete the IPC staging file and open a fresh one let _ = std::fs::remove_file(&ipc_path); self.open_new_ipc()?; + + // Release all sort-phase memory — data is now on disk + self.reservation.shrink(sort_peak); + // Track chunk_row_ids growth (long-lived until finish()) + self.reservation.resize(self.memory_size()); Ok(()) } @@ -335,6 +358,7 @@ struct WriterState { settings: NativeSettings, crc_handle: Option, writer_generation: i64, + reservation: MemoryReservation, } /// Path suffix for the intermediate Arrow IPC file used during sort-on-close. @@ -434,11 +458,14 @@ impl NativeParquetWriter { (WriterVariant::Parquet(Arc::new(Mutex::new(writer))), Some(crc_handle)) }; + let reservation = MemoryReservation::new(write_pool(), "parquet_writer"); + WRITERS.insert(temp_filename, WriterState { variant, settings, crc_handle, writer_generation, + reservation, }); Ok(()) @@ -464,17 +491,41 @@ impl NativeParquetWriter { let record_batch = RecordBatch::try_new(schema, struct_array.columns().to_vec())?; log_debug!("Created RecordBatch with {} rows and {} columns", record_batch.num_rows(), record_batch.num_columns()); - if let Some(state) = WRITERS.get_mut(&temp_filename) { - match &state.variant { - WriterVariant::Ipc(writer_arc) => { - log_debug!("Writing RecordBatch to IPC staging file"); - let mut writer = writer_arc.lock().unwrap(); - writer.write(&record_batch)?; - } - WriterVariant::Parquet(writer_arc) => { - log_debug!("Writing RecordBatch to Parquet file"); - let mut writer = writer_arc.lock().unwrap(); - writer.write(&record_batch)?; + if let Some(mut state) = WRITERS.get_mut(&temp_filename) { + let is_ipc = matches!(&state.variant, WriterVariant::Ipc(_)); + if is_ipc { + let writer_arc = match &state.variant { + WriterVariant::Ipc(w) => Arc::clone(w), + _ => unreachable!(), + }; + log_debug!("Writing RecordBatch to IPC staging file"); + let mut writer = writer_arc.lock().unwrap(); + writer.write(&record_batch)?; + // SortingChunkedWriter tracks its own memory via its reservation + } else { + let writer_arc = match &state.variant { + WriterVariant::Parquet(w) => Arc::clone(w), + _ => unreachable!(), + }; + log_debug!("Writing RecordBatch to Parquet file"); + + // Write the batch — ArrowWriter encodes into internal buffers + let mut writer = writer_arc.lock().unwrap(); + writer.write(&record_batch)?; + let actual_mem = writer.memory_size(); + drop(writer); + + // Now check: try to reserve the actual memory used. + // If over limit, the write already happened but we reject — + // the writer will be dropped (data replayed from translog). + let current = state.reservation.size(); + if actual_mem > current { + state.reservation.try_grow(actual_mem - current) + .map_err(|e| -> Box { + format!("Write memory limit exceeded: {}", e).into() + })?; + } else if actual_mem < current { + state.reservation.shrink(current - actual_mem); } } Ok(()) @@ -494,7 +545,7 @@ impl NativeParquetWriter { log_debug!("finalize_writer called for file: {} (temp: {})", filename, temp_filename); if let Some((_, state)) = WRITERS.remove(&temp_filename) { - let WriterState { variant, settings, crc_handle, writer_generation } = state; + let WriterState { variant, settings, crc_handle, writer_generation, reservation: _reservation } = state; let index_name = settings.index_name.as_deref().unwrap_or(""); match variant { @@ -622,6 +673,10 @@ impl NativeParquetWriter { let row_id_mapping = if !chunk_row_ids.is_empty() && !chunk_row_ids[0].is_empty() { let ids = &chunk_row_ids[0]; let total = ids.len(); + // Track mapping allocation in write pool + let mapping_bytes = total * std::mem::size_of::(); + let mut mapping_reservation = MemoryReservation::new(write_pool(), "finalize_mapping"); + mapping_reservation.grow(mapping_bytes); let mut mapping = vec![0i64; total]; for (new_pos, &old_row_id) in ids.iter().enumerate() { let orig_idx = old_row_id as usize; @@ -630,6 +685,7 @@ impl NativeParquetWriter { } } Some(mapping) + // mapping_reservation dropped here — memory transferred to Java via Box::into_raw } else { None }; @@ -665,6 +721,9 @@ impl NativeParquetWriter { // Build the flat permutation: result[original_row_id] = new_row_id let row_id_mapping = if !merge_output.mapping.is_empty() && !chunk_row_ids.is_empty() { let total = merge_output.mapping.len(); + let mapping_bytes = total * std::mem::size_of::(); + let mut mapping_reservation = MemoryReservation::new(write_pool(), "finalize_flat_mapping"); + mapping_reservation.grow(mapping_bytes); let mut flat_mapping = vec![0i64; total]; for i in 0..total { flat_mapping[i] = i as i64; @@ -782,16 +841,12 @@ impl NativeParquetWriter { let mut total_memory = 0; for entry in WRITERS.iter() { if entry.key().starts_with(&path_prefix) { - match &entry.value().variant { - WriterVariant::Parquet(writer_arc) => { - if let Ok(writer) = writer_arc.lock() { - total_memory += writer.memory_size(); - } - } - WriterVariant::Ipc(writer_arc) => { - if let Ok(writer) = writer_arc.lock() { - total_memory += writer.memory_size(); - } + // WriterState.reservation tracks Parquet variant memory + total_memory += entry.value().reservation.size(); + // SortingChunkedWriter has its own reservation tracked in the pool + if let WriterVariant::Ipc(writer_arc) = &entry.value().variant { + if let Ok(writer) = writer_arc.lock() { + total_memory += writer.reservation.size(); } } }