Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 71 additions & 47 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
ArrowStreamCoGroupSerializer,
CogroupPandasUDFSerializer,
ApplyInPandasWithStateSerializer,
TransformWithStateInPandasSerializer,
TransformWithStateInPandasInitStateSerializer,
Expand Down Expand Up @@ -489,28 +488,6 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu
)


def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
import pandas as pd

left_df = pd.concat(left_value_series, axis=1)
right_df = pd.concat(right_value_series, axis=1)

if len(argspec.args) == 2:
result = f(left_df, right_df)
elif len(argspec.args) == 3:
key_series = left_key_series if not left_df.empty else right_key_series
key = tuple(s[0] for s in key_series)
result = f(key, left_df, right_df)
verify_pandas_result(
result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False
)

return result

return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)]


def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
# the types of the fields have to be identical to return type
# an empty table can have no columns; if there are columns, they have to match
Expand Down Expand Up @@ -1020,7 +997,7 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index):
)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf)
return func, args_offsets, return_type, len(argspec.args)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return func, args_offsets, return_type, len(argspec.args)
Expand Down Expand Up @@ -2250,14 +2227,7 @@ def read_udfs(pickleSer, udf_info_list, eval_type, runner_conf, eval_conf):
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = ArrowStreamCoGroupSerializer(write_start_stream=True)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = CogroupPandasUDFSerializer(
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
assign_cols_by_name=runner_conf.assign_cols_by_name,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
arrow_cast=True,
)
ser = ArrowStreamCoGroupSerializer(write_start_stream=True)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
ser = ApplyInPandasWithStateSerializer(
timezone=runner_conf.timezone,
Expand Down Expand Up @@ -3067,6 +3037,75 @@ def cogrouped_func(
# profiling is not supported for UDF
return cogrouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
import pyarrow as pa
import pandas as pd

assert num_udfs == 1, "One COGROUPED_MAP_PANDAS UDF expected here."
cogrouped_udf, arg_offsets, return_type, num_udf_args = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)

left_key_offsets, left_value_offsets = parsed_offsets[0]
right_key_offsets, right_value_offsets = parsed_offsets[1]
output_schema = StructType([StructField("_0", return_type)])

def cogrouped_func(
split_index: int,
data: Iterator[Tuple[list[pa.RecordBatch], list[pa.RecordBatch]]],
) -> Iterator[pa.RecordBatch]:
"""Apply cogroupBy Pandas UDF."""
for left_batches, right_batches in data:
left_table = pa.Table.from_batches(left_batches)
right_table = pa.Table.from_batches(right_batches)
left_series = ArrowBatchTransformer.to_pandas(
left_table,
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
right_series = ArrowBatchTransformer.to_pandas(
right_table,
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
left_df = pd.concat([left_series[o] for o in left_value_offsets], axis=1)
right_df = pd.concat([right_series[o] for o in right_value_offsets], axis=1)

if num_udf_args == 2:
result = cogrouped_udf(left_df, right_df)
else:
key_series = (
[left_series[o] for o in left_key_offsets]
if not left_df.empty
else [right_series[o] for o in right_key_offsets]
)
key = tuple(s.iloc[0] for s in key_series)
result = cogrouped_udf(key, left_df, right_df)

del left_batches, right_batches, left_table, right_table
del left_series, right_series, left_df, right_df

verify_pandas_result(
result,
return_type,
runner_conf.assign_cols_by_name,
truncate_return_schema=False,
)

yield PandasToArrowConversion.convert(
[result],
output_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
del result

# profiling is not supported for UDF
return cogrouped_func, None, ser, ser

if (
eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
and not runner_conf.use_legacy_pandas_udf_conversion
Expand Down Expand Up @@ -3554,21 +3593,6 @@ def mapper(a):

return f(keys, vals, state)

elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
arg_offsets, f = udfs[0]

parsed_offsets = extract_key_value_indexes(arg_offsets)

def mapper(a):
df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
return f(df1_keys, df1_vals, df2_keys, df2_vals)

elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
# We assume there is only one UDF here because grouped agg doesn't
# support combining multiple UDFs.
Expand Down