diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index 985486c5..854e4b04 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -29,6 +29,8 @@ import org.apache.avro.Schema.Type._ import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ +import org.apache.avro.Schema.Type +import scala.collection.JavaConversions._ /** * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice @@ -36,6 +38,11 @@ import org.apache.spark.sql.types._ */ object SchemaConverters { + val LOGICAL_TYPE = "logicalType" + val DECIMAL = "decimal" + val PRECISION = "precision" + val SCALE = "scale" + class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) case class SchemaType(dataType: DataType, nullable: Boolean) @@ -46,7 +53,16 @@ object SchemaConverters { def toSqlType(avroSchema: Schema): SchemaType = { avroSchema.getType match { case INT => SchemaType(IntegerType, nullable = false) - case STRING => SchemaType(StringType, nullable = false) + case STRING => { + val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE) + if (logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)) { + val precision = avroSchema.getJsonProp(PRECISION).asInt + val scale = avroSchema.getJsonProp(SCALE).asInt + SchemaType(DecimalType(precision, scale), nullable = false) + } else { + SchemaType(StringType, nullable = false) + } + } case BOOLEAN => SchemaType(BooleanType, nullable = false) case BYTES => SchemaType(BinaryType, nullable = false) case DOUBLE => SchemaType(DoubleType, nullable = false) @@ -57,6 +73,7 @@ object SchemaConverters { case RECORD => val fields = avroSchema.getFields.asScala.map { f => + f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2)) val schemaType = toSqlType(f.schema()) StructField(f.name, schemaType.dataType, schemaType.nullable) } @@ -80,7 +97,9 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { - toSqlType(remainingUnionTypes.head).copy(nullable = true) + val remainingSchema = remainingUnionTypes.head + avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2)) + toSqlType(remainingSchema).copy(nullable = true) } else { toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) } @@ -148,8 +167,21 @@ object SchemaConverters { val avroType = avroSchema.getType (sqlType, avroType) match { // Avro strings are in Utf8, so we have to call toString on them - case (StringType, STRING) | (StringType, ENUM) => + case (StringType, ENUM) => (item: AnyRef) => if (item == null) null else item.toString + case (_, STRING) => + (item: AnyRef) => if (item == null) { + null + } else { + val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE) + if (logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)) { + val precision = avroSchema.getJsonProp(PRECISION).asInt + val scale = avroSchema.getJsonProp(SCALE).asInt + Decimal.apply(BigDecimal.apply(item.toString()), precision, scale) + } else { + item.toString + } + } // Byte arrays are reused by avro, so we have to make a copy of them. case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) | (FloatType, FLOAT) | (LongType, LONG) => @@ -185,6 +217,7 @@ object SchemaConverters { val sqlField = struct.fields(i) val avroField = avroSchema.getField(sqlField.name) if (avroField != null) { + avroField.getJsonProps.foreach(x => avroField.schema().addProp(x._1, x._2)) val converter = createConverter(avroField.schema(), sqlField.dataType, path :+ sqlField.name) converters(i) = converter @@ -256,7 +289,9 @@ object SchemaConverters { if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { - createConverter(remainingUnionTypes.head, sqlType, path) + val remainingSchema = remainingUnionTypes.head + avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2)) + createConverter(remainingSchema, sqlType, path) } else { createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path) } diff --git a/src/test/resources/users.avro b/src/test/resources/users.avro new file mode 100644 index 00000000..95050de4 Binary files /dev/null and b/src/test/resources/users.avro differ diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 0051a0f0..e7c14768 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -36,6 +36,7 @@ import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" + val userFile = "src/test/resources/users.avro" private var spark: SparkSession = _ @@ -773,4 +774,16 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { assert(readDf.collect().sameElements(writeDf.collect())) } } + + test("Logical Types") { + val df = spark.read.avro(userFile) + val decimals = df.select("decimal").collect().map(x => Decimal.apply(x.getDecimal(0))) + val dec1 = Decimal.apply(BigDecimal.apply("55555.555550000"), 25, 9) + val dec2 = Decimal.apply(BigDecimal.apply("8747336654.536756000"), 25, 9) + + assert(decimals.apply(0).equals(dec1)) + assert(decimals.apply(1).equals(dec2)) + assert(df.schema.apply("decimal").dataType == DecimalType(25,9)) + + } }