diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs index 644252b46b..7ff78713be 100644 --- a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -184,6 +184,17 @@ impl PhysicalExpr for WideDecimalBinaryExpr { let left_val = self.left.evaluate(batch)?; let right_val = self.right.evaluate(batch)?; + // Track scalar-ness so we can return a Scalar when both inputs are scalars. + // Without this, a (Scalar op Scalar) result would be returned as a length-1 + // Array, and downstream comparisons against full batches would incorrectly + // see two Array operands with mismatched lengths instead of (Array, Scalar), + // crashing arrow-ord's compare_op with "Cannot compare arrays of different + // lengths". This pattern appears, for example, in TPC-DS q23's BHJ filter + // `0.95 * scalar_subquery > ssales`. + let both_scalar = matches!( + (&left_val, &right_val), + (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) + ); let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, &right_val) { (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { @@ -280,7 +291,16 @@ impl PhysicalExpr for WideDecimalBinaryExpr { result }; let result = result.with_data_type(DataType::Decimal128(p_out, s_out)); - Ok(ColumnarValue::Array(Arc::new(result))) + if both_scalar { + // Convert the length-1 result back into a Scalar so downstream + // expressions (binary ops, comparisons) can take the scalar fast-path + // and propagate the scalar tag (Datum::is_scalar) through arrow-rs + // kernels. + let scalar = datafusion::common::ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(Arc::new(result))) + } } fn children(&self) -> Vec<&Arc> { @@ -557,4 +577,66 @@ mod tests { let arr = result.as_primitive::(); assert_eq!(arr.value(0), 20000); // 2.0000 } + + /// Regression test for the Scalar x Scalar wide-decimal evaluation path. + /// + /// When both inputs are `ColumnarValue::Scalar`, `evaluate` must return a + /// `ColumnarValue::Scalar` -- not a length-1 `ColumnarValue::Array`. Otherwise + /// downstream comparisons against full batches see two `Array` operands with + /// mismatched lengths and arrow-ord's `compare_op` rejects them with + /// "Cannot compare arrays of different lengths, got N vs 1". This pattern + /// appears, for example, in TPC-DS q23's BHJ filter + /// `0.95 * scalar_subquery > ssales`. + #[test] + fn test_scalar_scalar_returns_scalar() { + use datafusion::common::ScalarValue; + use datafusion::physical_expr::expressions::Literal; + + // 0.95 * 100.00 -- the same Scalar x Scalar decimal multiply pattern that + // appears in TPC-DS q23's filter `0.95 * scalar_subquery > ssales`. + let left: Arc = + Arc::new(Literal::new(ScalarValue::Decimal128(Some(95), 38, 2))); + let right: Arc = + Arc::new(Literal::new(ScalarValue::Decimal128(Some(10000), 38, 2))); + + let expr = WideDecimalBinaryExpr::new( + left, + right, + WideDecimalOp::Multiply, + 38, + 2, + EvalMode::Legacy, + ); + + // Empty schema -- both inputs are Literals so no columns are needed. + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + match expr.evaluate(&batch).unwrap() { + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(v), 38, 2)) => { + // 0.95 * 100.00 = 95.00 -> at scale 2, integer 9500 + assert_eq!(v, 9500); + } + ColumnarValue::Scalar(other) => { + panic!("expected Decimal128(Some(_), 38, 2), got {other:?}"); + } + ColumnarValue::Array(_) => { + panic!("Scalar x Scalar must return ColumnarValue::Scalar, not Array"); + } + } + } + + /// Companion test: when at least one input is an Array, the result must remain an Array. + /// Guards against over-eager scalar-unwrapping in the fix. + #[test] + fn test_array_input_returns_array() { + let batch = make_batch( + vec![Some(150), Some(250)], + 38, + 2, + vec![Some(100), Some(200)], + 38, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2, EvalMode::Legacy).unwrap(); + assert_eq!(result.len(), 2); + } }