Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions python/pyspark/sql/tests/arrow/test_arrow_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def func(iterator):
expected = df.collect()
self.assertEqual(actual, expected)

def test_coerce_output_type_to_declared_schema(self):
# Regression test: when the user yields a batch whose Arrow type does
# not match the declared output schema, the worker should coerce it
# rather than letting the JVM fail later with an opaque getInt error
# on the wrong ArrowColumnVector accessor.
from pyspark.sql.types import IntegerType, StructField, StructType

def double_x(iter_batches):
for batch in iter_batches:
# The input column is long (int64); produce int64 output even
# though the declared schema is IntegerType (int32).
yield pa.RecordBatch.from_arrays(
[pa.array([v * 2 for v in batch.column("x").to_pylist()], type=pa.int64())],
names=["x"],
)

src = self.spark.createDataFrame([(1,), (2,), (3,)], ["x"])
out = src.mapInArrow(double_x, schema=StructType([StructField("x", IntegerType())]))
self.assertEqual([r.x for r in out.collect()], [2, 4, 6])

def test_large_variable_width_types(self):
with self.sql_conf({"spark.sql.execution.arrow.useLargeVarTypes": True}):
data = [("foo", b"foo"), (None, None), ("bar", b"bar")]
Expand Down Expand Up @@ -213,27 +233,30 @@ def test_negative_and_zero_batch_size(self):
MapInArrowTests.test_map_in_arrow(self)

def test_nested_extraneous_field(self):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}):

def func(iterator):
for _ in iterator:
struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
yield pa.RecordBatch.from_arrays([struct_arr], ["x"])
# The worker now coerces each output batch against the declared output
# schema (same contract as sibling Arrow eval-type branches such as
# SQL_SCALAR_ARROW_UDF). A nested struct field that is not declared in
# the output schema is silently dropped by the pyarrow cast.
def func(iterator):
for _ in iterator:
struct_arr = pa.StructArray.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
yield pa.RecordBatch.from_arrays([struct_arr], ["x"])

df = self.spark.range(1)
with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
df.mapInArrow(func, "x struct<b:int>").collect()
df = self.spark.range(1)
rows = df.mapInArrow(func, "x struct<b:int>").collect()
self.assertEqual([r.x.b for r in rows], [3, 4])

def test_top_level_wrong_order(self):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}):

def func(iterator):
for _ in iterator:
yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"])
# Same as above: when the user yields columns whose names match the
# declared output but in a different order, they are now silently
# reordered by name to align with the declared schema.
def func(iterator):
for _ in iterator:
yield pa.RecordBatch.from_arrays([[1], [2]], ["b", "a"])

df = self.spark.range(1)
with self.assertRaisesRegex(Exception, r"ARROW_TYPE_MISMATCH.*SQL_MAP_ARROW_ITER_UDF"):
df.mapInArrow(func, "a int, b int").collect()
df = self.spark.range(1)
row = df.mapInArrow(func, "a int, b int").first()
self.assertEqual((row.a, row.b), (2, 1))

def test_nullability_widen(self):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.validateSchema.enabled": True}):
Expand Down
23 changes: 21 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
return func, None, None, None
return func, None, None, return_type
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
Expand Down Expand Up @@ -2371,6 +2371,17 @@ def extract_key_value_indexes(grouped_arg_offsets):

assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
udf_func: Callable[[Iterator[pa.RecordBatch]], Iterator[pa.RecordBatch]] = udfs[0][0]
return_type = udfs[0][3]

# Pre-compute target schema so each output batch can be coerced to the
# declared output types (e.g. int64 -> int32 from a pandas roundtrip).
# Without this, a type mismatch only surfaces deep in the JVM as an
# opaque getInt error on the wrong ArrowColumnVector accessor.
return_schema = to_arrow_schema(
return_type,
timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types,
)

def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
"""Apply mapInArrow UDF"""
Expand All @@ -2388,7 +2399,15 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
output_batches,
Iterator[pa.RecordBatch], # type: ignore[type-abstract]
)
yield from map(ArrowBatchTransformer.wrap_struct, verified_iter)
coerced_iter = (
ArrowBatchTransformer.enforce_schema(
batch,
return_schema,
reorder_by_name=runner_conf.assign_cols_by_name,
)
for batch in verified_iter
)
yield from map(ArrowBatchTransformer.wrap_struct, coerced_iter)

# profiling is not supported for UDF
return func, None, ser, ser
Expand Down