diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java new file mode 100644 index 0000000000..0e01c12d81 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method + * pattern used by CometScalarSubquery so the native side can dispatch via + * call_static_method_unchecked. + */ +public class CometUdfBridge { + + // Per-thread, bounded LRU of UDF instances keyed by class name. Comet + // native execution threads (Tokio/DataFusion worker pool) are reused + // across tasks within an executor, so the effective lifetime of cached + // entries is the worker thread (i.e. the executor JVM). This is fine for + // stateless UDFs like ArrayExistsUDF; future stateful UDFs would need + // explicit per-task isolation. + private static final int CACHE_CAPACITY = 64; + + private static final ThreadLocal> INSTANCES = + ThreadLocal.withInitial( + () -> + new LinkedHashMap(CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_CAPACITY; + } + }); + + /** + * Called from native via JNI. + * + * @param udfClassName fully-qualified class name implementing CometUDF + * @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input) + * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) + * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result + * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + */ + public static void evaluate( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr) { + LinkedHashMap cache = INSTANCES.get(); + CometUDF udf = cache.get(udfClassName); + if (udf == null) { + try { + // Resolve via the executor's context classloader so user-supplied UDF jars + // (added via spark.jars / --jars) are visible. + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = CometUdfBridge.class.getClassLoader(); + } + udf = + (CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e); + } + cache.put(udfClassName, udf); + } + + BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); + + ValueVector[] inputs = new ValueVector[inputArrayPtrs.length]; + ValueVector result = null; + try { + for (int i = 0; i < inputArrayPtrs.length; i++) { + ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]); + ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]); + inputs[i] = Data.importVector(allocator, inArr, inSch, null); + } + + result = udf.evaluate(inputs); + if (!(result instanceof FieldVector)) { + throw new RuntimeException( + "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); + } + // Result length must match the longest input. Scalar (length-1) inputs + // are allowed to be shorter, but a vector input bounds the output. + int expectedLen = 0; + for (ValueVector v : inputs) { + expectedLen = Math.max(expectedLen, v.getValueCount()); + } + if (result.getValueCount() != expectedLen) { + throw new RuntimeException( + "CometUDF.evaluate() returned " + + result.getValueCount() + + " rows, expected " + + expectedLen); + } + ArrowArray outArr = ArrowArray.wrap(outArrayPtr); + ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); + Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch); + } finally { + for (ValueVector v : inputs) { + if (v != null) { + try { + v.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + if (result != null) { + try { + result.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala b/common/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala new file mode 100644 index 0000000000..ffe990a48b --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.ListVector +import org.apache.spark.sql.catalyst.expressions.{ArrayExists, LambdaFunction, NamedLambdaVariable} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometArrowAllocator + +/** + * JVM UDF implementing Spark's `exists(array, x -> predicate(x))` higher-order function. + * + * Inputs: + * - inputs(0): ListVector (the array column) + * - inputs(1): VarCharVector length-1 scalar (registry key for the lambda expression) + * + * Output: BitVector (nullable boolean), same length as the input array vector. + * + * Implements Spark's three-valued logic: + * - true if any element satisfies the predicate + * - null if no element satisfies but the predicate returned null for at least one element + * - false if all elements produce false + */ +class ArrayExistsUDF extends CometUDF { + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + require(inputs.length == 2, s"ArrayExistsUDF expects 2 inputs, got ${inputs.length}") + val listVec = inputs(0).asInstanceOf[ListVector] + val keyVec = inputs(1).asInstanceOf[VarCharVector] + require( + keyVec.getValueCount >= 1 && !keyVec.isNull(0), + "ArrayExistsUDF requires a non-null scalar registry key") + + val registryKey = new String(keyVec.get(0), StandardCharsets.UTF_8) + val arrayExistsExpr = CometLambdaRegistry.get(registryKey).asInstanceOf[ArrayExists] + + val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = arrayExistsExpr.function + val body = arrayExistsExpr.functionForEval + val followThreeValuedLogic = arrayExistsExpr.followThreeValuedLogic + val elementType = elementVar.dataType + + val dataVec = listVec.getDataVector + val n = listVec.getValueCount + val out = new BitVector("exists_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (listVec.isNull(i)) { + out.setNull(i) + } else { + val startIdx = listVec.getElementStartIndex(i) + val endIdx = listVec.getElementEndIndex(i) + var exists = false + var foundNull = false + var j = startIdx + while (j < endIdx && !exists) { + if (dataVec.isNull(j)) { + elementVar.value.set(null) + val ret = body.eval(null) + if (ret == null) foundNull = true + else if (ret.asInstanceOf[Boolean]) exists = true + } else { + val elem = getSparkValue(dataVec, j, elementType) + elementVar.value.set(elem) + val ret = body.eval(null) + if (ret == null) foundNull = true + else if (ret.asInstanceOf[Boolean]) exists = true + } + j += 1 + } + if (exists) { + out.set(i, 1) + } else if (followThreeValuedLogic && foundNull) { + out.setNull(i) + } else { + out.set(i, 0) + } + } + i += 1 + } + out.setValueCount(n) + out + } + + private def getSparkValue(vec: ValueVector, index: Int, sparkType: DataType): Any = { + sparkType match { + case BooleanType => + vec.asInstanceOf[BitVector].get(index) == 1 + case ByteType => + vec.asInstanceOf[TinyIntVector].get(index).toByte + case ShortType => + vec.asInstanceOf[SmallIntVector].get(index).toShort + case IntegerType => + vec.asInstanceOf[IntVector].get(index) + case LongType => + vec.asInstanceOf[BigIntVector].get(index) + case FloatType => + vec.asInstanceOf[Float4Vector].get(index) + case DoubleType => + vec.asInstanceOf[Float8Vector].get(index) + case StringType => + val bytes = vec.asInstanceOf[VarCharVector].get(index) + UTF8String.fromBytes(bytes) + case BinaryType => + vec.asInstanceOf[VarBinaryVector].get(index) + case _: DecimalType => + val dt = sparkType.asInstanceOf[DecimalType] + val decimal = vec.asInstanceOf[DecimalVector].getObject(index) + Decimal(decimal, dt.precision, dt.scale) + case DateType => + vec.asInstanceOf[DateDayVector].get(index) + case TimestampType => + vec.asInstanceOf[TimeStampMicroTZVector].get(index) + case _ => + throw new UnsupportedOperationException( + s"ArrayExistsUDF does not yet support element type: $sparkType") + } + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala new file mode 100644 index 0000000000..5e020ae74a --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan + * time the serde layer registers a lambda expression under a unique key; at execution time the + * UDF retrieves it by that key (passed as a scalar argument). + */ +object CometLambdaRegistry { + + private val registry = new ConcurrentHashMap[String, Expression]() + + def register(expression: Expression): String = { + val key = UUID.randomUUID().toString + registry.put(key, expression) + key + } + + def get(key: String): Expression = { + val expr = registry.get(key) + if (expr == null) { + throw new IllegalStateException( + s"Lambda expression not found in registry for key: $key. " + + "This indicates a lifecycle issue between plan creation and execution.") + } + expr + } + + def remove(key: String): Unit = { + registry.remove(key) + } + + // Visible for testing + def size(): Int = registry.size() +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala new file mode 100644 index 0000000000..ac7b72a883 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.ValueVector + +/** + * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns + * an Arrow vector. + * + * - Vector arguments arrive at the row count of the current batch. + * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. + * - The returned vector's length must match the longest input. + * + * Implementations must have a public no-arg constructor and should be stateless: instances are + * cached per executor thread for the lifetime of the JVM. + */ +trait CometUDF { + def evaluate(inputs: Array[ValueVector]): ValueVector +} diff --git a/native/Cargo.lock b/native/Cargo.lock index ae2d6b074c..e66a8d8e7e 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -230,9 +230,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d441fdda254b65f3e9025910eb2c2066b6295d9c8ed409522b8d2ace1ff8574c" +checksum = "607e64bb911ee4f90483e044fe78f175989148c2892e659a2cd25429e782ec54" dependencies = [ "arrow-arith", "arrow-array", @@ -251,9 +251,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced5406f8b720cc0bc3aa9cf5758f93e8593cda5490677aa194e4b4b383f9a59" +checksum = "e754319ed8a85d817fe7adf183227e0b5308b82790a737b426c1124626b48118" dependencies = [ "arrow-array", "arrow-buffer", @@ -265,9 +265,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "772bd34cacdda8baec9418d80d23d0fb4d50ef0735685bd45158b83dfeb6e62d" +checksum = "841321891f247aa86c6112c80d83d89cb36e0addd020fa2425085b8eb6c3f579" dependencies = [ "ahash", "arrow-buffer", @@ -276,7 +276,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "num-complex", "num-integer", "num-traits", @@ -284,9 +284,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898f4cf1e9598fdb77f356fdf2134feedfd0ee8d5a4e0a5f573e7d0aec16baa4" +checksum = "f955dfb73fae000425f49c8226d2044dab60fb7ad4af1e24f961756354d996c9" dependencies = [ "bytes", "half", @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0127816c96533d20fc938729f48c52d3e48f99717e7a0b5ade77d742510736d" +checksum = "ca5e686972523798f76bef355145bc1ae25a84c731e650268d31ab763c701663" dependencies = [ "arrow-array", "arrow-buffer", @@ -318,9 +318,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca025bd0f38eeecb57c2153c0123b960494138e6a957bbda10da2b25415209fe" +checksum = "86c276756867fc8186ec380c72c290e6e3b23a1d4fb05df6b1d62d2e62666d48" dependencies = [ "arrow-array", "arrow-cast", @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d10beeab2b1c3bb0b53a00f7c944a178b622173a5c7bcabc3cb45d90238df4" +checksum = "db3b5846209775b6dc8056d77ff9a032b27043383dd5488abd0b663e265b9373" dependencies = [ "arrow-buffer", "arrow-schema", @@ -346,9 +346,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "609a441080e338147a84e8e6904b6da482cefb957c5cdc0f3398872f69a315d0" +checksum = "fd8907ddd8f9fbabf91ec2c85c1d81fe2874e336d2443eb36373595e28b98dd5" dependencies = [ "arrow-array", "arrow-buffer", @@ -361,15 +361,16 @@ dependencies = [ [[package]] name = "arrow-json" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ead0914e4861a531be48fe05858265cf854a4880b9ed12618b1d08cba9bebc8" +checksum = "f4518c59acc501f10d7dcae397fe12b8db3d81bc7de94456f8a58f9165d6f502" dependencies = [ "arrow-array", "arrow-buffer", "arrow-cast", - "arrow-data", + "arrow-ord", "arrow-schema", + "arrow-select", "chrono", "half", "indexmap 2.14.0", @@ -385,9 +386,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a7ba279b20b52dad300e68cfc37c17efa65e68623169076855b3a9e941ca5" +checksum = "efa70d9d6b1356f1fb9f1f651b84a725b7e0abb93f188cf7d31f14abfa2f2e6f" dependencies = [ "arrow-array", "arrow-buffer", @@ -398,9 +399,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14fe367802f16d7668163ff647830258e6e0aeea9a4d79aaedf273af3bdcd3e" +checksum = "faec88a945338192beffbbd4be0def70135422930caa244ac3cec0cd213b26b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -411,9 +412,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c30a1365d7a7dc50cc847e54154e6af49e4c4b0fddc9f607b687f29212082743" +checksum = "18aa020f6bc8e5201dcd2d4b7f98c68f8a410ef37128263243e6ff2a47a67d4f" dependencies = [ "bitflags 2.11.1", "serde_core", @@ -422,9 +423,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78694888660a9e8ac949853db393af2a8b8fc82c19ce333132dfa2e72cc1a7fe" +checksum = "a657ab5132e9c8ca3b24eb15a823d0ced38017fe3930ff50167466b02e2d592c" dependencies = [ "ahash", "arrow-array", @@ -436,9 +437,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e04a01f8bb73ce54437514c5fd3ee2aa3e8abe4c777ee5cc55853b1652f79e" +checksum = "f6de2efbbd1a9f9780ceb8d1ff5d20421b35863b361e3386b4f571f1fc69fcb8" dependencies = [ "arrow-array", "arrow-buffer", @@ -492,9 +493,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" +checksum = "e79b3f8a79cccc2898f31920fc69f304859b3bd567490f75ebf51ae1c792a9ac" dependencies = [ "compression-codecs", "compression-core", @@ -1125,9 +1126,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.4" +version = "1.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e" +checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce" dependencies = [ "arrayref", "arrayvec", @@ -1496,9 +1497,9 @@ dependencies = [ [[package]] name = "compression-codecs" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" +checksum = "ce2548391e9c1929c21bf6aa2680af86fe4c1b33e6cea9ac1cfeec0bd11218cf" dependencies = [ "bzip2", "compression-core", @@ -1511,9 +1512,9 @@ dependencies = [ [[package]] name = "compression-core" -version = "0.4.31" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +checksum = "cc14f565cf027a105f7a44ccf9e5b424348421a1d8952a8fc9d499d313107789" [[package]] name = "concurrent-queue" @@ -1580,9 +1581,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpp_demangle" -version = "0.5.1" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0667304c32ea56cb4cd6d2d7c0cfe9a2f8041229db8c033af7f8d69492429def" +checksum = "f2bb79cb74d735044c972aae58ed0aaa9a837e85b01106a54c39e42e97f62253" dependencies = [ "cfg-if", ] @@ -2116,8 +2117,10 @@ dependencies = [ "criterion", "datafusion", "datafusion-comet-common", + "datafusion-comet-jni-bridge", "futures", "hex", + "jni 0.22.4", "num", "rand 0.10.1", "regex", @@ -2813,9 +2816,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ "block-buffer 0.12.0", "const-oid 0.10.2", @@ -3264,9 +3267,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +checksum = "171fefbc92fe4a4de27e0698d6a5b392d6a0e333506bc49133760b3bcf948733" dependencies = [ "atomic-waker", "bytes", @@ -3388,7 +3391,7 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" dependencies = [ - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -3469,9 +3472,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hybrid-array" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3944cf8cf766b40e2a1a333ee5e9b563f854d5fa49d6a8ca2764e97c6eddb214" +checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" dependencies = [ "typenum", ] @@ -3742,9 +3745,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" dependencies = [ "icu_normalizer", "icu_properties", @@ -3884,9 +3887,9 @@ dependencies = [ [[package]] name = "jiff" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" dependencies = [ "jiff-static", "jiff-tzdb-platform", @@ -3901,9 +3904,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.23" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" dependencies = [ "proc-macro2", "quote", @@ -4013,9 +4016,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" dependencies = [ "cfg-if", "futures-util", @@ -4537,7 +4540,7 @@ dependencies = [ "md-5", "parking_lot", "percent-encoding", - "quick-xml 0.39.2", + "quick-xml 0.39.3", "rand 0.10.1", "reqwest 0.12.28", "ring", @@ -4658,7 +4661,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.38.4", "reqsign-core", - "reqwest 0.13.2", + "reqwest 0.13.3", "serde", "serde_json", "tokio", @@ -4796,9 +4799,9 @@ dependencies = [ [[package]] name = "parquet" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d3f9f2205199603564127932b89695f52b62322f541d0fc7179d57c2e1c9877" +checksum = "43d7efd3052f7d6ef601085559a246bc991e9a8cc77e02753737df6322ce35f1" dependencies = [ "ahash", "arrow-array", @@ -4814,7 +4817,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "lz4_flex", "num-bigint", "num-integer", @@ -4836,23 +4839,25 @@ dependencies = [ [[package]] name = "parquet-variant" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf493f3c9ddd984d0efb019f67343e4aa4bab893931f6a14b82083065dc3d28" +checksum = "262fd51760f388670dbab2283efaadd0f4ed87ad584e60bd0db7fb79d527f045" dependencies = [ + "arrow", "arrow-schema", "chrono", "half", "indexmap 2.14.0", + "num-traits", "simdutf8", "uuid", ] [[package]] name = "parquet-variant-compute" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ac038d46a503a7d563b4f5df5802c4315d5343d009feab195d15ac512b4cb27" +checksum = "4c94fc2c2c077a00b3d5232f965037cee3455c432567a78d66db101daa035689" dependencies = [ "arrow", "arrow-schema", @@ -4867,9 +4872,9 @@ dependencies = [ [[package]] name = "parquet-variant-json" -version = "58.1.0" +version = "58.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "015a09c2ffe5108766c7c1235c307b8a3c2ea64eca38455ba1a7f3a7f32f16e2" +checksum = "7ed1077da4aeb4e4141aa2f9858ac354975595eb30f907762894587941e8f2f7" dependencies = [ "arrow-schema", "base64", @@ -4958,18 +4963,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +checksum = "cbf0d9e68100b3a7989b4901972f265cd542e560a3a8a724e1e20322f4d06ce9" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +checksum = "a990e22f43e84855daf260dded30524ef4a9021cc7541c26540500a50b624389" dependencies = [ "proc-macro2", "quote", @@ -5289,9 +5294,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.39.2" +version = "0.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958f21e8e7ceb5a1aa7fa87fab28e7c75976e0bfe7e23ff069e0a260f894067d" +checksum = "721da970c312655cde9b4ffe0547f20a8494866a4af5ff51f18b7c633d0c870b" dependencies = [ "memchr", "serde", @@ -5633,9 +5638,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" +checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" dependencies = [ "base64", "bytes", @@ -5694,9 +5699,9 @@ dependencies = [ [[package]] name = "roaring" -version = "0.11.3" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ba9ce64a8f45d7fc86358410bb1a82e8c987504c0d4900e9141d69a9f26c885" +checksum = "1dedc5658c6ecb3bdb5ef5f3295bb9253f42dcf3fd1402c03f6b1f7659c3c4a9" dependencies = [ "bytemuck", "byteorder", @@ -5788,9 +5793,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.38" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", @@ -5815,9 +5820,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.14.0" +version = "1.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" dependencies = [ "web-time", "zeroize", @@ -5825,13 +5830,13 @@ dependencies = [ [[package]] name = "rustls-platform-verifier" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" +checksum = "26d1e2536ce4f35f4846aa13bff16bd0ff40157cdb14cc056c7b14ba41233ba0" dependencies = [ "core-foundation", "core-foundation-sys", - "jni 0.21.1", + "jni 0.22.4", "log", "once_cell", "rustls", @@ -5852,9 +5857,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -6075,9 +6080,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" +checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" dependencies = [ "base64", "chrono", @@ -6094,9 +6099,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" +checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -6147,7 +6152,7 @@ checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -6212,9 +6217,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -6337,9 +6342,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "symbolic-common" -version = "12.18.1" +version = "12.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f3cdeaae6779ecba2567f20bf7716718b8c4ce6717c9def4ced18786bb11ea" +checksum = "332615d90111d8eeaf86a84dc9bbe9f65d0d8c5cf11b4caccedc37754eb0dcfd" dependencies = [ "debugid", "memmap2", @@ -6349,9 +6354,9 @@ dependencies = [ [[package]] name = "symbolic-demangle" -version = "12.18.1" +version = "12.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672c6ad9cb8fce6a1283cc9df9070073cccad00ae241b80e3686328a64e3523b" +checksum = "912017718eb4d21930546245af9a3475c9dccf15675a5c215664e76621afc471" dependencies = [ "cpp_demangle", "rustc-demangle", @@ -6588,9 +6593,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" dependencies = [ "bytes", "libc", @@ -6777,9 +6782,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "typetag" @@ -6947,11 +6952,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] @@ -6960,14 +6965,14 @@ version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" dependencies = [ "cfg-if", "once_cell", @@ -6978,9 +6983,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" dependencies = [ "js-sys", "wasm-bindgen", @@ -6988,9 +6993,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6998,9 +7003,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" dependencies = [ "bumpalo", "proc-macro2", @@ -7011,9 +7016,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" dependencies = [ "unicode-ident", ] @@ -7080,9 +7085,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" dependencies = [ "js-sys", "wasm-bindgen", @@ -7458,6 +7463,12 @@ dependencies = [ "wit-bindgen-rust-macro", ] +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + [[package]] name = "wit-bindgen-core" version = "0.51.0" diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 844cc07c69..6019f168cc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -122,10 +122,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, + Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, + ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -701,6 +701,23 @@ impl PhysicalPlanner { expr.names.clone(), ))) } + ExprStruct::JvmScalarUdf(udf) => { + let args = udf + .args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + let return_type = + to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { + GeneralError("JvmScalarUdf missing return_type".to_string()) + })?); + Ok(Arc::new(JvmScalarUdfExpr::new( + udf.class_name.clone(), + args, + return_type, + udf.return_nullable, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs new file mode 100644 index 0000000000..89cd8ee514 --- /dev/null +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::{ + errors::Result as JniResult, + objects::{JClass, JStaticMethodID}, + signature::{Primitive, ReturnType}, + strings::JNIString, + Env, +}; + +/// JNI handle for the JVM `org.apache.comet.udf.CometUdfBridge` class. +/// Mirrors the static-method pattern in `comet_exec.rs` (`CometScalarSubquery`). +#[allow(dead_code)] // class field is held to keep JStaticMethodID alive +pub struct CometUdfBridge<'a> { + pub class: JClass<'a>, + pub method_evaluate: JStaticMethodID, + pub method_evaluate_ret: ReturnType, +} + +impl<'a> CometUdfBridge<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/udf/CometUdfBridge"; + + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + Ok(CometUdfBridge { + method_evaluate: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("evaluate"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + )?, + method_evaluate_ret: ReturnType::Primitive(Primitive::Void), + class, + }) + } +} diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 21c647135b..d72323c961 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -192,11 +192,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod comet_udf_bridge; mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use comet_udf_bridge::CometUdfBridge; use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. @@ -228,6 +230,9 @@ pub struct JVMClasses<'a> { /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, + /// The CometUdfBridge class used to dispatch JVM scalar UDFs. + /// `None` if the class is not on the classpath. + pub comet_udf_bridge: Option>, } unsafe impl Send for JVMClasses<'_> {} @@ -298,6 +303,13 @@ impl JVMClasses<'_> { comet_batch_iterator: CometBatchIterator::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), + comet_udf_bridge: { + let bridge = CometUdfBridge::new(env).ok(); + if env.exception_check() { + env.exception_clear(); + } + bridge + }, } }); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c7a305285d..90e3d87032 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -90,6 +90,7 @@ message Expr { ToCsv to_csv = 67; HoursTransform hours_transform = 68; ArraysZip arrays_zip = 69; + JvmScalarUdf jvm_scalar_udf = 70; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -514,3 +515,18 @@ message ArraysZip { repeated Expr values = 1; repeated string names = 2; } + +// Scalar UDF dispatched to the JVM via JNI. Native side exports input arrays +// through Arrow C Data Interface, calls CometUdfBridge.evaluate, and imports +// the result. +message JvmScalarUdf { + // Fully-qualified Java/Scala class name implementing + // org.apache.comet.udf.CometUDF (must have a public no-arg constructor). + string class_name = 1; + // Argument expressions, evaluated by the native side before invocation. + repeated Expr args = 2; + // Expected return type. Used to import the result FFI_ArrowArray. + DataType return_type = 3; + // Whether the result column may contain nulls. + bool return_nullable = 4; +} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e9a4a546c1..33ffc1c886 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,6 +36,8 @@ regex = { workspace = true } # preserve_order: needed for get_json_object to match Spark's JSON key ordering serde_json = { version = "1.0", features = ["preserve_order"] } datafusion-comet-common = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } +jni = "0.22.4" futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs new file mode 100644 index 0000000000..668a2b6727 --- /dev/null +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{make_array, ArrayRef}; +use arrow::datatypes::{DataType, Schema}; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; + +use datafusion::common::Result as DFResult; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; + +use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; +use datafusion_comet_jni_bridge::JVMClasses; +use jni::objects::{JObject, JValue}; + +/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. +/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. +#[derive(Debug)] +pub struct JvmScalarUdfExpr { + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, +} + +impl JvmScalarUdfExpr { + pub fn new( + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, + ) -> Self { + Self { + class_name, + args, + return_type, + return_nullable, + } + } +} + +impl Display for JvmScalarUdfExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "JvmScalarUdf({}", self.class_name)?; + for a in &self.args { + write!(f, ", {a}")?; + } + write!(f, ")") + } +} + +impl Hash for JvmScalarUdfExpr { + fn hash(&self, state: &mut H) { + self.class_name.hash(state); + for a in &self.args { + a.hash(state); + } + self.return_type.hash(state); + self.return_nullable.hash(state); + } +} + +impl PartialEq for JvmScalarUdfExpr { + fn eq(&self, other: &Self) -> bool { + self.class_name == other.class_name + && self.return_type == other.return_type + && self.return_nullable == other.return_nullable + && self.args.len() == other.args.len() + && self.args.iter().zip(&other.args).all(|(a, b)| a.eq(b)) + } +} + +impl Eq for JvmScalarUdfExpr {} + +impl PhysicalExpr for JvmScalarUdfExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_nullable) + } + + fn evaluate(&self, batch: &RecordBatch) -> DFResult { + // Step 1: evaluate child expressions to get Arrow arrays. Scalar children + // (e.g. literal patterns) are sent as length-1 vectors rather than expanded + // to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. + let arrays: Vec = self + .args + .iter() + .map(|e| match e.evaluate(batch)? { + ColumnarValue::Array(a) => Ok(a), + ColumnarValue::Scalar(s) => s.to_array_of_size(1), + }) + .collect::>()?; + + // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. + // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. + let in_ffi_arrays: Vec> = arrays + .iter() + .map(|arr| Box::new(FFI_ArrowArray::new(&arr.to_data()))) + .collect(); + let in_ffi_schemas: Vec> = arrays + .iter() + .map(|arr| { + FFI_ArrowSchema::try_from(arr.data_type()) + .map(Box::new) + .map_err(|e| CometError::Arrow { source: e }) + }) + .collect::>()?; + + let in_arr_ptrs: Vec = in_ffi_arrays + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowArray as i64) + .collect(); + let in_sch_ptrs: Vec = in_ffi_schemas + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) + .collect(); + + // Allocate output FFI slots. + let mut out_array = Box::new(FFI_ArrowArray::empty()); + let mut out_schema = Box::new(FFI_ArrowSchema::empty()); + let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; + let out_sch_ptr = out_schema.as_mut() as *mut FFI_ArrowSchema as i64; + + let class_name = self.class_name.clone(); + let n_args = arrays.len(); + + // Step 3: attach a JNI env for this thread and call the static bridge method. + JVMClasses::with_env(|env| { + let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { + CometError::from(ExecutionError::GeneralError( + "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ + class was not found on the JVM classpath." + .to_string(), + )) + })?; + + // Build the JVM String for the class name. + let jclass_name = env + .new_string(&class_name) + .map_err(|e| CometError::JNI { source: e })?; + + // Build the long[] arrays for input pointers. + let in_arr_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_arr_java + .set_region(env, 0, &in_arr_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + let in_sch_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_sch_java + .set_region(env, 0, &in_sch_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + let ret = unsafe { + env.call_static_method_unchecked( + &bridge.class, + bridge.method_evaluate, + bridge.method_evaluate_ret, + &[ + JValue::from(&jclass_name).as_jni(), + JValue::Object(JObject::from(in_arr_java).as_ref()).as_jni(), + JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), + JValue::Long(out_arr_ptr).as_jni(), + JValue::Long(out_sch_ptr).as_jni(), + ], + ) + }; + + if let Some(exception) = datafusion_comet_jni_bridge::check_exception(env)? { + return Err(exception); + } + + ret.map_err(|e| CometError::JNI { source: e })?; + Ok(()) + })?; + + // Step 4: import the result from the FFI slots filled by the JVM. + // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap + // allocation is freed by the move), and `from_ffi` wraps it in an Arc that + // keeps the JVM-installed release callback alive until the resulting + // ArrayData drops. `out_schema` is borrowed; its release callback runs + // exactly once when the Box drops at end of scope. + let result_data = unsafe { from_ffi(*out_array, &out_schema) } + .map_err(|e| CometError::Arrow { source: e })?; + Ok(ColumnarValue::Array(make_array(result_data))) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(JvmScalarUdfExpr::new( + self.class_name.clone(), + children, + self.return_type.clone(), + self.return_nullable, + ))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index eddf2ff460..d5297f27fd 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -55,6 +55,8 @@ pub use cast::{spark_cast, Cast, SparkCastOptions}; mod bloom_filter; pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain, SparkBloomFilterVersion}; +pub mod jvm_udf; + mod conditional_funcs; mod conversion_funcs; mod map_funcs; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e73cf12f79..28f31c7c81 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -69,7 +69,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, classOf[Size] -> CometSize, - classOf[ArraysZip] -> CometArraysZip) + classOf[ArraysZip] -> CometArraysZip, + classOf[ArrayExists] -> CometArrayExists) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 5edc08840a..782a3cdd1b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, AttributeReference, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -31,6 +31,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde._ import org.apache.comet.shims.{CometExprShim, CometTypeShim} +import org.apache.comet.udf.CometLambdaRegistry object CometArrayRemove extends CometExpressionSerde[ArrayRemove] @@ -812,3 +813,71 @@ trait ArraysBase { } } } + +object CometArrayExists extends CometExpressionSerde[ArrayExists] { + + private def isElementTypeSupported(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | StringType => + true + case _ => false + } + + override def getSupportLevel(expr: ArrayExists): SupportLevel = { + val ArrayType(elementType, _) = expr.argument.dataType + if (!isElementTypeSupported(elementType)) { + return Unsupported(Some(s"Unsupported array element type: $elementType")) + } + // Only support lambdas that reference the lambda variable alone (no captured columns) + expr.function match { + case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) => + val capturedRefs = body.collect { case a: AttributeReference => a } + if (capturedRefs.nonEmpty) { + Unsupported(Some("Lambda references columns outside the array element")) + } else { + Compatible() + } + case _ => + Unsupported(Some("Only single-argument lambda functions are supported")) + } + } + + override def convert( + expr: ArrayExists, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayProto = exprToProtoInternal(expr.argument, inputs, binding) + if (arrayProto.isEmpty) { + withInfo(expr, "Failed to serialize array argument") + return None + } + + val registryKey = CometLambdaRegistry.register(expr) + val keyLiteral = Literal(registryKey) + val keyProto = exprToProtoInternal(keyLiteral, inputs, binding) + if (keyProto.isEmpty) { + CometLambdaRegistry.remove(registryKey) + withInfo(expr, "Failed to serialize registry key literal") + return None + } + + val returnType = serializeDataType(BooleanType).getOrElse { + CometLambdaRegistry.remove(registryKey) + return None + } + + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.ArrayExistsUDF") + .addArgs(arrayProto.get) + .addArgs(keyProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 0abf4e4e9e..3b2f5fcc5c 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -243,7 +243,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val df = spark.read .parquet(path.toString) .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) + .withColumn("idx", org.apache.spark.sql.functions.udf((_: Int) => 1).apply(col("_4"))) .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), @@ -1085,4 +1085,45 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array_exists - integer predicate") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3)), (array(4, 5, 6)), (array(-1, -2)), (NULL)") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 2) FROM t")) + } + } + + test("array_exists - string predicate") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql( + "INSERT INTO t VALUES (array('hello', 'world')), (array('foo')), (array(NULL, 'bar')), (NULL)") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x = 'world') FROM t")) + } + } + + test("array_exists - null elements with three-valued logic") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, NULL, 3)), (array(NULL, NULL)), (array(4, 5))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 10) FROM t")) + } + } + + test("array_exists - all elements match") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(10, 20, 30)), (array(1, 2, 3))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t")) + } + } + + test("array_exists - empty array") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array()), (array(1))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t")) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala index 44ef1a4735..287567bcc8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -156,6 +156,39 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { } } + def arrayExistsBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as int), + | cast((value + 1) % 100 as int), + | cast((value + 2) % 100 as int), + | cast((value + 3) % 100 as int), + | cast((value + 4) % 100 as int), + | cast((value + 5) % 100 as int), + | cast((value + 6) % 100 as int), + | cast((value + 7) % 100 as int), + | cast((value + 8) % 100 as int), + | cast((value + 9) % 100 as int) + | ) as int_arr + |FROM $tbl""".stripMargin)) + + runExpressionBenchmark( + "array_exists - int array (x -> x > 50)", + values, + "SELECT exists(int_arr, x -> x > 50) FROM parquetV1Table") + + runExpressionBenchmark( + "array_exists - int array (x -> x < 0)", + values, + "SELECT exists(int_arr, x -> x < 0) FROM parquetV1Table") + } + } + } + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 4 * 1024 * 1024 @@ -178,5 +211,9 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { runBenchmarkWithTable("ArrayPosition", values) { v => arrayPositionBenchmark(v) } + + runBenchmarkWithTable("ArrayExists", values) { v => + arrayExistsBenchmark(v) + } } }