Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.sql.functions import arrow_udf, ArrowUDFType
from pyspark.util import PythonEvalType
from pyspark.sql import Row
from pyspark.sql.types import ArrayType, YearMonthIntervalType
from pyspark.sql import functions as sf
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
Expand Down Expand Up @@ -811,6 +812,31 @@ def agg_int2(id: pa.Array) -> int:
# Integer value 2147483657 not in range: -2147483648 to 2147483647
result3.collect()

def test_unsupported_return_types(self):
import pyarrow as pa

with self.quiet():
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type with grouped aggregate "
"Arrow UDFs.*ArrayType.*YearMonthIntervalType",
):
arrow_udf(
lambda x: x,
ArrayType(ArrayType(YearMonthIntervalType())),
ArrowUDFType.GROUPED_AGG,
)

with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type with grouped aggregate "
"Arrow UDFs.*ArrayType.*YearMonthIntervalType",
):

@arrow_udf(ArrayType(ArrayType(YearMonthIntervalType())), ArrowUDFType.GROUPED_AGG)
def func_a(a: pa.Array) -> pa.Scalar:
return pa.compute.max(a)


class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Row,
MapType,
BinaryType,
YearMonthIntervalType,
)
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
Expand Down Expand Up @@ -1022,6 +1023,26 @@ def test_return_type_coercion(self):
# Integer value 2147483652 not in range: -2147483648 to 2147483647
result3.collect()

def test_unsupported_return_types(self):
import pyarrow as pa

with self.quiet():
for udf_type in [ArrowUDFType.SCALAR, ArrowUDFType.SCALAR_ITER]:
with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*scalar Arrow UDF.*ArrayType.*YearMonthIntervalType",
):
arrow_udf(lambda x: x, ArrayType(YearMonthIntervalType()), udf_type)

with self.assertRaisesRegex(
NotImplementedError,
"Invalid return type.*scalar Arrow UDF.*ArrayType.*YearMonthIntervalType",
):

@arrow_udf(ArrayType(YearMonthIntervalType()))
def func_a(a: pa.Array) -> pa.Array:
return a


class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,19 @@ def _check_return_type(returnType: DataType, evalType: int) -> None:
"feature": f"Invalid return type with scalar Pandas UDFs: " f"{returnType}"
},
)
elif (
evalType == PythonEvalType.SQL_SCALAR_ARROW_UDF
or evalType == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
):
try:
to_arrow_type(returnType)
except TypeError:
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={
"feature": f"Invalid return type with scalar Arrow UDFs: " f"{returnType}"
},
)
elif (
evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
Expand Down Expand Up @@ -355,6 +368,18 @@ def _check_return_type(returnType: DataType, evalType: int) -> None:
f"{returnType}"
},
)
elif evalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
try:
# Different from SQL_GROUPED_AGG_PANDAS_UDF, StructType is allowed here
to_arrow_type(returnType)
except TypeError:
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={
"feature": f"Invalid return type with grouped aggregate Arrow UDFs: "
f"{returnType}"
},
)

@property
def returnType(self) -> DataType:
Expand Down