diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py index 5119e0e827f6d..9c9712e62c6ad 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py @@ -79,6 +79,26 @@ def func(iterator): expected = df.collect() self.assertEqual(actual, expected) + def test_coerce_output_type_to_declared_schema(self): + # Regression test: when the user yields a batch whose Arrow type does + # not match the declared output schema, the worker should coerce it + # rather than letting the JVM fail later with an opaque getInt error + # on the wrong ArrowColumnVector accessor. + from pyspark.sql.types import IntegerType, StructField, StructType + + def double_x(iter_batches): + for batch in iter_batches: + # The input column is long (int64); produce int64 output even + # though the declared schema is IntegerType (int32). + yield pa.RecordBatch.from_arrays( + [pa.array([v * 2 for v in batch.column("x").to_pylist()], type=pa.int64())], + names=["x"], + ) + + src = self.spark.createDataFrame([(1,), (2,), (3,)], ["x"]) + out = src.mapInArrow(double_x, schema=StructType([StructField("x", IntegerType())])) + self.assertEqual([r.x for r in out.collect()], [2, 4, 6]) + def test_large_variable_width_types(self): with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": True}): data = [("foo", b"foo"), (None, None), ("bar", b"bar")] @@ -213,27 +233,30 @@ def test_negative_and_zero_batch_size(self): MapInArrowTests.test_map_in_arrow(self) def test_nested_extraneous_field(self): - with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}): - - def func(iterator): - for _ in iterator: - struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"]) - yield pa.RecordBatch.from_arrays([struct_arr], ["x"]) + # The worker now coerces each output batch against the declared output + # schema (same contract as sibling Arrow eval-type branches such as + # SQL_SCALAR_ARROW_UDF). A nested struct field that is not declared in + # the output schema is silently dropped by the pyarrow cast. + def func(iterator): + for _ in iterator: + struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"]) + yield pa.RecordBatch.from_arrays([struct_arr], ["x"]) - df = self.spark.range(1) - with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"): - df.mapInArrow(func, "x struct").collect() + df = self.spark.range(1) + rows = df.mapInArrow(func, "x struct").collect() + self.assertEqual([r.x.b for r in rows], [3, 4]) def test_top_level_wrong_order(self): - with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}): - - def func(iterator): - for _ in iterator: - yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"]) + # Same as above: when the user yields columns whose names match the + # declared output but in a different order, they are now silently + # reordered by name to align with the declared schema. + def func(iterator): + for _ in iterator: + yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"]) - df = self.spark.range(1) - with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"): - df.mapInArrow(func, "a int, b int").collect() + df = self.spark.range(1) + row = df.mapInArrow(func, "a int, b int").first() + self.assertEqual((row.a, row.b), (2, 1)) def test_nullability_widen(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 306d4f80dbe5d..f1377e730b6ae 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -989,7 +989,7 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - return func, None, None, None + return func, None, None, return_type elif eval_type in ( PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF, @@ -2371,6 +2371,17 @@ def extract_key_value_indexes(grouped_arg_offsets): assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." udf_func: Callable[[Iterator[pa.RecordBatch]], Iterator[pa.RecordBatch]] = udfs[0][0] + return_type = udfs[0][3] + + # Pre-compute target schema so each output batch can be coerced to the + # declared output types (e.g. int64 -> int32 from a pandas roundtrip). + # Without this, a type mismatch only surfaces deep in the JVM as an + # opaque getInt error on the wrong ArrowColumnVector accessor. + return_schema = to_arrow_schema( + return_type, + timezone="UTC", + prefers_large_types=runner_conf.use_large_var_types, + ) def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """Apply mapInArrow UDF""" @@ -2388,7 +2399,15 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record output_batches, Iterator[pa.RecordBatch], # type: ignore[type-abstract] ) - yield from map(ArrowBatchTransformer.wrap_struct, verified_iter) + coerced_iter = ( + ArrowBatchTransformer.enforce_schema( + batch, + return_schema, + reorder_by_name=runner_conf.assign_cols_by_name, + ) + for batch in verified_iter + ) + yield from map(ArrowBatchTransformer.wrap_struct, coerced_iter) # profiling is not supported for UDF return func, None, ser, ser