diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py index ff0de0c4696c8..ae0a4af2a2e34 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py @@ -21,7 +21,7 @@ from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin -from pyspark.sql.types import VarcharType +from pyspark.sql.types import DayTimeIntervalType, VarcharType, StructType, StructField, StringType from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -243,6 +243,55 @@ def test_udf_use_arrow_and_session_conf(self): udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF ) + def test_day_time_interval_type_casting(self): + """Test that DayTimeIntervalType UDFs work with Arrow and preserve field specifications.""" + + # HOUR TO SECOND + @udf(useArrow=True, returnType=DayTimeIntervalType(1, 3)) + def return_interval(x): + return x + + # UDF input: HOUR TO SECOND, UDF output: HOUR TO SECOND + df = self.spark.sql("SELECT INTERVAL '200:13:50.3' HOUR TO SECOND as value").select( + return_interval("value").alias("result") + ) + self.assertEqual(df.schema.fields[0].dataType, DayTimeIntervalType(1, 3)) + self.assertIsNotNone(df.collect()[0]["result"]) + + # UDF input: DAY TO SECOND, UDF output: HOUR TO SECOND + df2 = self.spark.sql("SELECT INTERVAL '1 10:30:45.123' DAY TO SECOND as value").select( + return_interval("value").alias("result") + ) + self.assertEqual(df.schema.fields[0].dataType, DayTimeIntervalType(1, 3)) + self.assertIsNotNone(df2.collect()[0]["result"]) + + def test_day_time_interval_in_struct(self): + """Test that DayTimeIntervalType works within StructType with Arrow UDFs.""" + + struct_type = StructType( + [ + StructField("interval_field", DayTimeIntervalType(1, 3)), + StructField("name", StringType()), + ] + ) + + @udf(useArrow=True, returnType=struct_type) + def create_struct_with_interval(interval_val, name_val): + return Row(interval_field=interval_val, name=name_val) + + df = self.spark.sql( + """ + SELECT INTERVAL '15:30:45.678' HOUR TO SECOND as interval_val, + 'test_name' as name_val + """ + ).select(create_struct_with_interval("interval_val", "name_val").alias("result")) + + self.assertEqual(df.schema.fields[0].dataType, struct_type) + self.assertEqual(df.schema.fields[0].dataType.fields[0].dataType, DayTimeIntervalType(1, 3)) + result = df.collect()[0]["result"] + self.assertIsNotNone(result["interval_field"]) + self.assertEqual(result["name"], "test_name") + @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index fce0807a7d230..3ecc84a1578a4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -458,6 +458,8 @@ object DataType { def transform: PartialFunction[DataType, DataType] = { case dt @ (_: CharType | _: VarcharType) => dt case _: StringType => StringType + // SPARK-53330 (see below) + case _: DayTimeIntervalType => DayTimeIntervalType.DEFAULT } if (checkComplexTypes) { @@ -465,6 +467,10 @@ object DataType { } else { (from, to) match { case (a: StringType, b: StringType) => a.constraint == b.constraint + // SPARK-53330: Arrow serialization always returns DayTimeIntervalType(0, 3) + // as it has the maximum range, we can always assume that we can match + // with the target type. + case (x: DayTimeIntervalType, y: DayTimeIntervalType) => true case (fromDataType, toDataType) => fromDataType == toDataType }