Skip to content
This repository was archived by the owner on Dec 20, 2018. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ import org.apache.avro.Schema.Type._

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
import org.apache.avro.Schema.Type
import scala.collection.JavaConversions._

/**
* This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
* versa.
*/
object SchemaConverters {

val LOGICAL_TYPE = "logicalType"
val DECIMAL = "decimal"
val PRECISION = "precision"
val SCALE = "scale"

class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex)

case class SchemaType(dataType: DataType, nullable: Boolean)
Expand All @@ -46,7 +53,16 @@ object SchemaConverters {
def toSqlType(avroSchema: Schema): SchemaType = {
avroSchema.getType match {
case INT => SchemaType(IntegerType, nullable = false)
case STRING => SchemaType(StringType, nullable = false)
case STRING => {
val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE)
if (logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)) {
val precision = avroSchema.getJsonProp(PRECISION).asInt
val scale = avroSchema.getJsonProp(SCALE).asInt
SchemaType(DecimalType(precision, scale), nullable = false)
} else {
SchemaType(StringType, nullable = false)
}
}
case BOOLEAN => SchemaType(BooleanType, nullable = false)
case BYTES => SchemaType(BinaryType, nullable = false)
case DOUBLE => SchemaType(DoubleType, nullable = false)
Expand All @@ -57,6 +73,7 @@ object SchemaConverters {

case RECORD =>
val fields = avroSchema.getFields.asScala.map { f =>
f.getJsonProps.foreach(x => f.schema().addProp(x._1, x._2))
val schemaType = toSqlType(f.schema())
StructField(f.name, schemaType.dataType, schemaType.nullable)
}
Expand All @@ -80,7 +97,9 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
toSqlType(remainingUnionTypes.head).copy(nullable = true)
val remainingSchema = remainingUnionTypes.head
avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2))
toSqlType(remainingSchema).copy(nullable = true)
} else {
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
}
Expand Down Expand Up @@ -148,8 +167,21 @@ object SchemaConverters {
val avroType = avroSchema.getType
(sqlType, avroType) match {
// Avro strings are in Utf8, so we have to call toString on them
case (StringType, STRING) | (StringType, ENUM) =>
case (StringType, ENUM) =>
(item: AnyRef) => if (item == null) null else item.toString
case (_, STRING) =>
(item: AnyRef) => if (item == null) {
null
} else {
val logicalType = avroSchema.getJsonProp(LOGICAL_TYPE)
if (logicalType != null && logicalType.asText().equalsIgnoreCase(DECIMAL)) {
val precision = avroSchema.getJsonProp(PRECISION).asInt
val scale = avroSchema.getJsonProp(SCALE).asInt
Decimal.apply(BigDecimal.apply(item.toString()), precision, scale)
} else {
item.toString
}
}
// Byte arrays are reused by avro, so we have to make a copy of them.
case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) |
(FloatType, FLOAT) | (LongType, LONG) =>
Expand Down Expand Up @@ -185,6 +217,7 @@ object SchemaConverters {
val sqlField = struct.fields(i)
val avroField = avroSchema.getField(sqlField.name)
if (avroField != null) {
avroField.getJsonProps.foreach(x => avroField.schema().addProp(x._1, x._2))
val converter = createConverter(avroField.schema(), sqlField.dataType,
path :+ sqlField.name)
converters(i) = converter
Expand Down Expand Up @@ -256,7 +289,9 @@ object SchemaConverters {
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
createConverter(remainingUnionTypes.head, sqlType, path)
val remainingSchema = remainingUnionTypes.head
avroSchema.getJsonProps.foreach(x => remainingSchema.addProp(x._1, x._2))
createConverter(remainingSchema, sqlType, path)
} else {
createConverter(Schema.createUnion(remainingUnionTypes.asJava), sqlType, path)
}
Expand Down
Binary file added src/test/resources/users.avro
Binary file not shown.
13 changes: 13 additions & 0 deletions src/test/scala/com/databricks/spark/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException
class AvroSuite extends FunSuite with BeforeAndAfterAll {
val episodesFile = "src/test/resources/episodes.avro"
val testFile = "src/test/resources/test.avro"
val userFile = "src/test/resources/users.avro"

private var spark: SparkSession = _

Expand Down Expand Up @@ -773,4 +774,16 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
assert(readDf.collect().sameElements(writeDf.collect()))
}
}

test("Logical Types") {
val df = spark.read.avro(userFile)
val decimals = df.select("decimal").collect().map(x => Decimal.apply(x.getDecimal(0)))
val dec1 = Decimal.apply(BigDecimal.apply("55555.555550000"), 25, 9)
val dec2 = Decimal.apply(BigDecimal.apply("8747336654.536756000"), 25, 9)

assert(decimals.apply(0).equals(dec1))
assert(decimals.apply(1).equals(dec2))
assert(df.schema.apply("decimal").dataType == DecimalType(25,9))

}
}