Skip to content

Commit 5425a8d

Browse files
committed
[SPARK-53354] Simplify LiteralValueProtoConverter.toCatalystStruct
1 parent 7007e1c commit 5425a8d

File tree

1 file changed

+62
-97
lines changed

1 file changed

+62
-97
lines changed

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 62 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,7 @@ object LiteralValueProtoConverter {
328328
}
329329
}
330330

331-
private def getConverter(
332-
dataType: proto.DataType,
333-
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
331+
private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
334332
dataType.getKindCase match {
335333
case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
336334
case proto.DataType.KindCase.INTEGER => v => v.getInteger
@@ -354,20 +352,15 @@ object LiteralValueProtoConverter {
354352
case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
355353
case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
356354
case proto.DataType.KindCase.STRUCT =>
357-
if (inferDataType) { v =>
358-
val (struct, structType) = toCatalystStruct(v.getStruct, None)
359-
LiteralValueWithDataType(
360-
struct,
361-
proto.DataType.newBuilder.setStruct(structType).build())
362-
} else { v =>
363-
toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
364-
}
355+
v => toCatalystStructLegacy(v.getStruct, dataType.getStruct)
365356
case _ =>
366357
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
367358
}
368359
}
369360

370-
private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
361+
private def getInferredDataType(
362+
literal: proto.Expression.Literal,
363+
recursive: Boolean = false): Option[proto.DataType] = {
371364
if (literal.hasNull) {
372365
return Some(literal.getNull)
373366
}
@@ -399,8 +392,31 @@ object LiteralValueProtoConverter {
399392
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
400393
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
401394
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
402-
// The type of the fields will be inferred from the literals of the fields in the struct.
403-
builder.setStruct(literal.getStruct.getStructType.getStruct)
395+
if (recursive) {
396+
val structType = literal.getStruct.getDataTypeStruct
397+
val structData = literal.getStruct.getElementsList.asScala
398+
val structTypeBuilder = proto.DataType.Struct.newBuilder
399+
for ((element, field) <- structData.zip(structType.getFieldsList.asScala)) {
400+
if (field.hasDataType) {
401+
structTypeBuilder.addFields(field)
402+
} else {
403+
getInferredDataType(element, recursive = true) match {
404+
case Some(dataType) =>
405+
val fieldBuilder = structTypeBuilder.addFieldsBuilder()
406+
fieldBuilder.setName(field.getName)
407+
fieldBuilder.setDataType(dataType)
408+
fieldBuilder.setNullable(field.getNullable)
409+
if (field.hasMetadata) {
410+
fieldBuilder.setMetadata(field.getMetadata)
411+
}
412+
case None => return None
413+
}
414+
}
415+
}
416+
builder.setStruct(structTypeBuilder.build())
417+
} else {
418+
builder.setStruct(proto.DataType.Struct.newBuilder.build())
419+
}
404420
case _ =>
405421
// Not all data types support inferring the data type from the literal at the moment.
406422
// e.g. the type of DayTimeInterval contains extra information like start_field and
@@ -410,13 +426,6 @@ object LiteralValueProtoConverter {
410426
Some(builder.build())
411427
}
412428

413-
private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
414-
getInferredDataType(literal).getOrElse {
415-
throw InvalidPlanInput(
416-
s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}")
417-
}
418-
}
419-
420429
def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = {
421430
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
422431
tag: ClassTag[T]): Array[T] = {
@@ -451,91 +460,47 @@ object LiteralValueProtoConverter {
451460
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
452461
}
453462

454-
def toCatalystStruct(
455-
struct: proto.Expression.Literal.Struct,
456-
structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = {
457-
def toTuple[A <: Object](data: Seq[A]): Product = {
458-
try {
459-
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
460-
tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
461-
} catch {
462-
case _: Exception =>
463-
throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
464-
}
463+
private def toTuple[A <: Object](data: Seq[A]): Product = {
464+
try {
465+
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
466+
tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
467+
} catch {
468+
case _: Exception =>
469+
throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
465470
}
471+
}
466472

467-
if (struct.hasDataTypeStruct) {
468-
// The new way to define and convert structs.
469-
val (structData, structType) = if (structTypeOpt.isDefined) {
470-
val structFields = structTypeOpt.get.getFieldsList.asScala
471-
val structData =
472-
struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
473-
getConverter(structField.getDataType)(element)
474-
}
475-
(structData, structTypeOpt.get)
476-
} else {
477-
def protoStructField(
478-
name: String,
479-
dataType: proto.DataType,
480-
nullable: Boolean,
481-
metadata: Option[String]): proto.DataType.StructField = {
482-
val builder = proto.DataType.StructField
483-
.newBuilder()
484-
.setName(name)
485-
.setDataType(dataType)
486-
.setNullable(nullable)
487-
metadata.foreach(builder.setMetadata)
488-
builder.build()
489-
}
490-
491-
val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
492-
493-
val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
494-
case (element, dataTypeField) =>
495-
if (dataTypeField.hasDataType) {
496-
(getConverter(dataTypeField.getDataType)(element), dataTypeField)
497-
} else {
498-
val outerDataType = getInferredDataTypeOrThrow(element)
499-
val (value, dataType) =
500-
getConverter(outerDataType, inferDataType = true)(element) match {
501-
case LiteralValueWithDataType(value, dataType) => (value, dataType)
502-
case value => (value, outerDataType)
503-
}
504-
(
505-
value,
506-
protoStructField(
507-
dataTypeField.getName,
508-
dataType,
509-
dataTypeField.getNullable,
510-
if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None))
511-
}
512-
}
473+
private def toCatalystStructLegacy(
474+
struct: proto.Expression.Literal.Struct,
475+
structType: proto.DataType.Struct): Any = {
476+
val elements = struct.getElementsList.asScala
477+
val dataTypes = structType.getFieldsList.asScala.map(_.getDataType)
478+
val structData = elements
479+
.zip(dataTypes)
480+
.map { case (element, dataType) =>
481+
getConverter(dataType)(element)
482+
}
483+
.asInstanceOf[scala.collection.Seq[Object]]
484+
.toSeq
513485

514-
val structType = proto.DataType.Struct
515-
.newBuilder()
516-
.addAllFields(structDataAndFields.map(_._2).asJava)
517-
.build()
486+
toTuple(structData)
487+
}
518488

519-
(structDataAndFields.map(_._1), structType)
489+
def toCatalystStruct(struct: proto.Expression.Literal.Struct): (Any, proto.DataType.Struct) = {
490+
if (struct.hasDataTypeStruct) {
491+
val literal = proto.Expression.Literal.newBuilder().setStruct(struct).build()
492+
getInferredDataType(literal, recursive = true) match {
493+
case Some(dataType) =>
494+
(toCatalystStructLegacy(struct, dataType.getStruct), dataType.getStruct)
495+
case None => throw InvalidPlanInput("Cannot infer data type from this struct literal.")
520496
}
521-
(toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
522497
} else if (struct.hasStructType) {
523498
// For backward compatibility, we still support the old way to define and convert structs.
524-
val elements = struct.getElementsList.asScala
525-
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
526-
val structData = elements
527-
.zip(dataTypes)
528-
.map { case (element, dataType) =>
529-
getConverter(dataType)(element)
530-
}
531-
.asInstanceOf[scala.collection.Seq[Object]]
532-
.toSeq
533-
534-
(toTuple(structData), struct.getStructType.getStruct)
499+
(
500+
toCatalystStructLegacy(struct, struct.getStructType.getStruct),
501+
struct.getStructType.getStruct)
535502
} else {
536503
throw InvalidPlanInput("Data type information is missing in the struct literal.")
537504
}
538505
}
539-
540-
private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType)
541506
}

0 commit comments

Comments
 (0)