diff --git a/README.md b/README.md index ca6f168b..4cdbe22b 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ This library supports writing of all Spark SQL types into Avro. For most types, | ---------------|-----------| | ByteType | int | | ShortType | int | -| DecimalType | string | +| DecimalType | bytes | | BinaryType | bytes | | TimestampType | long | | StructType | record | diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index c746b50c..e1ae624f 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -96,7 +96,10 @@ private[avro] class AvroOutputWriter( } case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | BooleanType => identity - case _: DecimalType => (item: Any) => if (item == null) null else item.toString + case decimalType: DecimalType => (item: Any) => if (item == null) null else { + val decimal = item.asInstanceOf[java.math.BigDecimal] + ByteBuffer.wrap(decimal.unscaledValue().toByteArray) + } case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime case DateType => (item: Any) => diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index 7f8e20f4..63b22900 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -38,6 +38,20 @@ object SchemaConverters { case class SchemaType(dataType: DataType, nullable: Boolean) + /** + * Indicator of a field with decimal logical type and scale property. + */ + private def isDecimalField(avroSchema: Schema): Boolean = { + val nullableLogicalTypeNode = avroSchema.getJsonProp("logicalType") + val logicalTypeOption = Option(nullableLogicalTypeNode).map(_.asText()) + val matchLogicalType = logicalTypeOption == Some("decimal") + val hasScale = Option(decimalScaleProp(avroSchema)) + .map(_.asInt(Int.MinValue)).exists(_ >= 0) + val hasPrecision = Option(decimalPrecisionProp(avroSchema)) + .map(_.asInt(Int.MinValue)).exists(_ > 0) + matchLogicalType && hasScale && hasPrecision + } + /** * This function takes an avro schema and returns a sql schema. */ @@ -46,7 +60,12 @@ object SchemaConverters { case INT => SchemaType(IntegerType, nullable = false) case STRING => SchemaType(StringType, nullable = false) case BOOLEAN => SchemaType(BooleanType, nullable = false) - case BYTES => SchemaType(BinaryType, nullable = false) + case BYTES => if (isDecimalField(avroSchema)) { + SchemaType(DecimalType( + decimalPrecisionProp(avroSchema).asInt, + decimalScaleProp(avroSchema).asInt + ), nullable = false) + } else SchemaType(BinaryType, nullable = false) case DOUBLE => SchemaType(DoubleType, nullable = false) case FLOAT => SchemaType(FloatType, nullable = false) case LONG => SchemaType(LongType, nullable = false) @@ -106,6 +125,14 @@ object SchemaConverters { } } + private def decimalScaleProp(avroSchema: Schema) = { + avroSchema.getJsonProp("scale") + } + + private def decimalPrecisionProp(avroSchema: Schema) = { + avroSchema.getJsonProp("precision") + } + /** * This function converts sparkSQL StructType into avro schema. This method uses two other * converter methods in order to do the conversion. @@ -170,6 +197,17 @@ object SchemaConverters { bytes } + case (decimalType: DecimalType, BYTES) => + (item: AnyRef) => + if (item == null) { + null + } else { + val byteBuffer = item.asInstanceOf[ByteBuffer] + val bytes = new Array[Byte](byteBuffer.remaining) + byteBuffer.get(bytes) + BigDecimal(BigInt(bytes), decimalType.scale) + } + case (struct: StructType, RECORD) => val length = struct.fields.length val converters = new Array[AnyRef => AnyRef](length) @@ -323,7 +361,8 @@ object SchemaConverters { case LongType => schemaBuilder.longType() case FloatType => schemaBuilder.floatType() case DoubleType => schemaBuilder.doubleType() - case _: DecimalType => schemaBuilder.stringType() + case decimalType: DecimalType => + createBytesWithDecimalLogicalType(schemaBuilder.bytesBuilder(), decimalType) case StringType => schemaBuilder.stringType() case BinaryType => schemaBuilder.bytesType() case BooleanType => schemaBuilder.booleanType() @@ -350,6 +389,15 @@ object SchemaConverters { } } + private def createBytesWithDecimalLogicalType[T]( + bytesBuilder: BytesBuilder[T], decimalType: DecimalType) = { + bytesBuilder + .prop("logicalType", "decimal") + .prop("precision", decimalType.precision.toString) + .prop("scale", decimalType.scale.toString) + .endBytes() + } + /** * This function is used to construct fields of the avro record, where schema of the field is * specified by avro representation of dataType. Since builders for record fields are different @@ -367,7 +415,8 @@ object SchemaConverters { case LongType => newFieldBuilder.longType() case FloatType => newFieldBuilder.floatType() case DoubleType => newFieldBuilder.doubleType() - case _: DecimalType => newFieldBuilder.stringType() + case decimalType: DecimalType => + createBytesWithDecimalLogicalType(newFieldBuilder.bytesBuilder(), decimalType) case StringType => newFieldBuilder.stringType() case BinaryType => newFieldBuilder.bytesType() case BooleanType => newFieldBuilder.booleanType() diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 4843ad46..5cc910fb 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -478,6 +478,9 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { for (i <- arrayOfByte.indices) { arrayOfByte(i) = i.toByte } + + val decimalBytes = new java.math.BigDecimal("3.14").unscaledValue().toByteArray + val cityRDD = spark.sparkContext.parallelize(Seq( Row("San Francisco", 12, new Timestamp(666), null, arrayOfByte), Row("Palo Alto", null, new Timestamp(777), null, arrayOfByte), @@ -492,9 +495,9 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { val times = spark.read.avro(avroDir).select("Time").collect() assert(times.map(_(0)).toSet == Set(666, 777, 42)) - // DecimalType should be converted to string + // DecimalType should be converted to java.math.BigDecimal val decimals = spark.read.avro(avroDir).select("Decimal").collect() - assert(decimals.map(_(0)).contains("3.14")) + assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14"))) // There should be a null entry val length = spark.read.avro(avroDir).select("Length").collect()