diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java new file mode 100644 index 0000000000..0e01c12d81 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method + * pattern used by CometScalarSubquery so the native side can dispatch via + * call_static_method_unchecked. + */ +public class CometUdfBridge { + + // Per-thread, bounded LRU of UDF instances keyed by class name. Comet + // native execution threads (Tokio/DataFusion worker pool) are reused + // across tasks within an executor, so the effective lifetime of cached + // entries is the worker thread (i.e. the executor JVM). This is fine for + // stateless UDFs like ArrayExistsUDF; future stateful UDFs would need + // explicit per-task isolation. + private static final int CACHE_CAPACITY = 64; + + private static final ThreadLocal> INSTANCES = + ThreadLocal.withInitial( + () -> + new LinkedHashMap(CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_CAPACITY; + } + }); + + /** + * Called from native via JNI. + * + * @param udfClassName fully-qualified class name implementing CometUDF + * @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input) + * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) + * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result + * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + */ + public static void evaluate( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr) { + LinkedHashMap cache = INSTANCES.get(); + CometUDF udf = cache.get(udfClassName); + if (udf == null) { + try { + // Resolve via the executor's context classloader so user-supplied UDF jars + // (added via spark.jars / --jars) are visible. + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = CometUdfBridge.class.getClassLoader(); + } + udf = + (CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e); + } + cache.put(udfClassName, udf); + } + + BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); + + ValueVector[] inputs = new ValueVector[inputArrayPtrs.length]; + ValueVector result = null; + try { + for (int i = 0; i < inputArrayPtrs.length; i++) { + ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]); + ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]); + inputs[i] = Data.importVector(allocator, inArr, inSch, null); + } + + result = udf.evaluate(inputs); + if (!(result instanceof FieldVector)) { + throw new RuntimeException( + "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); + } + // Result length must match the longest input. Scalar (length-1) inputs + // are allowed to be shorter, but a vector input bounds the output. + int expectedLen = 0; + for (ValueVector v : inputs) { + expectedLen = Math.max(expectedLen, v.getValueCount()); + } + if (result.getValueCount() != expectedLen) { + throw new RuntimeException( + "CometUDF.evaluate() returned " + + result.getValueCount() + + " rows, expected " + + expectedLen); + } + ArrowArray outArr = ArrowArray.wrap(outArrayPtr); + ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); + Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch); + } finally { + for (ValueVector v : inputs) { + if (v != null) { + try { + v.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + if (result != null) { + try { + result.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala new file mode 100644 index 0000000000..5e020ae74a --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan + * time the serde layer registers a lambda expression under a unique key; at execution time the + * UDF retrieves it by that key (passed as a scalar argument). + */ +object CometLambdaRegistry { + + private val registry = new ConcurrentHashMap[String, Expression]() + + def register(expression: Expression): String = { + val key = UUID.randomUUID().toString + registry.put(key, expression) + key + } + + def get(key: String): Expression = { + val expr = registry.get(key) + if (expr == null) { + throw new IllegalStateException( + s"Lambda expression not found in registry for key: $key. " + + "This indicates a lifecycle issue between plan creation and execution.") + } + expr + } + + def remove(key: String): Unit = { + registry.remove(key) + } + + // Visible for testing + def size(): Int = registry.size() +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala new file mode 100644 index 0000000000..ac7b72a883 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.ValueVector + +/** + * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns + * an Arrow vector. + * + * - Vector arguments arrive at the row count of the current batch. + * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. + * - The returned vector's length must match the longest input. + * + * Implementations must have a public no-arg constructor and should be stateless: instances are + * cached per executor thread for the lifetime of the JVM. + */ +trait CometUDF { + def evaluate(inputs: Array[ValueVector]): ValueVector +} diff --git a/native/Cargo.lock b/native/Cargo.lock index ae2d6b074c..75e84d851d 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2116,8 +2116,10 @@ dependencies = [ "criterion", "datafusion", "datafusion-comet-common", + "datafusion-comet-jni-bridge", "futures", "hex", + "jni 0.22.4", "num", "rand 0.10.1", "regex", diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 844cc07c69..6019f168cc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -122,10 +122,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, + Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, + ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -701,6 +701,23 @@ impl PhysicalPlanner { expr.names.clone(), ))) } + ExprStruct::JvmScalarUdf(udf) => { + let args = udf + .args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + let return_type = + to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { + GeneralError("JvmScalarUdf missing return_type".to_string()) + })?); + Ok(Arc::new(JvmScalarUdfExpr::new( + udf.class_name.clone(), + args, + return_type, + udf.return_nullable, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs new file mode 100644 index 0000000000..89cd8ee514 --- /dev/null +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::{ + errors::Result as JniResult, + objects::{JClass, JStaticMethodID}, + signature::{Primitive, ReturnType}, + strings::JNIString, + Env, +}; + +/// JNI handle for the JVM `org.apache.comet.udf.CometUdfBridge` class. +/// Mirrors the static-method pattern in `comet_exec.rs` (`CometScalarSubquery`). +#[allow(dead_code)] // class field is held to keep JStaticMethodID alive +pub struct CometUdfBridge<'a> { + pub class: JClass<'a>, + pub method_evaluate: JStaticMethodID, + pub method_evaluate_ret: ReturnType, +} + +impl<'a> CometUdfBridge<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/udf/CometUdfBridge"; + + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + Ok(CometUdfBridge { + method_evaluate: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("evaluate"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + )?, + method_evaluate_ret: ReturnType::Primitive(Primitive::Void), + class, + }) + } +} diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 21c647135b..d72323c961 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -192,11 +192,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod comet_udf_bridge; mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use comet_udf_bridge::CometUdfBridge; use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. @@ -228,6 +230,9 @@ pub struct JVMClasses<'a> { /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, + /// The CometUdfBridge class used to dispatch JVM scalar UDFs. + /// `None` if the class is not on the classpath. + pub comet_udf_bridge: Option>, } unsafe impl Send for JVMClasses<'_> {} @@ -298,6 +303,13 @@ impl JVMClasses<'_> { comet_batch_iterator: CometBatchIterator::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), + comet_udf_bridge: { + let bridge = CometUdfBridge::new(env).ok(); + if env.exception_check() { + env.exception_clear(); + } + bridge + }, } }); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c7a305285d..90e3d87032 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -90,6 +90,7 @@ message Expr { ToCsv to_csv = 67; HoursTransform hours_transform = 68; ArraysZip arrays_zip = 69; + JvmScalarUdf jvm_scalar_udf = 70; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -514,3 +515,18 @@ message ArraysZip { repeated Expr values = 1; repeated string names = 2; } + +// Scalar UDF dispatched to the JVM via JNI. Native side exports input arrays +// through Arrow C Data Interface, calls CometUdfBridge.evaluate, and imports +// the result. +message JvmScalarUdf { + // Fully-qualified Java/Scala class name implementing + // org.apache.comet.udf.CometUDF (must have a public no-arg constructor). + string class_name = 1; + // Argument expressions, evaluated by the native side before invocation. + repeated Expr args = 2; + // Expected return type. Used to import the result FFI_ArrowArray. + DataType return_type = 3; + // Whether the result column may contain nulls. + bool return_nullable = 4; +} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e9a4a546c1..33ffc1c886 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,6 +36,8 @@ regex = { workspace = true } # preserve_order: needed for get_json_object to match Spark's JSON key ordering serde_json = { version = "1.0", features = ["preserve_order"] } datafusion-comet-common = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } +jni = "0.22.4" futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs new file mode 100644 index 0000000000..668a2b6727 --- /dev/null +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{make_array, ArrayRef}; +use arrow::datatypes::{DataType, Schema}; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; + +use datafusion::common::Result as DFResult; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; + +use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; +use datafusion_comet_jni_bridge::JVMClasses; +use jni::objects::{JObject, JValue}; + +/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. +/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. +#[derive(Debug)] +pub struct JvmScalarUdfExpr { + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, +} + +impl JvmScalarUdfExpr { + pub fn new( + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, + ) -> Self { + Self { + class_name, + args, + return_type, + return_nullable, + } + } +} + +impl Display for JvmScalarUdfExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "JvmScalarUdf({}", self.class_name)?; + for a in &self.args { + write!(f, ", {a}")?; + } + write!(f, ")") + } +} + +impl Hash for JvmScalarUdfExpr { + fn hash(&self, state: &mut H) { + self.class_name.hash(state); + for a in &self.args { + a.hash(state); + } + self.return_type.hash(state); + self.return_nullable.hash(state); + } +} + +impl PartialEq for JvmScalarUdfExpr { + fn eq(&self, other: &Self) -> bool { + self.class_name == other.class_name + && self.return_type == other.return_type + && self.return_nullable == other.return_nullable + && self.args.len() == other.args.len() + && self.args.iter().zip(&other.args).all(|(a, b)| a.eq(b)) + } +} + +impl Eq for JvmScalarUdfExpr {} + +impl PhysicalExpr for JvmScalarUdfExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_nullable) + } + + fn evaluate(&self, batch: &RecordBatch) -> DFResult { + // Step 1: evaluate child expressions to get Arrow arrays. Scalar children + // (e.g. literal patterns) are sent as length-1 vectors rather than expanded + // to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. + let arrays: Vec = self + .args + .iter() + .map(|e| match e.evaluate(batch)? { + ColumnarValue::Array(a) => Ok(a), + ColumnarValue::Scalar(s) => s.to_array_of_size(1), + }) + .collect::>()?; + + // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. + // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. + let in_ffi_arrays: Vec> = arrays + .iter() + .map(|arr| Box::new(FFI_ArrowArray::new(&arr.to_data()))) + .collect(); + let in_ffi_schemas: Vec> = arrays + .iter() + .map(|arr| { + FFI_ArrowSchema::try_from(arr.data_type()) + .map(Box::new) + .map_err(|e| CometError::Arrow { source: e }) + }) + .collect::>()?; + + let in_arr_ptrs: Vec = in_ffi_arrays + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowArray as i64) + .collect(); + let in_sch_ptrs: Vec = in_ffi_schemas + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) + .collect(); + + // Allocate output FFI slots. + let mut out_array = Box::new(FFI_ArrowArray::empty()); + let mut out_schema = Box::new(FFI_ArrowSchema::empty()); + let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; + let out_sch_ptr = out_schema.as_mut() as *mut FFI_ArrowSchema as i64; + + let class_name = self.class_name.clone(); + let n_args = arrays.len(); + + // Step 3: attach a JNI env for this thread and call the static bridge method. + JVMClasses::with_env(|env| { + let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { + CometError::from(ExecutionError::GeneralError( + "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ + class was not found on the JVM classpath." + .to_string(), + )) + })?; + + // Build the JVM String for the class name. + let jclass_name = env + .new_string(&class_name) + .map_err(|e| CometError::JNI { source: e })?; + + // Build the long[] arrays for input pointers. + let in_arr_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_arr_java + .set_region(env, 0, &in_arr_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + let in_sch_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_sch_java + .set_region(env, 0, &in_sch_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + let ret = unsafe { + env.call_static_method_unchecked( + &bridge.class, + bridge.method_evaluate, + bridge.method_evaluate_ret, + &[ + JValue::from(&jclass_name).as_jni(), + JValue::Object(JObject::from(in_arr_java).as_ref()).as_jni(), + JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), + JValue::Long(out_arr_ptr).as_jni(), + JValue::Long(out_sch_ptr).as_jni(), + ], + ) + }; + + if let Some(exception) = datafusion_comet_jni_bridge::check_exception(env)? { + return Err(exception); + } + + ret.map_err(|e| CometError::JNI { source: e })?; + Ok(()) + })?; + + // Step 4: import the result from the FFI slots filled by the JVM. + // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap + // allocation is freed by the move), and `from_ffi` wraps it in an Arc that + // keeps the JVM-installed release callback alive until the resulting + // ArrayData drops. `out_schema` is borrowed; its release callback runs + // exactly once when the Box drops at end of scope. + let result_data = unsafe { from_ffi(*out_array, &out_schema) } + .map_err(|e| CometError::Arrow { source: e })?; + Ok(ColumnarValue::Array(make_array(result_data))) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(JvmScalarUdfExpr::new( + self.class_name.clone(), + children, + self.return_type.clone(), + self.return_nullable, + ))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index eddf2ff460..d5297f27fd 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -55,6 +55,8 @@ pub use cast::{spark_cast, Cast, SparkCastOptions}; mod bloom_filter; pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain, SparkBloomFilterVersion}; +pub mod jvm_udf; + mod conditional_funcs; mod conversion_funcs; mod map_funcs; diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 48b8905035..63936a94b7 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -243,7 +243,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val df = spark.read .parquet(path.toString) .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) + .withColumn("idx", org.apache.spark.sql.functions.udf((_: Int) => 1).apply(col("_4"))) .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"),