Skip to content

[SPARK-53354][CONNECT] Simplify LiteralValueProtoConverter.toCatalystStruct #52098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -320,17 +320,15 @@ object LiteralValueProtoConverter {
toCatalystArray(literal.getArray)

case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
toCatalystStruct(literal.getStruct)._1
toCatalystStruct(literal.getStruct)

case other =>
throw new UnsupportedOperationException(
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
}
}

private def getConverter(
dataType: proto.DataType,
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
dataType.getKindCase match {
case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
case proto.DataType.KindCase.INTEGER => v => v.getInteger
Expand All @@ -354,20 +352,15 @@ object LiteralValueProtoConverter {
case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
case proto.DataType.KindCase.STRUCT =>
if (inferDataType) { v =>
val (struct, structType) = toCatalystStruct(v.getStruct, None)
LiteralValueWithDataType(
struct,
proto.DataType.newBuilder.setStruct(structType).build())
} else { v =>
toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
}
v => toCatalystStructInternal(v.getStruct, dataType.getStruct)
case _ =>
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
}
}

private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
private def getInferredDataType(
literal: proto.Expression.Literal,
recursive: Boolean = false): Option[proto.DataType] = {
if (literal.hasNull) {
return Some(literal.getNull)
}
Expand Down Expand Up @@ -399,8 +392,31 @@ object LiteralValueProtoConverter {
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
// The type of the fields will be inferred from the literals of the fields in the struct.
builder.setStruct(literal.getStruct.getStructType.getStruct)
if (recursive) {
val structType = literal.getStruct.getDataTypeStruct
val structData = literal.getStruct.getElementsList.asScala
val structTypeBuilder = proto.DataType.Struct.newBuilder
for ((element, field) <- structData.zip(structType.getFieldsList.asScala)) {
if (field.hasDataType) {
structTypeBuilder.addFields(field)
} else {
getInferredDataType(element, recursive = true) match {
case Some(dataType) =>
val fieldBuilder = structTypeBuilder.addFieldsBuilder()
fieldBuilder.setName(field.getName)
fieldBuilder.setDataType(dataType)
fieldBuilder.setNullable(field.getNullable)
if (field.hasMetadata) {
fieldBuilder.setMetadata(field.getMetadata)
}
case None => return None
}
}
}
builder.setStruct(structTypeBuilder.build())
} else {
builder.setStruct(proto.DataType.Struct.newBuilder.build())
}
case _ =>
// Not all data types support inferring the data type from the literal at the moment.
// e.g. the type of DayTimeInterval contains extra information like start_field and
Expand All @@ -410,13 +426,6 @@ object LiteralValueProtoConverter {
Some(builder.build())
}

private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
getInferredDataType(literal).getOrElse {
throw InvalidPlanInput(
s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}")
}
}

def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
Expand Down Expand Up @@ -451,9 +460,9 @@ object LiteralValueProtoConverter {
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
}

def toCatalystStruct(
private def toCatalystStructInternal(
struct: proto.Expression.Literal.Struct,
structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = {
structType: proto.DataType.Struct): Any = {
def toTuple[A <: Object](data: Seq[A]): Product = {
try {
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
Expand All @@ -464,78 +473,36 @@ object LiteralValueProtoConverter {
}
}

if (struct.hasDataTypeStruct) {
// The new way to define and convert structs.
val (structData, structType) = if (structTypeOpt.isDefined) {
val structFields = structTypeOpt.get.getFieldsList.asScala
val structData =
struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
getConverter(structField.getDataType)(element)
}
(structData, structTypeOpt.get)
} else {
def protoStructField(
name: String,
dataType: proto.DataType,
nullable: Boolean,
metadata: Option[String]): proto.DataType.StructField = {
val builder = proto.DataType.StructField
.newBuilder()
.setName(name)
.setDataType(dataType)
.setNullable(nullable)
metadata.foreach(builder.setMetadata)
builder.build()
}

val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala

val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
case (element, dataTypeField) =>
if (dataTypeField.hasDataType) {
(getConverter(dataTypeField.getDataType)(element), dataTypeField)
} else {
val outerDataType = getInferredDataTypeOrThrow(element)
val (value, dataType) =
getConverter(outerDataType, inferDataType = true)(element) match {
case LiteralValueWithDataType(value, dataType) => (value, dataType)
case value => (value, outerDataType)
}
(
value,
protoStructField(
dataTypeField.getName,
dataType,
dataTypeField.getNullable,
if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None))
}
}
val elements = struct.getElementsList.asScala
val dataTypes = structType.getFieldsList.asScala.map(_.getDataType)
val structData = elements
.zip(dataTypes)
.map { case (element, dataType) =>
getConverter(dataType)(element)
}
.asInstanceOf[scala.collection.Seq[Object]]
.toSeq

val structType = proto.DataType.Struct
.newBuilder()
.addAllFields(structDataAndFields.map(_._2).asJava)
.build()
toTuple(structData)
}

(structDataAndFields.map(_._1), structType)
def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
if (struct.hasDataTypeStruct) {
val literal = proto.Expression.Literal.newBuilder().setStruct(struct).build()
getInferredDataType(literal, recursive = true) match {
case Some(dataType) => dataType.getStruct
case None => throw InvalidPlanInput("Cannot infer data type from this struct literal.")
}
(toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
} else if (struct.hasStructType) {
// For backward compatibility, we still support the old way to define and convert structs.
val elements = struct.getElementsList.asScala
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
val structData = elements
.zip(dataTypes)
.map { case (element, dataType) =>
getConverter(dataType)(element)
}
.asInstanceOf[scala.collection.Seq[Object]]
.toSeq

(toTuple(structData), struct.getStructType.getStruct)
// For backward compatibility, we still support the old way to
// define and convert struct types.
struct.getStructType.getStruct
} else {
throw InvalidPlanInput("Data type information is missing in the struct literal.")
}
}

private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType)
def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
toCatalystStructInternal(struct, getProtoStructType(struct))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ object LiteralExpressionProtoConverter {
DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType)))

case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
val (structData, structType) = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct)
val structData = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct)
val dataType = DataTypeProtoConverter.toCatalystType(
proto.DataType.newBuilder.setStruct(structType).build())
proto.DataType.newBuilder
.setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
.build())
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
expressions.Literal(convert(structData), dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i
.addElements(LiteralValueProtoConverter.toLiteralProto("test"))
.build()

val (result, resultType) = LiteralValueProtoConverter.toCatalystStruct(structProto)
val result = LiteralValueProtoConverter.toCatalystStruct(structProto)
val resultType = LiteralValueProtoConverter.getProtoStructType(structProto)

// Verify the result is a tuple with correct values
assert(result.isInstanceOf[Product])
Expand Down Expand Up @@ -156,7 +157,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i
assert(!structFields.get(1).getNullable)
assert(!structFields.get(1).hasMetadata)

val (_, structTypeProto) = LiteralValueProtoConverter.toCatalystStruct(literalProto.getStruct)
val structTypeProto = LiteralValueProtoConverter.getProtoStructType(literalProto.getStruct)
assert(structTypeProto.getFieldsList.get(0).getNullable)
assert(structTypeProto.getFieldsList.get(0).hasMetadata)
assert(structTypeProto.getFieldsList.get(0).getMetadata == """{"key":"value"}""")
Expand Down