@@ -328,9 +328,7 @@ object LiteralValueProtoConverter {
328
328
}
329
329
}
330
330
331
- private def getConverter (
332
- dataType : proto.DataType ,
333
- inferDataType : Boolean = false ): proto.Expression .Literal => Any = {
331
+ private def getConverter (dataType : proto.DataType ): proto.Expression .Literal => Any = {
334
332
dataType.getKindCase match {
335
333
case proto.DataType .KindCase .SHORT => v => v.getShort.toShort
336
334
case proto.DataType .KindCase .INTEGER => v => v.getInteger
@@ -354,20 +352,15 @@ object LiteralValueProtoConverter {
354
352
case proto.DataType .KindCase .ARRAY => v => toCatalystArray(v.getArray)
355
353
case proto.DataType .KindCase .MAP => v => toCatalystMap(v.getMap)
356
354
case proto.DataType .KindCase .STRUCT =>
357
- if (inferDataType) { v =>
358
- val (struct, structType) = toCatalystStruct(v.getStruct, None )
359
- LiteralValueWithDataType (
360
- struct,
361
- proto.DataType .newBuilder.setStruct(structType).build())
362
- } else { v =>
363
- toCatalystStruct(v.getStruct, Some (dataType.getStruct))._1
364
- }
355
+ v => toCatalystStructLegacy(v.getStruct, dataType.getStruct)
365
356
case _ =>
366
357
throw InvalidPlanInput (s " Unsupported Literal Type: $dataType) " )
367
358
}
368
359
}
369
360
370
- private def getInferredDataType (literal : proto.Expression .Literal ): Option [proto.DataType ] = {
361
+ private def getInferredDataType (
362
+ literal : proto.Expression .Literal ,
363
+ recursive : Boolean = false ): Option [proto.DataType ] = {
371
364
if (literal.hasNull) {
372
365
return Some (literal.getNull)
373
366
}
@@ -399,8 +392,31 @@ object LiteralValueProtoConverter {
399
392
case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
400
393
builder.setCalendarInterval(proto.DataType .CalendarInterval .newBuilder.build())
401
394
case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
402
- // The type of the fields will be inferred from the literals of the fields in the struct.
403
- builder.setStruct(literal.getStruct.getStructType.getStruct)
395
+ if (recursive) {
396
+ val structType = literal.getStruct.getDataTypeStruct
397
+ val structData = literal.getStruct.getElementsList.asScala
398
+ val structTypeBuilder = proto.DataType .Struct .newBuilder
399
+ for ((element, field) <- structData.zip(structType.getFieldsList.asScala)) {
400
+ if (field.hasDataType) {
401
+ structTypeBuilder.addFields(field)
402
+ } else {
403
+ getInferredDataType(element, recursive = true ) match {
404
+ case Some (dataType) =>
405
+ val fieldBuilder = structTypeBuilder.addFieldsBuilder()
406
+ fieldBuilder.setName(field.getName)
407
+ fieldBuilder.setDataType(dataType)
408
+ fieldBuilder.setNullable(field.getNullable)
409
+ if (field.hasMetadata) {
410
+ fieldBuilder.setMetadata(field.getMetadata)
411
+ }
412
+ case None => return None
413
+ }
414
+ }
415
+ }
416
+ builder.setStruct(structTypeBuilder.build())
417
+ } else {
418
+ builder.setStruct(proto.DataType .Struct .newBuilder.build())
419
+ }
404
420
case _ =>
405
421
// Not all data types support inferring the data type from the literal at the moment.
406
422
// e.g. the type of DayTimeInterval contains extra information like start_field and
@@ -410,13 +426,6 @@ object LiteralValueProtoConverter {
410
426
Some (builder.build())
411
427
}
412
428
413
- private def getInferredDataTypeOrThrow (literal : proto.Expression .Literal ): proto.DataType = {
414
- getInferredDataType(literal).getOrElse {
415
- throw InvalidPlanInput (
416
- s " Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}" )
417
- }
418
- }
419
-
420
429
def toCatalystArray (array : proto.Expression .Literal .Array ): Array [_] = {
421
430
def makeArrayData [T ](converter : proto.Expression .Literal => T )(implicit
422
431
tag : ClassTag [T ]): Array [T ] = {
@@ -451,91 +460,47 @@ object LiteralValueProtoConverter {
451
460
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
452
461
}
453
462
454
- def toCatalystStruct (
455
- struct : proto.Expression .Literal .Struct ,
456
- structTypeOpt : Option [proto.DataType .Struct ] = None ): (Any , proto.DataType .Struct ) = {
457
- def toTuple [A <: Object ](data : Seq [A ]): Product = {
458
- try {
459
- val tupleClass = SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
460
- tupleClass.getConstructors.head.newInstance(data : _* ).asInstanceOf [Product ]
461
- } catch {
462
- case _ : Exception =>
463
- throw InvalidPlanInput (s " Unsupported Literal: ${data.mkString(" Array(" , " , " , " )" )}) " )
464
- }
463
+ private def toTuple [A <: Object ](data : Seq [A ]): Product = {
464
+ try {
465
+ val tupleClass = SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
466
+ tupleClass.getConstructors.head.newInstance(data : _* ).asInstanceOf [Product ]
467
+ } catch {
468
+ case _ : Exception =>
469
+ throw InvalidPlanInput (s " Unsupported Literal: ${data.mkString(" Array(" , " , " , " )" )}) " )
465
470
}
471
+ }
466
472
467
- if (struct.hasDataTypeStruct) {
468
- // The new way to define and convert structs.
469
- val (structData, structType) = if (structTypeOpt.isDefined) {
470
- val structFields = structTypeOpt.get.getFieldsList.asScala
471
- val structData =
472
- struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
473
- getConverter(structField.getDataType)(element)
474
- }
475
- (structData, structTypeOpt.get)
476
- } else {
477
- def protoStructField (
478
- name : String ,
479
- dataType : proto.DataType ,
480
- nullable : Boolean ,
481
- metadata : Option [String ]): proto.DataType .StructField = {
482
- val builder = proto.DataType .StructField
483
- .newBuilder()
484
- .setName(name)
485
- .setDataType(dataType)
486
- .setNullable(nullable)
487
- metadata.foreach(builder.setMetadata)
488
- builder.build()
489
- }
490
-
491
- val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
492
-
493
- val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
494
- case (element, dataTypeField) =>
495
- if (dataTypeField.hasDataType) {
496
- (getConverter(dataTypeField.getDataType)(element), dataTypeField)
497
- } else {
498
- val outerDataType = getInferredDataTypeOrThrow(element)
499
- val (value, dataType) =
500
- getConverter(outerDataType, inferDataType = true )(element) match {
501
- case LiteralValueWithDataType (value, dataType) => (value, dataType)
502
- case value => (value, outerDataType)
503
- }
504
- (
505
- value,
506
- protoStructField(
507
- dataTypeField.getName,
508
- dataType,
509
- dataTypeField.getNullable,
510
- if (dataTypeField.hasMetadata) Some (dataTypeField.getMetadata) else None ))
511
- }
512
- }
473
+ private def toCatalystStructLegacy (
474
+ struct : proto.Expression .Literal .Struct ,
475
+ structType : proto.DataType .Struct ): Any = {
476
+ val elements = struct.getElementsList.asScala
477
+ val dataTypes = structType.getFieldsList.asScala.map(_.getDataType)
478
+ val structData = elements
479
+ .zip(dataTypes)
480
+ .map { case (element, dataType) =>
481
+ getConverter(dataType)(element)
482
+ }
483
+ .asInstanceOf [scala.collection.Seq [Object ]]
484
+ .toSeq
513
485
514
- val structType = proto.DataType .Struct
515
- .newBuilder()
516
- .addAllFields(structDataAndFields.map(_._2).asJava)
517
- .build()
486
+ toTuple(structData)
487
+ }
518
488
519
- (structDataAndFields.map(_._1), structType)
489
+ def toCatalystStruct (struct : proto.Expression .Literal .Struct ): (Any , proto.DataType .Struct ) = {
490
+ if (struct.hasDataTypeStruct) {
491
+ val literal = proto.Expression .Literal .newBuilder().setStruct(struct).build()
492
+ getInferredDataType(literal, recursive = true ) match {
493
+ case Some (dataType) =>
494
+ (toCatalystStructLegacy(struct, dataType.getStruct), dataType.getStruct)
495
+ case None => throw InvalidPlanInput (" Cannot infer data type from this struct literal." )
520
496
}
521
- (toTuple(structData.toSeq.asInstanceOf [Seq [Object ]]), structType)
522
497
} else if (struct.hasStructType) {
523
498
// For backward compatibility, we still support the old way to define and convert structs.
524
- val elements = struct.getElementsList.asScala
525
- val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
526
- val structData = elements
527
- .zip(dataTypes)
528
- .map { case (element, dataType) =>
529
- getConverter(dataType)(element)
530
- }
531
- .asInstanceOf [scala.collection.Seq [Object ]]
532
- .toSeq
533
-
534
- (toTuple(structData), struct.getStructType.getStruct)
499
+ (
500
+ toCatalystStructLegacy(struct, struct.getStructType.getStruct),
501
+ struct.getStructType.getStruct)
535
502
} else {
536
503
throw InvalidPlanInput (" Data type information is missing in the struct literal." )
537
504
}
538
505
}
539
-
540
- private case class LiteralValueWithDataType (value : Any , dataType : proto.DataType )
541
506
}
0 commit comments