diff --git a/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala
new file mode 100644
index 00000000..ac3f794c
--- /dev/null
+++ b/src/main/scala/com/databricks/spark/avro/AvroEncoder.scala
@@ -0,0 +1,591 @@
+/*
+ * Copyright 2014 Databricks
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.databricks.spark.avro
+
+/**
+ * A Spark-SQL Encoder for Avro objects
+ */
+import java.util.{Map => JMap}
+import scala.collection.JavaConverters._
+import com.databricks.spark.avro.SchemaConverters.{IncompatibleSchemaException, SchemaType, resolveUnionType, toSqlType}
+import org.apache.avro.Schema
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic.{GenericData, IndexedRecord}
+import org.apache.avro.reflect.ReflectData
+import org.apache.avro.specific.SpecificRecord
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.objects._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
+
+/**
+ * A Spark-SQL Encoder for Avro objects
+ */
+object AvroEncoder {
+ /**
+ * Provides an Encoder for Avro objects of the given class
+ * @param avroClass the class of the Avro object for which to generate the Encoder
+ * @tparam T the type of the Avro class, must implement SpecificRecord
+ * @return an Encoder for the given Avro class
+ */
+ def of[T <: SpecificRecord](avroClass: Class[T]): Encoder[T] = {
+ AvroExpressionEncoder.of(avroClass)
+ }
+
+ /**
+ * Provides an Encoder for Avro objects implementing the given schema
+ * @param avroSchema the Schema of the Avro object for which to generate the Encoder
+ * @tparam T the type of the Avro class that implements the Schema, must implement IndexedRecord
+ * @return an Encoder for the given Avro Schema
+ */
+ def of[T <: IndexedRecord](avroSchema: Schema): Encoder[T] = {
+ AvroExpressionEncoder.of(avroSchema)
+ }
+}
+
+object AvroExpressionEncoder {
+ def of[T <: SpecificRecord](avroClass: Class[T]): ExpressionEncoder[T] = {
+ val schema = avroClass.getMethod("getClassSchema").invoke(null).asInstanceOf[Schema]
+ assert(toSqlType(schema).dataType.isInstanceOf[StructType])
+
+ val serializer = AvroTypeInference.serializerFor(avroClass, schema)
+ val deserializer = AvroTypeInference.deserializerFor(schema)
+
+ new ExpressionEncoder[T](
+ toSqlType(schema).dataType.asInstanceOf[StructType],
+ flat = false,
+ serializer.flatten,
+ deserializer = deserializer,
+ ClassTag[T](avroClass))
+ }
+
+ def of[T <: IndexedRecord](schema: Schema): ExpressionEncoder[T] = {
+ assert(toSqlType(schema).dataType.isInstanceOf[StructType])
+
+ val avroClass = Option(ReflectData.get.getClass(schema))
+ .map(_.asSubclass(classOf[SpecificRecord]))
+ .getOrElse(classOf[GenericData.Record])
+ val serializer = AvroTypeInference.serializerFor(avroClass, schema)
+ val deserializer = AvroTypeInference.deserializerFor(schema)
+
+ new ExpressionEncoder[T](
+ toSqlType(schema).dataType.asInstanceOf[StructType],
+ flat = false,
+ serializer.flatten,
+ deserializer,
+ ClassTag[T](avroClass))
+ }
+}
+
+/**
+ * Utilities for providing Avro object serializers and deserializers
+ */
+private object AvroTypeInference {
+ /**
+ * Translates an Avro Schema type to a proper SQL DataType. The Java Objects that back data in
+ * generated Generic and Specific records sometimes do not align with those suggested by Avro
+ * ReflectData, so we infer the proper SQL DataType to serialize and deserialize based on
+ * nullability and the wrapping Schema type.
+ */
+ private def inferExternalType(avroSchema: Schema): DataType = {
+ toSqlType(avroSchema) match {
+ // the non-nullable primitive types
+ case SchemaType(BooleanType, false) => BooleanType
+ case SchemaType(IntegerType, false) => IntegerType
+ case SchemaType(LongType, false) =>
+ if (avroSchema.getType == UNION) {
+ ObjectType(classOf[java.lang.Number])
+ } else {
+ LongType
+ }
+ case SchemaType(FloatType, false) => FloatType
+ case SchemaType(DoubleType, false) =>
+ if (avroSchema.getType == UNION) {
+ ObjectType(classOf[java.lang.Number])
+ } else {
+ DoubleType
+ }
+ // the nullable primitive types
+ case SchemaType(BooleanType, true) => ObjectType(classOf[java.lang.Boolean])
+ case SchemaType(IntegerType, true) => ObjectType(classOf[java.lang.Integer])
+ case SchemaType(LongType, true) => ObjectType(classOf[java.lang.Long])
+ case SchemaType(FloatType, true) => ObjectType(classOf[java.lang.Float])
+ case SchemaType(DoubleType, true) => ObjectType(classOf[java.lang.Double])
+ // the binary types
+ case SchemaType(BinaryType, _) =>
+ if (avroSchema.getType == FIXED) {
+ Option(ReflectData.get.getClass(avroSchema))
+ .map(ObjectType(_))
+ .getOrElse(ObjectType(classOf[GenericData.Fixed]))
+ } else {
+ ObjectType(classOf[java.nio.ByteBuffer])
+ }
+ // the referenced types
+ case SchemaType(ArrayType(_, _), _) =>
+ ObjectType(classOf[java.util.List[Object]])
+ case SchemaType(StringType, _) =>
+ avroSchema.getType match {
+ case ENUM =>
+ Option(ReflectData.get.getClass(avroSchema))
+ .map(ObjectType(_))
+ .getOrElse(ObjectType(classOf[GenericData.EnumSymbol]))
+ case _ =>
+ ObjectType(classOf[CharSequence])
+ }
+ case SchemaType(StructType(_), _) =>
+ Option(ReflectData.get.getClass(avroSchema))
+ .map(ObjectType(_))
+ .getOrElse(ObjectType(classOf[GenericData.Record]))
+ case SchemaType(MapType(_, _, _), _) =>
+ ObjectType(classOf[java.util.Map[Object, Object]])
+ }
+ }
+
+ /**
+ * Returns an expression that can be used to deserialize an InternalRow to an Avro object of
+ * type `T` that implements IndexedRecord and is compatible with the given Schema
+ */
+ def deserializerFor[T <: IndexedRecord] (avroSchema: Schema): Expression = {
+ deserializerFor(avroSchema, None)
+ }
+
+ private def deserializerFor(avroSchema: Schema, path: Option[Expression]): Expression = {
+ def addToPath(part: String): Expression = path
+ .map(p => UnresolvedExtractValue(p, Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+
+ def getPath: Expression = path.getOrElse(
+ GetColumnByOrdinal(0, inferExternalType(avroSchema)))
+
+ avroSchema.getType match {
+ case BOOLEAN =>
+ NewInstance(
+ classOf[java.lang.Boolean],
+ getPath :: Nil,
+ ObjectType(classOf[java.lang.Boolean]))
+ case INT =>
+ NewInstance(
+ classOf[java.lang.Integer],
+ getPath :: Nil,
+ ObjectType(classOf[java.lang.Integer]))
+ case LONG =>
+ NewInstance(
+ classOf[java.lang.Long],
+ getPath :: Nil,
+ ObjectType(classOf[java.lang.Long]))
+ case FLOAT =>
+ NewInstance(
+ classOf[java.lang.Float],
+ getPath :: Nil,
+ ObjectType(classOf[java.lang.Float]))
+ case DOUBLE =>
+ NewInstance(
+ classOf[java.lang.Double],
+ getPath :: Nil,
+ ObjectType(classOf[java.lang.Double]))
+
+ case BYTES =>
+ StaticInvoke(
+ classOf[java.nio.ByteBuffer],
+ ObjectType(classOf[java.nio.ByteBuffer]),
+ "wrap",
+ getPath :: Nil)
+ case FIXED =>
+ val fixedClass = Option(ReflectData.get.getClass(avroSchema))
+ .getOrElse(classOf[GenericData.Fixed])
+ if (fixedClass == classOf[GenericData.Fixed]) {
+ NewInstance(
+ fixedClass,
+ Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) ::
+ getPath ::
+ Nil,
+ ObjectType(fixedClass))
+ } else {
+ NewInstance(
+ fixedClass,
+ getPath :: Nil,
+ ObjectType(fixedClass))
+ }
+
+ case STRING =>
+ Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+ case ENUM =>
+ val enumClass = Option(ReflectData.get.getClass(avroSchema))
+ .getOrElse(classOf[GenericData.EnumSymbol])
+ if (enumClass == classOf[GenericData.EnumSymbol]) {
+ NewInstance(
+ enumClass,
+ Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) ::
+ Invoke(getPath, "toString", ObjectType(classOf[String])) ::
+ Nil,
+ ObjectType(enumClass))
+ } else {
+ StaticInvoke(
+ enumClass,
+ ObjectType(enumClass),
+ "valueOf",
+ Invoke(getPath, "toString", ObjectType(classOf[String])) :: Nil)
+ }
+
+ case ARRAY =>
+ val elementSchema = avroSchema.getElementType
+ val elementType = toSqlType(elementSchema).dataType
+ val array = Invoke(
+ MapObjects(element =>
+ deserializerFor(elementSchema, Some(element)),
+ getPath,
+ elementType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ classOf[java.util.Arrays],
+ ObjectType(classOf[java.util.List[Object]]),
+ "asList",
+ array :: Nil)
+
+ case MAP =>
+ val valueSchema = avroSchema.getValueType
+ val valueType = inferExternalType(valueSchema) match {
+ case t if t == ObjectType(classOf[java.lang.CharSequence]) =>
+ StringType
+ case other => other
+ }
+
+ val keyData = Invoke(
+ MapObjects(
+ p => deserializerFor(Schema.create(STRING), Some(p)),
+ Invoke(getPath, "keyArray", ArrayType(StringType)),
+ StringType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+ val valueData = Invoke(
+ MapObjects(
+ p => deserializerFor(valueSchema, Some(p)),
+ Invoke(getPath, "valueArray", ArrayType(valueType)),
+ valueType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ ArrayBasedMapData.getClass,
+ ObjectType(classOf[JMap[_, _]]),
+ "toJavaMap",
+ keyData :: valueData :: Nil)
+
+ case UNION =>
+ val (resolvedSchema, _) = resolveUnionType(avroSchema)
+ if (resolvedSchema.getType == RECORD &&
+ avroSchema.getTypes.asScala.filterNot(_.getType == NULL).length > 1) {
+ // A Union resolved to a record that originally had more than 1 type when filtered
+ // of its nulls must be complex
+ val bottom = Literal.create(null, ObjectType(classOf[Object])).asInstanceOf[Expression]
+
+ resolvedSchema.getFields.foldLeft(bottom) { (tree: Expression, field: Schema.Field) =>
+ val fieldValue = ObjectCast(
+ deserializerFor(field.schema, Some(addToPath(field.name))),
+ ObjectType(classOf[Object]))
+
+ If(IsNull(fieldValue), tree, fieldValue)
+ }
+ } else {
+ deserializerFor(resolvedSchema, path)
+ }
+
+ case RECORD =>
+ val args = avroSchema.getFields.map { field =>
+ val position = Literal(field.pos)
+ val argument = deserializerFor(field.schema, Some(addToPath(field.name)))
+ (position, argument)
+ }.toList
+
+ val recordClass = Option(ReflectData.get.getClass(avroSchema))
+ .getOrElse(classOf[GenericData.Record])
+ val newInstance = if (recordClass == classOf[GenericData.Record]) {
+ NewInstance(
+ recordClass,
+ Literal.fromObject(avroSchema, ObjectType(classOf[Schema])) :: Nil,
+ ObjectType(recordClass))
+ } else {
+ NewInstance(
+ recordClass,
+ Nil,
+ ObjectType(recordClass))
+ }
+
+ val result = InitializeAvroObject(newInstance, args)
+
+ if (path.nonEmpty) {
+ If(IsNull(getPath),
+ Literal.create(null, ObjectType(recordClass)),
+ result)
+ } else {
+ result
+ }
+
+ case NULL =>
+ /*
+ * Encountering NULL at this level implies it was the type of a Field, which should never
+ * be the case
+ */
+ throw new IncompatibleSchemaException("Null type should only be used in Union types")
+ }
+ }
+
+ /**
+ * Returns an expression that can be used to serialize an Avro object with a class of type `T`
+ * that is compatible with the given Schema to an InternalRow
+ */
+ def serializerFor[T <: IndexedRecord](avroClass: Class[T], avroSchema: Schema):
+ CreateNamedStruct = {
+ val inputObject = BoundReference(0, ObjectType(avroClass), nullable = true)
+ serializerFor(inputObject, avroSchema, topLevel = true).asInstanceOf[CreateNamedStruct]
+ }
+
+ def serializerFor(
+ inputObject: Expression,
+ avroSchema: Schema,
+ topLevel: Boolean = false): Expression = {
+
+ def toCatalystArray(inputObject: Expression, schema: Schema): Expression = {
+ val elementType = inferExternalType(schema)
+
+ if (elementType.isInstanceOf[ObjectType]) {
+ MapObjects(element =>
+ serializerFor(element, schema),
+ Invoke(
+ inputObject,
+ "toArray",
+ ObjectType(classOf[Array[Object]])),
+ elementType)
+ } else {
+ NewInstance(
+ classOf[GenericArrayData],
+ inputObject :: Nil,
+ dataType = ArrayType(elementType, containsNull = false))
+ }
+ }
+
+ def toCatalystMap(inputObject: Expression, schema: Schema): Expression = {
+ val valueSchema = schema.getValueType
+ val valueType = inferExternalType(valueSchema)
+
+ ExternalMapToCatalyst(
+ inputObject,
+ ObjectType(classOf[org.apache.avro.util.Utf8]),
+ serializerFor(_, Schema.create(STRING)),
+ valueType,
+ serializerFor(_, valueSchema))
+ }
+
+ if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+ inputObject
+ } else {
+ avroSchema.getType match {
+ case BOOLEAN =>
+ Invoke(inputObject, "booleanValue", BooleanType)
+ case INT =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case LONG =>
+ Invoke(inputObject, "longValue", LongType)
+ case FLOAT =>
+ Invoke(inputObject, "floatValue", FloatType)
+ case DOUBLE =>
+ Invoke(inputObject, "doubleValue", DoubleType)
+
+ case BYTES =>
+ Invoke(inputObject, "array", BinaryType)
+ case FIXED =>
+ Invoke(inputObject, "bytes", BinaryType)
+
+ case STRING =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil)
+
+ case ENUM =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ Invoke(inputObject, "toString", ObjectType(classOf[java.lang.String])) :: Nil)
+
+ case ARRAY =>
+ val elementSchema = avroSchema.getElementType
+ toCatalystArray(inputObject, elementSchema)
+
+ case MAP =>
+ toCatalystMap(inputObject, avroSchema)
+
+ case UNION =>
+ val unionWithoutNulls = Schema.createUnion(
+ avroSchema.getTypes.asScala.filterNot(_.getType == NULL))
+ val (resolvedSchema, nullable) = resolveUnionType(avroSchema)
+ if (resolvedSchema.getType == RECORD && unionWithoutNulls.getTypes.length > 1) {
+ // A Union resolved to a record that originally had more than 1 type when filtered
+ // of its nulls must be complex
+ val complexStruct = CreateNamedStruct(
+ resolvedSchema.getFields.zipWithIndex.flatMap { case (field, index) =>
+ val unionIndex = StaticInvoke(
+ classOf[GenericData],
+ IntegerType,
+ "get().resolveUnion",
+ Literal.fromObject(
+ unionWithoutNulls,
+ ObjectType(classOf[Schema])) :: inputObject :: Nil)
+
+ val fieldValue = If(EqualTo(Literal(index), unionIndex),
+ serializerFor(
+ ObjectCast(
+ inputObject,
+ inferExternalType(field.schema())),
+ field.schema),
+ Literal.create(null, toSqlType(field.schema()).dataType))
+
+ Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil})
+
+ complexStruct
+
+ } else {
+ if (nullable) {
+ serializerFor(inputObject, resolvedSchema)
+ } else {
+ serializerFor(
+ AssertNotNull(inputObject, Seq(avroSchema.getTypes.toString)),
+ resolvedSchema)
+ }
+ }
+
+ case RECORD =>
+ val createStruct = CreateNamedStruct(
+ avroSchema.getFields.flatMap { field =>
+ val fieldValue = Invoke(
+ inputObject,
+ "get",
+ inferExternalType(field.schema),
+ Literal(field.pos) :: Nil)
+ Literal(field.name) :: serializerFor(fieldValue, field.schema) :: Nil})
+ if (topLevel) {
+ createStruct
+ } else {
+ If(IsNull(inputObject),
+ Literal.create(null, createStruct.dataType),
+ createStruct)
+ }
+
+ case NULL =>
+ /*
+ * Encountering NULL at this level implies it was the type of a Field, which should never
+ * be the case
+ */
+ throw new IncompatibleSchemaException("Null type should only be used in Union types")
+ }
+ }
+ }
+
+ /**
+ * Initializes an Avro Record instance (that implements the IndexedRecord interface) by calling
+ * the `put` method on a the Record instance with the provided position and value arguments
+ * @param objectInstance an expression that will evaluate to the Record instance
+ * @param args a sequence of expression pairs that will respectively evaluate to the index of
+ * the record in which to insert, and the argument value to insert
+ */
+ private case class InitializeAvroObject(
+ objectInstance: Expression,
+ args: List[(Expression, Expression)]) extends Expression with NonSQLExpression {
+
+ override def nullable: Boolean = objectInstance.nullable
+ override def children: Seq[Expression] = objectInstance +: args.map { case (_, v) => v }
+ override def dataType: DataType = objectInstance.dataType
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val instanceGen = objectInstance.genCode(ctx)
+
+ val avroInstance = ctx.freshName("avroObject")
+ val avroInstanceJavaType = ctx.javaType(objectInstance.dataType)
+ ctx.addMutableState(avroInstanceJavaType, avroInstance, "")
+
+ val initialize = args.map {
+ case (posExpr, argExpr) =>
+ val posGen = posExpr.genCode(ctx)
+ val argGen = argExpr.genCode(ctx)
+ s"""
+ ${posGen.code}
+ ${argGen.code}
+ $avroInstance.put(${posGen.value}, ${argGen.value});
+ """
+ }
+
+ val initExpressions = ctx.splitExpressions(ctx.INPUT_ROW, initialize)
+ val code =
+ s"""
+ ${instanceGen.code}
+ $avroInstance = ${instanceGen.value};
+ if (!${instanceGen.isNull}) {
+ $initExpressions
+ }
+ """
+ ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
+ }
+ }
+
+ /**
+ * Casts an expression to another object.
+ *
+ * @param value The value to cast
+ * @param resultType The type the value should be cast to.
+ */
+ private case class ObjectCast(
+ value : Expression,
+ resultType: DataType) extends Expression with NonSQLExpression {
+
+ override def nullable: Boolean = value.nullable
+ override def dataType: DataType = resultType
+ override def children: Seq[Expression] = value :: Nil
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+
+ val javaType = ctx.javaType(resultType)
+ val obj = value.genCode(ctx)
+
+ val code = s"""
+ ${obj.code}
+ final $javaType ${ev.value} = ($javaType) ${obj.value};
+ """
+
+ ev.copy(code = code, isNull = obj.isNull)
+ }
+ }
+}
diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
index aa634d4c..1b8bc450 100644
--- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
+++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
@@ -18,13 +18,11 @@ package com.databricks.spark.avro
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
-
import org.apache.avro.generic.GenericData.Fixed
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.SchemaBuilder._
import org.apache.avro.Schema.Type._
-
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
@@ -74,38 +72,50 @@ object SchemaConverters {
nullable = false)
case UNION =>
- if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
- // In case of a union with null, eliminate it and make a recursive call
- val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
- if (remainingUnionTypes.size == 1) {
- toSqlType(remainingUnionTypes.head).copy(nullable = true)
- } else {
- toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
- }
- } else avroSchema.getTypes.asScala.map(_.getType) match {
- case Seq(t1) =>
- toSqlType(avroSchema.getTypes.get(0))
- case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
- SchemaType(LongType, nullable = false)
- case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
- SchemaType(DoubleType, nullable = false)
- case _ =>
- // Convert complex unions to struct types where field names are member0, member1, etc.
- // This is consistent with the behavior when converting between Avro and Parquet.
- val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
- case (s, i) =>
- val schemaType = toSqlType(s)
- // All fields are nullable because only one of them is set at a time
- StructField(s"member$i", schemaType.dataType, nullable = true)
- }
-
- SchemaType(StructType(fields), nullable = false)
+ resolveUnionType(avroSchema) match {
+ case (schema, nullable) => toSqlType(schema).copy(nullable = nullable)
}
case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
}
}
+ /**
+ * Resolves an avro UNION type to an SQL-compatible avro type. Converts complex unions to records
+ * if necessary.
+ */
+ def resolveUnionType(avroSchema: Schema, nullable: Boolean = false): (Schema, Boolean) = {
+ if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
+ // In case of a union with null, eliminate it, and make a recursive call
+ val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
+ if (remainingUnionTypes.size == 1) {
+ (remainingUnionTypes.head, true)
+ } else {
+ resolveUnionType(Schema.createUnion(remainingUnionTypes.asJava), nullable = true)
+ }
+ } else avroSchema.getTypes.asScala.map(_.getType) match {
+ case Seq(t1) =>
+ (avroSchema.getTypes.get(0), true)
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ (Schema.create(LONG), false)
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ (Schema.create(DOUBLE), false)
+ case _ =>
+ // Convert complex unions to records where field names are member0, member1, etc.
+ // This is consistent with the behavior when converting between Avro and Parquet.
+ val record = SchemaBuilder.record(avroSchema.getName).fields()
+ avroSchema.getTypes.asScala.zipWithIndex.foreach {
+ case (s, i) =>
+ // All fields are nullable because only one of them is set at a time
+ record.name(s"member$i").`type`(SchemaBuilder.unionOf()
+ .`type`(Schema.create(NULL)).and
+ .`type`(s).endUnion())
+ .withDefault(null)
+ }
+ (record.endRecord(), false)
+ }
+ }
+
/**
* This function converts sparkSQL StructType into avro schema. This method uses two other
* converter methods in order to do the conversion.
diff --git a/src/test/java/com/databricks/spark/avro/SimpleEnums.java b/src/test/java/com/databricks/spark/avro/SimpleEnums.java
new file mode 100644
index 00000000..5989c620
--- /dev/null
+++ b/src/test/java/com/databricks/spark/avro/SimpleEnums.java
@@ -0,0 +1,13 @@
+/**
+ * Autogenerated by Avro
+ *
+ * DO NOT EDIT DIRECTLY
+ */
+package com.databricks.spark.avro;
+@SuppressWarnings("all")
+@org.apache.avro.specific.AvroGenerated
+public enum SimpleEnums {
+ SPADES, HEARTS, DIAMONDS, CLUBS ;
+ public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"namespace\":\"com.databricks.spark.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}");
+ public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
+}
diff --git a/src/test/java/com/databricks/spark/avro/SimpleFixed.java b/src/test/java/com/databricks/spark/avro/SimpleFixed.java
new file mode 100644
index 00000000..8318b65a
--- /dev/null
+++ b/src/test/java/com/databricks/spark/avro/SimpleFixed.java
@@ -0,0 +1,23 @@
+/**
+ * Autogenerated by Avro
+ *
+ * DO NOT EDIT DIRECTLY
+ */
+package com.databricks.spark.avro;
+@SuppressWarnings("all")
+@org.apache.avro.specific.FixedSize(16)
+@org.apache.avro.specific.AvroGenerated
+public class SimpleFixed extends org.apache.avro.specific.SpecificFixed {
+ public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"namespace\":\"com.databricks.spark.avro\",\"size\":16}");
+ public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
+
+ /** Creates a new SimpleFixed */
+ public SimpleFixed() {
+ super();
+ }
+
+ /** Creates a new SimpleFixed with the given bytes */
+ public SimpleFixed(byte[] bytes) {
+ super(bytes);
+ }
+}
diff --git a/src/test/java/com/databricks/spark/avro/SimpleRecord.java b/src/test/java/com/databricks/spark/avro/SimpleRecord.java
new file mode 100644
index 00000000..a36161ed
--- /dev/null
+++ b/src/test/java/com/databricks/spark/avro/SimpleRecord.java
@@ -0,0 +1,195 @@
+/**
+ * Autogenerated by Avro
+ *
+ * DO NOT EDIT DIRECTLY
+ */
+package com.databricks.spark.avro;
+@SuppressWarnings("all")
+@org.apache.avro.specific.AvroGenerated
+public class SimpleRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord {
+ public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"SimpleRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]}");
+ public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
+ @Deprecated public int nested1;
+ @Deprecated public java.lang.CharSequence nested2;
+
+ /**
+ * Default constructor. Note that this does not initialize fields
+ * to their default values from the schema. If that is desired then
+ * one should use newBuilder().
+ */
+ public SimpleRecord() {}
+
+ /**
+ * All-args constructor.
+ */
+ public SimpleRecord(java.lang.Integer nested1, java.lang.CharSequence nested2) {
+ this.nested1 = nested1;
+ this.nested2 = nested2;
+ }
+
+ public org.apache.avro.Schema getSchema() { return SCHEMA$; }
+ // Used by DatumWriter. Applications should not call.
+ public java.lang.Object get(int field$) {
+ switch (field$) {
+ case 0: return nested1;
+ case 1: return nested2;
+ default: throw new org.apache.avro.AvroRuntimeException("Bad index");
+ }
+ }
+ // Used by DatumReader. Applications should not call.
+ @SuppressWarnings(value="unchecked")
+ public void put(int field$, java.lang.Object value$) {
+ switch (field$) {
+ case 0: nested1 = (java.lang.Integer)value$; break;
+ case 1: nested2 = (java.lang.CharSequence)value$; break;
+ default: throw new org.apache.avro.AvroRuntimeException("Bad index");
+ }
+ }
+
+ /**
+ * Gets the value of the 'nested1' field.
+ */
+ public java.lang.Integer getNested1() {
+ return nested1;
+ }
+
+ /**
+ * Sets the value of the 'nested1' field.
+ * @param value the value to set.
+ */
+ public void setNested1(java.lang.Integer value) {
+ this.nested1 = value;
+ }
+
+ /**
+ * Gets the value of the 'nested2' field.
+ */
+ public java.lang.CharSequence getNested2() {
+ return nested2;
+ }
+
+ /**
+ * Sets the value of the 'nested2' field.
+ * @param value the value to set.
+ */
+ public void setNested2(java.lang.CharSequence value) {
+ this.nested2 = value;
+ }
+
+ /** Creates a new SimpleRecord RecordBuilder */
+ public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder() {
+ return new com.databricks.spark.avro.SimpleRecord.Builder();
+ }
+
+ /** Creates a new SimpleRecord RecordBuilder by copying an existing Builder */
+ public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord.Builder other) {
+ return new com.databricks.spark.avro.SimpleRecord.Builder(other);
+ }
+
+ /** Creates a new SimpleRecord RecordBuilder by copying an existing SimpleRecord instance */
+ public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord other) {
+ return new com.databricks.spark.avro.SimpleRecord.Builder(other);
+ }
+
+ /**
+ * RecordBuilder for SimpleRecord instances.
+ */
+ public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase
+ implements org.apache.avro.data.RecordBuilder {
+
+ private int nested1;
+ private java.lang.CharSequence nested2;
+
+ /** Creates a new Builder */
+ private Builder() {
+ super(com.databricks.spark.avro.SimpleRecord.SCHEMA$);
+ }
+
+ /** Creates a Builder by copying an existing Builder */
+ private Builder(com.databricks.spark.avro.SimpleRecord.Builder other) {
+ super(other);
+ if (isValidValue(fields()[0], other.nested1)) {
+ this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1);
+ fieldSetFlags()[0] = true;
+ }
+ if (isValidValue(fields()[1], other.nested2)) {
+ this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2);
+ fieldSetFlags()[1] = true;
+ }
+ }
+
+ /** Creates a Builder by copying an existing SimpleRecord instance */
+ private Builder(com.databricks.spark.avro.SimpleRecord other) {
+ super(com.databricks.spark.avro.SimpleRecord.SCHEMA$);
+ if (isValidValue(fields()[0], other.nested1)) {
+ this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1);
+ fieldSetFlags()[0] = true;
+ }
+ if (isValidValue(fields()[1], other.nested2)) {
+ this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2);
+ fieldSetFlags()[1] = true;
+ }
+ }
+
+ /** Gets the value of the 'nested1' field */
+ public java.lang.Integer getNested1() {
+ return nested1;
+ }
+
+ /** Sets the value of the 'nested1' field */
+ public com.databricks.spark.avro.SimpleRecord.Builder setNested1(int value) {
+ validate(fields()[0], value);
+ this.nested1 = value;
+ fieldSetFlags()[0] = true;
+ return this;
+ }
+
+ /** Checks whether the 'nested1' field has been set */
+ public boolean hasNested1() {
+ return fieldSetFlags()[0];
+ }
+
+ /** Clears the value of the 'nested1' field */
+ public com.databricks.spark.avro.SimpleRecord.Builder clearNested1() {
+ fieldSetFlags()[0] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'nested2' field */
+ public java.lang.CharSequence getNested2() {
+ return nested2;
+ }
+
+ /** Sets the value of the 'nested2' field */
+ public com.databricks.spark.avro.SimpleRecord.Builder setNested2(java.lang.CharSequence value) {
+ validate(fields()[1], value);
+ this.nested2 = value;
+ fieldSetFlags()[1] = true;
+ return this;
+ }
+
+ /** Checks whether the 'nested2' field has been set */
+ public boolean hasNested2() {
+ return fieldSetFlags()[1];
+ }
+
+ /** Clears the value of the 'nested2' field */
+ public com.databricks.spark.avro.SimpleRecord.Builder clearNested2() {
+ nested2 = null;
+ fieldSetFlags()[1] = false;
+ return this;
+ }
+
+ @Override
+ public SimpleRecord build() {
+ try {
+ SimpleRecord record = new SimpleRecord();
+ record.nested1 = fieldSetFlags()[0] ? this.nested1 : (java.lang.Integer) defaultValue(fields()[0]);
+ record.nested2 = fieldSetFlags()[1] ? this.nested2 : (java.lang.CharSequence) defaultValue(fields()[1]);
+ return record;
+ } catch (Exception e) {
+ throw new org.apache.avro.AvroRuntimeException(e);
+ }
+ }
+ }
+}
diff --git a/src/test/java/com/databricks/spark/avro/TestRecord.java b/src/test/java/com/databricks/spark/avro/TestRecord.java
new file mode 100644
index 00000000..dd323bb7
--- /dev/null
+++ b/src/test/java/com/databricks/spark/avro/TestRecord.java
@@ -0,0 +1,893 @@
+/**
+ * Autogenerated by Avro
+ *
+ * DO NOT EDIT DIRECTLY
+ */
+package com.databricks.spark.avro;
+@SuppressWarnings("all")
+@org.apache.avro.specific.AvroGenerated
+public class TestRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord {
+ public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"TestRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"boolean\",\"type\":\"boolean\",\"default\":true},{\"name\":\"int\",\"type\":\"int\",\"default\":0},{\"name\":\"long\",\"type\":\"long\",\"default\":0},{\"name\":\"float\",\"type\":\"float\",\"default\":0.0},{\"name\":\"double\",\"type\":\"double\",\"default\":0.0},{\"name\":\"string\",\"type\":\"string\",\"default\":\"value\"},{\"name\":\"bytes\",\"type\":\"bytes\",\"default\":\"ΓΏ\"},{\"name\":\"nested\",\"type\":{\"type\":\"record\",\"name\":\"SimpleRecord\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]},\"default\":{\"nested1\":0,\"nested2\":\"string\"}},{\"name\":\"enum\",\"type\":{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},\"default\":\"SPADES\"},{\"name\":\"fixed\",\"type\":{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"size\":16},\"default\":\"string_length_16\"},{\"name\":\"intArray\",\"type\":{\"type\":\"array\",\"items\":\"int\"},\"default\":[1,2,3]},{\"name\":\"stringArray\",\"type\":{\"type\":\"array\",\"items\":\"string\"},\"default\":[\"a\",\"b\",\"c\"]},{\"name\":\"recordArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleRecord\"},\"default\":[{\"nested1\":0,\"nested2\":\"value\"},{\"nested1\":0,\"nested2\":\"value\"}]},{\"name\":\"enumArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleEnums\"},\"default\":[\"SPADES\",\"HEARTS\",\"SPADES\"]},{\"name\":\"fixedArray\",\"type\":{\"type\":\"array\",\"items\":\"SimpleFixed\"},\"default\":[\"foo\",\"bar\",\"baz\"]}]}");
+ public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
+ @Deprecated public boolean boolean$;
+ @Deprecated public int int$;
+ @Deprecated public long long$;
+ @Deprecated public float float$;
+ @Deprecated public double double$;
+ @Deprecated public java.lang.CharSequence string;
+ @Deprecated public java.nio.ByteBuffer bytes;
+ @Deprecated public com.databricks.spark.avro.SimpleRecord nested;
+ @Deprecated public com.databricks.spark.avro.SimpleEnums enum$;
+ @Deprecated public com.databricks.spark.avro.SimpleFixed fixed;
+ @Deprecated public java.util.List intArray;
+ @Deprecated public java.util.List stringArray;
+ @Deprecated public java.util.List recordArray;
+ @Deprecated public java.util.List enumArray;
+ @Deprecated public java.util.List fixedArray;
+
+ /**
+ * Default constructor. Note that this does not initialize fields
+ * to their default values from the schema. If that is desired then
+ * one should use newBuilder().
+ */
+ public TestRecord() {}
+
+ /**
+ * All-args constructor.
+ */
+ public TestRecord(java.lang.Boolean boolean$, java.lang.Integer int$, java.lang.Long long$, java.lang.Float float$, java.lang.Double double$, java.lang.CharSequence string, java.nio.ByteBuffer bytes, com.databricks.spark.avro.SimpleRecord nested, com.databricks.spark.avro.SimpleEnums enum$, com.databricks.spark.avro.SimpleFixed fixed, java.util.List intArray, java.util.List stringArray, java.util.List recordArray, java.util.List enumArray, java.util.List fixedArray) {
+ this.boolean$ = boolean$;
+ this.int$ = int$;
+ this.long$ = long$;
+ this.float$ = float$;
+ this.double$ = double$;
+ this.string = string;
+ this.bytes = bytes;
+ this.nested = nested;
+ this.enum$ = enum$;
+ this.fixed = fixed;
+ this.intArray = intArray;
+ this.stringArray = stringArray;
+ this.recordArray = recordArray;
+ this.enumArray = enumArray;
+ this.fixedArray = fixedArray;
+ }
+
+ public org.apache.avro.Schema getSchema() { return SCHEMA$; }
+ // Used by DatumWriter. Applications should not call.
+ public java.lang.Object get(int field$) {
+ switch (field$) {
+ case 0: return boolean$;
+ case 1: return int$;
+ case 2: return long$;
+ case 3: return float$;
+ case 4: return double$;
+ case 5: return string;
+ case 6: return bytes;
+ case 7: return nested;
+ case 8: return enum$;
+ case 9: return fixed;
+ case 10: return intArray;
+ case 11: return stringArray;
+ case 12: return recordArray;
+ case 13: return enumArray;
+ case 14: return fixedArray;
+ default: throw new org.apache.avro.AvroRuntimeException("Bad index");
+ }
+ }
+ // Used by DatumReader. Applications should not call.
+ @SuppressWarnings(value="unchecked")
+ public void put(int field$, java.lang.Object value$) {
+ switch (field$) {
+ case 0: boolean$ = (java.lang.Boolean)value$; break;
+ case 1: int$ = (java.lang.Integer)value$; break;
+ case 2: long$ = (java.lang.Long)value$; break;
+ case 3: float$ = (java.lang.Float)value$; break;
+ case 4: double$ = (java.lang.Double)value$; break;
+ case 5: string = (java.lang.CharSequence)value$; break;
+ case 6: bytes = (java.nio.ByteBuffer)value$; break;
+ case 7: nested = (com.databricks.spark.avro.SimpleRecord)value$; break;
+ case 8: enum$ = (com.databricks.spark.avro.SimpleEnums)value$; break;
+ case 9: fixed = (com.databricks.spark.avro.SimpleFixed)value$; break;
+ case 10: intArray = (java.util.List)value$; break;
+ case 11: stringArray = (java.util.List)value$; break;
+ case 12: recordArray = (java.util.List)value$; break;
+ case 13: enumArray = (java.util.List)value$; break;
+ case 14: fixedArray = (java.util.List)value$; break;
+ default: throw new org.apache.avro.AvroRuntimeException("Bad index");
+ }
+ }
+
+ /**
+ * Gets the value of the 'boolean$' field.
+ */
+ public java.lang.Boolean getBoolean$() {
+ return boolean$;
+ }
+
+ /**
+ * Sets the value of the 'boolean$' field.
+ * @param value the value to set.
+ */
+ public void setBoolean$(java.lang.Boolean value) {
+ this.boolean$ = value;
+ }
+
+ /**
+ * Gets the value of the 'int$' field.
+ */
+ public java.lang.Integer getInt$() {
+ return int$;
+ }
+
+ /**
+ * Sets the value of the 'int$' field.
+ * @param value the value to set.
+ */
+ public void setInt$(java.lang.Integer value) {
+ this.int$ = value;
+ }
+
+ /**
+ * Gets the value of the 'long$' field.
+ */
+ public java.lang.Long getLong$() {
+ return long$;
+ }
+
+ /**
+ * Sets the value of the 'long$' field.
+ * @param value the value to set.
+ */
+ public void setLong$(java.lang.Long value) {
+ this.long$ = value;
+ }
+
+ /**
+ * Gets the value of the 'float$' field.
+ */
+ public java.lang.Float getFloat$() {
+ return float$;
+ }
+
+ /**
+ * Sets the value of the 'float$' field.
+ * @param value the value to set.
+ */
+ public void setFloat$(java.lang.Float value) {
+ this.float$ = value;
+ }
+
+ /**
+ * Gets the value of the 'double$' field.
+ */
+ public java.lang.Double getDouble$() {
+ return double$;
+ }
+
+ /**
+ * Sets the value of the 'double$' field.
+ * @param value the value to set.
+ */
+ public void setDouble$(java.lang.Double value) {
+ this.double$ = value;
+ }
+
+ /**
+ * Gets the value of the 'string' field.
+ */
+ public java.lang.CharSequence getString() {
+ return string;
+ }
+
+ /**
+ * Sets the value of the 'string' field.
+ * @param value the value to set.
+ */
+ public void setString(java.lang.CharSequence value) {
+ this.string = value;
+ }
+
+ /**
+ * Gets the value of the 'bytes' field.
+ */
+ public java.nio.ByteBuffer getBytes() {
+ return bytes;
+ }
+
+ /**
+ * Sets the value of the 'bytes' field.
+ * @param value the value to set.
+ */
+ public void setBytes(java.nio.ByteBuffer value) {
+ this.bytes = value;
+ }
+
+ /**
+ * Gets the value of the 'nested' field.
+ */
+ public com.databricks.spark.avro.SimpleRecord getNested() {
+ return nested;
+ }
+
+ /**
+ * Sets the value of the 'nested' field.
+ * @param value the value to set.
+ */
+ public void setNested(com.databricks.spark.avro.SimpleRecord value) {
+ this.nested = value;
+ }
+
+ /**
+ * Gets the value of the 'enum$' field.
+ */
+ public com.databricks.spark.avro.SimpleEnums getEnum$() {
+ return enum$;
+ }
+
+ /**
+ * Sets the value of the 'enum$' field.
+ * @param value the value to set.
+ */
+ public void setEnum$(com.databricks.spark.avro.SimpleEnums value) {
+ this.enum$ = value;
+ }
+
+ /**
+ * Gets the value of the 'fixed' field.
+ */
+ public com.databricks.spark.avro.SimpleFixed getFixed() {
+ return fixed;
+ }
+
+ /**
+ * Sets the value of the 'fixed' field.
+ * @param value the value to set.
+ */
+ public void setFixed(com.databricks.spark.avro.SimpleFixed value) {
+ this.fixed = value;
+ }
+
+ /**
+ * Gets the value of the 'intArray' field.
+ */
+ public java.util.List getIntArray() {
+ return intArray;
+ }
+
+ /**
+ * Sets the value of the 'intArray' field.
+ * @param value the value to set.
+ */
+ public void setIntArray(java.util.List value) {
+ this.intArray = value;
+ }
+
+ /**
+ * Gets the value of the 'stringArray' field.
+ */
+ public java.util.List getStringArray() {
+ return stringArray;
+ }
+
+ /**
+ * Sets the value of the 'stringArray' field.
+ * @param value the value to set.
+ */
+ public void setStringArray(java.util.List value) {
+ this.stringArray = value;
+ }
+
+ /**
+ * Gets the value of the 'recordArray' field.
+ */
+ public java.util.List getRecordArray() {
+ return recordArray;
+ }
+
+ /**
+ * Sets the value of the 'recordArray' field.
+ * @param value the value to set.
+ */
+ public void setRecordArray(java.util.List value) {
+ this.recordArray = value;
+ }
+
+ /**
+ * Gets the value of the 'enumArray' field.
+ */
+ public java.util.List getEnumArray() {
+ return enumArray;
+ }
+
+ /**
+ * Sets the value of the 'enumArray' field.
+ * @param value the value to set.
+ */
+ public void setEnumArray(java.util.List value) {
+ this.enumArray = value;
+ }
+
+ /**
+ * Gets the value of the 'fixedArray' field.
+ */
+ public java.util.List getFixedArray() {
+ return fixedArray;
+ }
+
+ /**
+ * Sets the value of the 'fixedArray' field.
+ * @param value the value to set.
+ */
+ public void setFixedArray(java.util.List value) {
+ this.fixedArray = value;
+ }
+
+ /** Creates a new TestRecord RecordBuilder */
+ public static com.databricks.spark.avro.TestRecord.Builder newBuilder() {
+ return new com.databricks.spark.avro.TestRecord.Builder();
+ }
+
+ /** Creates a new TestRecord RecordBuilder by copying an existing Builder */
+ public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord.Builder other) {
+ return new com.databricks.spark.avro.TestRecord.Builder(other);
+ }
+
+ /** Creates a new TestRecord RecordBuilder by copying an existing TestRecord instance */
+ public static com.databricks.spark.avro.TestRecord.Builder newBuilder(com.databricks.spark.avro.TestRecord other) {
+ return new com.databricks.spark.avro.TestRecord.Builder(other);
+ }
+
+ /**
+ * RecordBuilder for TestRecord instances.
+ */
+ public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase
+ implements org.apache.avro.data.RecordBuilder {
+
+ private boolean boolean$;
+ private int int$;
+ private long long$;
+ private float float$;
+ private double double$;
+ private java.lang.CharSequence string;
+ private java.nio.ByteBuffer bytes;
+ private com.databricks.spark.avro.SimpleRecord nested;
+ private com.databricks.spark.avro.SimpleEnums enum$;
+ private com.databricks.spark.avro.SimpleFixed fixed;
+ private java.util.List intArray;
+ private java.util.List stringArray;
+ private java.util.List recordArray;
+ private java.util.List enumArray;
+ private java.util.List fixedArray;
+
+ /** Creates a new Builder */
+ private Builder() {
+ super(com.databricks.spark.avro.TestRecord.SCHEMA$);
+ }
+
+ /** Creates a Builder by copying an existing Builder */
+ private Builder(com.databricks.spark.avro.TestRecord.Builder other) {
+ super(other);
+ if (isValidValue(fields()[0], other.boolean$)) {
+ this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$);
+ fieldSetFlags()[0] = true;
+ }
+ if (isValidValue(fields()[1], other.int$)) {
+ this.int$ = data().deepCopy(fields()[1].schema(), other.int$);
+ fieldSetFlags()[1] = true;
+ }
+ if (isValidValue(fields()[2], other.long$)) {
+ this.long$ = data().deepCopy(fields()[2].schema(), other.long$);
+ fieldSetFlags()[2] = true;
+ }
+ if (isValidValue(fields()[3], other.float$)) {
+ this.float$ = data().deepCopy(fields()[3].schema(), other.float$);
+ fieldSetFlags()[3] = true;
+ }
+ if (isValidValue(fields()[4], other.double$)) {
+ this.double$ = data().deepCopy(fields()[4].schema(), other.double$);
+ fieldSetFlags()[4] = true;
+ }
+ if (isValidValue(fields()[5], other.string)) {
+ this.string = data().deepCopy(fields()[5].schema(), other.string);
+ fieldSetFlags()[5] = true;
+ }
+ if (isValidValue(fields()[6], other.bytes)) {
+ this.bytes = data().deepCopy(fields()[6].schema(), other.bytes);
+ fieldSetFlags()[6] = true;
+ }
+ if (isValidValue(fields()[7], other.nested)) {
+ this.nested = data().deepCopy(fields()[7].schema(), other.nested);
+ fieldSetFlags()[7] = true;
+ }
+ if (isValidValue(fields()[8], other.enum$)) {
+ this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$);
+ fieldSetFlags()[8] = true;
+ }
+ if (isValidValue(fields()[9], other.fixed)) {
+ this.fixed = data().deepCopy(fields()[9].schema(), other.fixed);
+ fieldSetFlags()[9] = true;
+ }
+ if (isValidValue(fields()[10], other.intArray)) {
+ this.intArray = data().deepCopy(fields()[10].schema(), other.intArray);
+ fieldSetFlags()[10] = true;
+ }
+ if (isValidValue(fields()[11], other.stringArray)) {
+ this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray);
+ fieldSetFlags()[11] = true;
+ }
+ if (isValidValue(fields()[12], other.recordArray)) {
+ this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray);
+ fieldSetFlags()[12] = true;
+ }
+ if (isValidValue(fields()[13], other.enumArray)) {
+ this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray);
+ fieldSetFlags()[13] = true;
+ }
+ if (isValidValue(fields()[14], other.fixedArray)) {
+ this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray);
+ fieldSetFlags()[14] = true;
+ }
+ }
+
+ /** Creates a Builder by copying an existing TestRecord instance */
+ private Builder(com.databricks.spark.avro.TestRecord other) {
+ super(com.databricks.spark.avro.TestRecord.SCHEMA$);
+ if (isValidValue(fields()[0], other.boolean$)) {
+ this.boolean$ = data().deepCopy(fields()[0].schema(), other.boolean$);
+ fieldSetFlags()[0] = true;
+ }
+ if (isValidValue(fields()[1], other.int$)) {
+ this.int$ = data().deepCopy(fields()[1].schema(), other.int$);
+ fieldSetFlags()[1] = true;
+ }
+ if (isValidValue(fields()[2], other.long$)) {
+ this.long$ = data().deepCopy(fields()[2].schema(), other.long$);
+ fieldSetFlags()[2] = true;
+ }
+ if (isValidValue(fields()[3], other.float$)) {
+ this.float$ = data().deepCopy(fields()[3].schema(), other.float$);
+ fieldSetFlags()[3] = true;
+ }
+ if (isValidValue(fields()[4], other.double$)) {
+ this.double$ = data().deepCopy(fields()[4].schema(), other.double$);
+ fieldSetFlags()[4] = true;
+ }
+ if (isValidValue(fields()[5], other.string)) {
+ this.string = data().deepCopy(fields()[5].schema(), other.string);
+ fieldSetFlags()[5] = true;
+ }
+ if (isValidValue(fields()[6], other.bytes)) {
+ this.bytes = data().deepCopy(fields()[6].schema(), other.bytes);
+ fieldSetFlags()[6] = true;
+ }
+ if (isValidValue(fields()[7], other.nested)) {
+ this.nested = data().deepCopy(fields()[7].schema(), other.nested);
+ fieldSetFlags()[7] = true;
+ }
+ if (isValidValue(fields()[8], other.enum$)) {
+ this.enum$ = data().deepCopy(fields()[8].schema(), other.enum$);
+ fieldSetFlags()[8] = true;
+ }
+ if (isValidValue(fields()[9], other.fixed)) {
+ this.fixed = data().deepCopy(fields()[9].schema(), other.fixed);
+ fieldSetFlags()[9] = true;
+ }
+ if (isValidValue(fields()[10], other.intArray)) {
+ this.intArray = data().deepCopy(fields()[10].schema(), other.intArray);
+ fieldSetFlags()[10] = true;
+ }
+ if (isValidValue(fields()[11], other.stringArray)) {
+ this.stringArray = data().deepCopy(fields()[11].schema(), other.stringArray);
+ fieldSetFlags()[11] = true;
+ }
+ if (isValidValue(fields()[12], other.recordArray)) {
+ this.recordArray = data().deepCopy(fields()[12].schema(), other.recordArray);
+ fieldSetFlags()[12] = true;
+ }
+ if (isValidValue(fields()[13], other.enumArray)) {
+ this.enumArray = data().deepCopy(fields()[13].schema(), other.enumArray);
+ fieldSetFlags()[13] = true;
+ }
+ if (isValidValue(fields()[14], other.fixedArray)) {
+ this.fixedArray = data().deepCopy(fields()[14].schema(), other.fixedArray);
+ fieldSetFlags()[14] = true;
+ }
+ }
+
+ /** Gets the value of the 'boolean$' field */
+ public java.lang.Boolean getBoolean$() {
+ return boolean$;
+ }
+
+ /** Sets the value of the 'boolean$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setBoolean$(boolean value) {
+ validate(fields()[0], value);
+ this.boolean$ = value;
+ fieldSetFlags()[0] = true;
+ return this;
+ }
+
+ /** Checks whether the 'boolean$' field has been set */
+ public boolean hasBoolean$() {
+ return fieldSetFlags()[0];
+ }
+
+ /** Clears the value of the 'boolean$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearBoolean$() {
+ fieldSetFlags()[0] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'int$' field */
+ public java.lang.Integer getInt$() {
+ return int$;
+ }
+
+ /** Sets the value of the 'int$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setInt$(int value) {
+ validate(fields()[1], value);
+ this.int$ = value;
+ fieldSetFlags()[1] = true;
+ return this;
+ }
+
+ /** Checks whether the 'int$' field has been set */
+ public boolean hasInt$() {
+ return fieldSetFlags()[1];
+ }
+
+ /** Clears the value of the 'int$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearInt$() {
+ fieldSetFlags()[1] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'long$' field */
+ public java.lang.Long getLong$() {
+ return long$;
+ }
+
+ /** Sets the value of the 'long$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setLong$(long value) {
+ validate(fields()[2], value);
+ this.long$ = value;
+ fieldSetFlags()[2] = true;
+ return this;
+ }
+
+ /** Checks whether the 'long$' field has been set */
+ public boolean hasLong$() {
+ return fieldSetFlags()[2];
+ }
+
+ /** Clears the value of the 'long$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearLong$() {
+ fieldSetFlags()[2] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'float$' field */
+ public java.lang.Float getFloat$() {
+ return float$;
+ }
+
+ /** Sets the value of the 'float$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setFloat$(float value) {
+ validate(fields()[3], value);
+ this.float$ = value;
+ fieldSetFlags()[3] = true;
+ return this;
+ }
+
+ /** Checks whether the 'float$' field has been set */
+ public boolean hasFloat$() {
+ return fieldSetFlags()[3];
+ }
+
+ /** Clears the value of the 'float$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearFloat$() {
+ fieldSetFlags()[3] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'double$' field */
+ public java.lang.Double getDouble$() {
+ return double$;
+ }
+
+ /** Sets the value of the 'double$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setDouble$(double value) {
+ validate(fields()[4], value);
+ this.double$ = value;
+ fieldSetFlags()[4] = true;
+ return this;
+ }
+
+ /** Checks whether the 'double$' field has been set */
+ public boolean hasDouble$() {
+ return fieldSetFlags()[4];
+ }
+
+ /** Clears the value of the 'double$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearDouble$() {
+ fieldSetFlags()[4] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'string' field */
+ public java.lang.CharSequence getString() {
+ return string;
+ }
+
+ /** Sets the value of the 'string' field */
+ public com.databricks.spark.avro.TestRecord.Builder setString(java.lang.CharSequence value) {
+ validate(fields()[5], value);
+ this.string = value;
+ fieldSetFlags()[5] = true;
+ return this;
+ }
+
+ /** Checks whether the 'string' field has been set */
+ public boolean hasString() {
+ return fieldSetFlags()[5];
+ }
+
+ /** Clears the value of the 'string' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearString() {
+ string = null;
+ fieldSetFlags()[5] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'bytes' field */
+ public java.nio.ByteBuffer getBytes() {
+ return bytes;
+ }
+
+ /** Sets the value of the 'bytes' field */
+ public com.databricks.spark.avro.TestRecord.Builder setBytes(java.nio.ByteBuffer value) {
+ validate(fields()[6], value);
+ this.bytes = value;
+ fieldSetFlags()[6] = true;
+ return this;
+ }
+
+ /** Checks whether the 'bytes' field has been set */
+ public boolean hasBytes() {
+ return fieldSetFlags()[6];
+ }
+
+ /** Clears the value of the 'bytes' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearBytes() {
+ bytes = null;
+ fieldSetFlags()[6] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'nested' field */
+ public com.databricks.spark.avro.SimpleRecord getNested() {
+ return nested;
+ }
+
+ /** Sets the value of the 'nested' field */
+ public com.databricks.spark.avro.TestRecord.Builder setNested(com.databricks.spark.avro.SimpleRecord value) {
+ validate(fields()[7], value);
+ this.nested = value;
+ fieldSetFlags()[7] = true;
+ return this;
+ }
+
+ /** Checks whether the 'nested' field has been set */
+ public boolean hasNested() {
+ return fieldSetFlags()[7];
+ }
+
+ /** Clears the value of the 'nested' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearNested() {
+ nested = null;
+ fieldSetFlags()[7] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'enum$' field */
+ public com.databricks.spark.avro.SimpleEnums getEnum$() {
+ return enum$;
+ }
+
+ /** Sets the value of the 'enum$' field */
+ public com.databricks.spark.avro.TestRecord.Builder setEnum$(com.databricks.spark.avro.SimpleEnums value) {
+ validate(fields()[8], value);
+ this.enum$ = value;
+ fieldSetFlags()[8] = true;
+ return this;
+ }
+
+ /** Checks whether the 'enum$' field has been set */
+ public boolean hasEnum$() {
+ return fieldSetFlags()[8];
+ }
+
+ /** Clears the value of the 'enum$' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearEnum$() {
+ enum$ = null;
+ fieldSetFlags()[8] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'fixed' field */
+ public com.databricks.spark.avro.SimpleFixed getFixed() {
+ return fixed;
+ }
+
+ /** Sets the value of the 'fixed' field */
+ public com.databricks.spark.avro.TestRecord.Builder setFixed(com.databricks.spark.avro.SimpleFixed value) {
+ validate(fields()[9], value);
+ this.fixed = value;
+ fieldSetFlags()[9] = true;
+ return this;
+ }
+
+ /** Checks whether the 'fixed' field has been set */
+ public boolean hasFixed() {
+ return fieldSetFlags()[9];
+ }
+
+ /** Clears the value of the 'fixed' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearFixed() {
+ fixed = null;
+ fieldSetFlags()[9] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'intArray' field */
+ public java.util.List getIntArray() {
+ return intArray;
+ }
+
+ /** Sets the value of the 'intArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder setIntArray(java.util.List value) {
+ validate(fields()[10], value);
+ this.intArray = value;
+ fieldSetFlags()[10] = true;
+ return this;
+ }
+
+ /** Checks whether the 'intArray' field has been set */
+ public boolean hasIntArray() {
+ return fieldSetFlags()[10];
+ }
+
+ /** Clears the value of the 'intArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearIntArray() {
+ intArray = null;
+ fieldSetFlags()[10] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'stringArray' field */
+ public java.util.List getStringArray() {
+ return stringArray;
+ }
+
+ /** Sets the value of the 'stringArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder setStringArray(java.util.List value) {
+ validate(fields()[11], value);
+ this.stringArray = value;
+ fieldSetFlags()[11] = true;
+ return this;
+ }
+
+ /** Checks whether the 'stringArray' field has been set */
+ public boolean hasStringArray() {
+ return fieldSetFlags()[11];
+ }
+
+ /** Clears the value of the 'stringArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearStringArray() {
+ stringArray = null;
+ fieldSetFlags()[11] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'recordArray' field */
+ public java.util.List getRecordArray() {
+ return recordArray;
+ }
+
+ /** Sets the value of the 'recordArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder setRecordArray(java.util.List value) {
+ validate(fields()[12], value);
+ this.recordArray = value;
+ fieldSetFlags()[12] = true;
+ return this;
+ }
+
+ /** Checks whether the 'recordArray' field has been set */
+ public boolean hasRecordArray() {
+ return fieldSetFlags()[12];
+ }
+
+ /** Clears the value of the 'recordArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearRecordArray() {
+ recordArray = null;
+ fieldSetFlags()[12] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'enumArray' field */
+ public java.util.List getEnumArray() {
+ return enumArray;
+ }
+
+ /** Sets the value of the 'enumArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder setEnumArray(java.util.List value) {
+ validate(fields()[13], value);
+ this.enumArray = value;
+ fieldSetFlags()[13] = true;
+ return this;
+ }
+
+ /** Checks whether the 'enumArray' field has been set */
+ public boolean hasEnumArray() {
+ return fieldSetFlags()[13];
+ }
+
+ /** Clears the value of the 'enumArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearEnumArray() {
+ enumArray = null;
+ fieldSetFlags()[13] = false;
+ return this;
+ }
+
+ /** Gets the value of the 'fixedArray' field */
+ public java.util.List getFixedArray() {
+ return fixedArray;
+ }
+
+ /** Sets the value of the 'fixedArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder setFixedArray(java.util.List value) {
+ validate(fields()[14], value);
+ this.fixedArray = value;
+ fieldSetFlags()[14] = true;
+ return this;
+ }
+
+ /** Checks whether the 'fixedArray' field has been set */
+ public boolean hasFixedArray() {
+ return fieldSetFlags()[14];
+ }
+
+ /** Clears the value of the 'fixedArray' field */
+ public com.databricks.spark.avro.TestRecord.Builder clearFixedArray() {
+ fixedArray = null;
+ fieldSetFlags()[14] = false;
+ return this;
+ }
+
+ @Override
+ public TestRecord build() {
+ try {
+ TestRecord record = new TestRecord();
+ record.boolean$ = fieldSetFlags()[0] ? this.boolean$ : (java.lang.Boolean) defaultValue(fields()[0]);
+ record.int$ = fieldSetFlags()[1] ? this.int$ : (java.lang.Integer) defaultValue(fields()[1]);
+ record.long$ = fieldSetFlags()[2] ? this.long$ : (java.lang.Long) defaultValue(fields()[2]);
+ record.float$ = fieldSetFlags()[3] ? this.float$ : (java.lang.Float) defaultValue(fields()[3]);
+ record.double$ = fieldSetFlags()[4] ? this.double$ : (java.lang.Double) defaultValue(fields()[4]);
+ record.string = fieldSetFlags()[5] ? this.string : (java.lang.CharSequence) defaultValue(fields()[5]);
+ record.bytes = fieldSetFlags()[6] ? this.bytes : (java.nio.ByteBuffer) defaultValue(fields()[6]);
+ record.nested = fieldSetFlags()[7] ? this.nested : (com.databricks.spark.avro.SimpleRecord) defaultValue(fields()[7]);
+ record.enum$ = fieldSetFlags()[8] ? this.enum$ : (com.databricks.spark.avro.SimpleEnums) defaultValue(fields()[8]);
+ record.fixed = fieldSetFlags()[9] ? this.fixed : (com.databricks.spark.avro.SimpleFixed) defaultValue(fields()[9]);
+ record.intArray = fieldSetFlags()[10] ? this.intArray : (java.util.List) defaultValue(fields()[10]);
+ record.stringArray = fieldSetFlags()[11] ? this.stringArray : (java.util.List) defaultValue(fields()[11]);
+ record.recordArray = fieldSetFlags()[12] ? this.recordArray : (java.util.List) defaultValue(fields()[12]);
+ record.enumArray = fieldSetFlags()[13] ? this.enumArray : (java.util.List) defaultValue(fields()[13]);
+ record.fixedArray = fieldSetFlags()[14] ? this.fixedArray : (java.util.List) defaultValue(fields()[14]);
+ return record;
+ } catch (Exception e) {
+ throw new org.apache.avro.AvroRuntimeException(e);
+ }
+ }
+ }
+}
diff --git a/src/test/resources/specific.avsc b/src/test/resources/specific.avsc
new file mode 100644
index 00000000..dbbc1da6
--- /dev/null
+++ b/src/test/resources/specific.avsc
@@ -0,0 +1,40 @@
+{
+ "namespace": "com.databricks.spark.avro",
+ "type": "record",
+ "name": "TestRecord",
+ "fields": [
+ {"name": "boolean", "type": "boolean", "default": true},
+ {"name": "int", "type": "int", "default": 0},
+ {"name": "long", "type": "long", "default": 0},
+ {"name": "float", "type": "float", "default": 0.0},
+ {"name": "double", "type": "double", "default": 0.0},
+ {"name": "string", "type": "string", "default": "value"},
+ {"name": "bytes", "type": "bytes", "default": "\u00ff"},
+ {"name": "nested", "type": {
+ "type": "record", "name": "SimpleRecord", "fields": [
+ {"name": "nested1", "type": "int", "default": 0},
+ {"name": "nested2", "type": "string", "default": "string"}]},
+ "default": {"nested1": 0, "nested2": "string"}},
+ {"name": "enum", "type": {
+ "name": "SimpleEnums", "type": "enum", "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]},
+ "default": "SPADES"},
+ {"name": "fixed", "type": {
+ "name": "SimpleFixed", "type": "fixed", "size": 16},
+ "default": "string_length_16"},
+ {"name": "intArray",
+ "type": {"type": "array", "items": "int"},
+ "default": [1, 2, 3]},
+ {"name": "stringArray",
+ "type": {"type": "array", "items": "string"},
+ "default": ["a", "b", "c"]},
+ {"name": "recordArray",
+ "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleRecord"},
+ "default": [{"nested1": 0, "nested2": "value"}, {"nested1": 0, "nested2": "value"}]},
+ {"name": "enumArray",
+ "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleEnums"},
+ "default": ["SPADES", "HEARTS", "SPADES"]},
+ {"name": "fixedArray",
+ "type": {"type": "array", "items": "com.databricks.spark.avro.SimpleFixed"},
+ "default": ["foo", "bar", "baz"]}
+ ]
+}
\ No newline at end of file
diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala
index 1b5d07aa..e0aa7a26 100644
--- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala
+++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala
@@ -23,17 +23,19 @@ import java.sql.Timestamp
import java.util.UUID
import scala.collection.JavaConversions._
-
import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException
import org.apache.avro.Schema
import org.apache.avro.Schema.{Field, Type}
+import org.apache.avro.SchemaBuilder
import org.apache.avro.file.DataFileWriter
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
-import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
+import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord, GenericRecordBuilder}
import org.apache.commons.io.FileUtils
-
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class AvroSuite extends FunSuite with BeforeAndAfterAll {
@@ -674,4 +676,209 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
assert(input.rdd.partitions.size > 2)
}
}
+
+ test("generic record converts to row and back") {
+ val nested =
+ SchemaBuilder.record("simple_record").fields()
+ .name("nested1").`type`("int").withDefault(0)
+ .name("nested2").`type`("string").withDefault("string").endRecord()
+
+ val schema = SchemaBuilder.record("record").fields()
+ .name("boolean").`type`("boolean").withDefault(false)
+ .name("int").`type`("int").withDefault(0)
+ .name("long").`type`("long").withDefault(0L)
+ .name("float").`type`("float").withDefault(0.0F)
+ .name("double").`type`("double").withDefault(0.0)
+ .name("string").`type`("string").withDefault("string")
+ .name("bytes").`type`("bytes").withDefault(java.nio.ByteBuffer.wrap("bytes".getBytes))
+ .name("nested").`type`(nested).withDefault(new GenericRecordBuilder(nested).build)
+ .name("enum").`type`(
+ SchemaBuilder.enumeration("simple_enums")
+ .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS"))
+ .withDefault("SPADES")
+ .name("int_array").`type`(
+ SchemaBuilder.array().items().`type`("int"))
+ .withDefault(java.util.Arrays.asList(1, 2, 3))
+ .name("string_array").`type`(
+ SchemaBuilder.array().items().`type`("string"))
+ .withDefault(java.util.Arrays.asList("a", "b", "c"))
+ .name("record_array").`type`(
+ SchemaBuilder.array.items.`type`(nested))
+ .withDefault(java.util.Arrays.asList(
+ new GenericRecordBuilder(nested).build,
+ new GenericRecordBuilder(nested).build))
+ .name("enum_array").`type`(
+ SchemaBuilder.array.items.`type`(
+ SchemaBuilder.enumeration("simple_enums")
+ .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS")))
+ .withDefault(java.util.Arrays.asList("SPADES", "HEARTS", "SPADES"))
+ .name("fixed_array").`type`(
+ SchemaBuilder.array.items().`type`(
+ SchemaBuilder.fixed("simple_fixed").size(3)))
+ .withDefault(java.util.Arrays.asList("foo", "bar", "baz"))
+ .name("fixed").`type`(SchemaBuilder.fixed("simple_fixed").size(16))
+ .withDefault("string_length_16")
+ .endRecord()
+
+ val encoder = AvroEncoder.of[GenericData.Record](schema)
+ val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
+ val record = new GenericRecordBuilder(schema).build
+ val row = expressionEncoder.toRow(record)
+ val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(record == recordFromRow)
+ }
+
+ test("specific record converts to row and back") {
+ val schemaPath = "src/test/resources/specific.avsc"
+ val schema = new Schema.Parser().parse(new File(schemaPath))
+ val record = TestRecord.newBuilder().build()
+
+ val classEncoder = AvroEncoder.of[TestRecord](classOf[TestRecord])
+ val classExpressionEncoder = classEncoder.asInstanceOf[ExpressionEncoder[TestRecord]]
+ val classRow = classExpressionEncoder.toRow(record)
+ val classRecordFromRow = classExpressionEncoder.resolveAndBind().fromRow(classRow)
+
+ assert(record == classRecordFromRow)
+
+ val schemaEncoder = AvroEncoder.of[TestRecord](schema)
+ val schemaExpressionEncoder = schemaEncoder.asInstanceOf[ExpressionEncoder[TestRecord]]
+ val schemaRow = schemaExpressionEncoder.toRow(record)
+ val schemaRecordFromRow = schemaExpressionEncoder.resolveAndBind().fromRow(schemaRow)
+
+ assert(record == schemaRecordFromRow)
+ }
+
+ test("encoder resolves union types to rows") {
+ val schema = SchemaBuilder.record("record").fields()
+ .name("int_null_union").`type`(
+ SchemaBuilder.unionOf.`type`("null").and.`type`("int").endUnion)
+ .withDefault(null)
+ .name("string_null_union").`type`(
+ SchemaBuilder.unionOf.`type`("null").and.`type`("string").endUnion)
+ .withDefault(null)
+ .name("int_long_union").`type`(
+ SchemaBuilder.unionOf.`type`("int").and.`type`("long").endUnion)
+ .withDefault(0)
+ .name("float_double_union").`type`(
+ SchemaBuilder.unionOf.`type`("float").and.`type`("double").endUnion)
+ .withDefault(0.0)
+ .endRecord
+
+ val encoder = AvroEncoder.of[GenericData.Record](schema)
+ val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
+ val record = new GenericRecordBuilder(schema).build
+ val row = expressionEncoder.toRow(record)
+ val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(record.get(0) == recordFromRow.get(0))
+ assert(record.get(1) == recordFromRow.get(1))
+ assert(record.get(2) == recordFromRow.get(2))
+ assert(record.get(3) == recordFromRow.get(3))
+
+ record.put(0, 0)
+ record.put(1, "value")
+
+ val updatedRow = expressionEncoder.toRow(record)
+ val updatedRecordFromRow = expressionEncoder.resolveAndBind().fromRow(updatedRow)
+
+ assert(record.get(0) == updatedRecordFromRow.get(0))
+ assert(record.get(1) == updatedRecordFromRow.get(1))
+ }
+
+ test("encoder resolves map types to rows") {
+ val intMap = new java.util.HashMap[java.lang.String, java.lang.Integer]
+ intMap.put("foo", 1)
+ intMap.put("bar", 2)
+ intMap.put("baz", 3)
+
+ val stringMap = new java.util.HashMap[java.lang.String, java.lang.String]
+ stringMap.put("foo", "a")
+ stringMap.put("bar", "b")
+ stringMap.put("baz", "c")
+
+ val schema = SchemaBuilder.record("record").fields()
+ .name("int_map").`type`(
+ SchemaBuilder.map.values.`type`("int")).withDefault(intMap)
+ .name("string_map").`type`(
+ SchemaBuilder.map.values.`type`("string")).withDefault(stringMap)
+ .endRecord()
+
+ val encoder = AvroEncoder.of[GenericData.Record](schema)
+ val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
+ val record = new GenericRecordBuilder(schema).build
+ val row = expressionEncoder.toRow(record)
+ val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ val rowIntMap = recordFromRow.get(0)
+ assert(intMap == rowIntMap)
+
+ val rowStringMap = recordFromRow.get(1)
+ assert(stringMap == rowStringMap)
+ }
+
+ test("encoder resolves complex unions to rows") {
+ val nested =
+ SchemaBuilder.record("simple_record").fields()
+ .name("nested1").`type`("int").withDefault(0)
+ .name("nested2").`type`("string").withDefault("foo").endRecord()
+ val schema = SchemaBuilder.record("record").fields()
+ .name("int_float_string_record").`type`(
+ SchemaBuilder.unionOf()
+ .`type`("null").and()
+ .`type`("int").and()
+ .`type`("float").and()
+ .`type`("string").and()
+ .`type`(nested).endUnion()
+ ).withDefault(null).endRecord()
+
+ val encoder = AvroEncoder.of[GenericData.Record](schema)
+ val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]]
+ val record = new GenericRecordBuilder(schema).build
+ var row = expressionEncoder.toRow(record)
+ var recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(row.getStruct(0, 4).get(0, IntegerType) == null)
+ assert(row.getStruct(0, 4).get(1, FloatType) == null)
+ assert(row.getStruct(0, 4).get(2, StringType) == null)
+ assert(row.getStruct(0, 4).getStruct(3, 2) == null)
+ assert(record == recordFromRow)
+
+ record.put(0, 1)
+ row = expressionEncoder.toRow(record)
+ recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(row.getStruct(0, 4).get(1, FloatType) == null)
+ assert(row.getStruct(0, 4).get(2, StringType) == null)
+ assert(row.getStruct(0, 4).getStruct(3, 2) == null)
+ assert(record == recordFromRow)
+
+ record.put(0, 1F)
+ row = expressionEncoder.toRow(record)
+ recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(row.getStruct(0, 4).get(0, IntegerType) == null)
+ assert(row.getStruct(0, 4).get(2, StringType) == null)
+ assert(row.getStruct(0, 4).getStruct(3, 2) == null)
+ assert(record == recordFromRow)
+
+ record.put(0, "bar")
+ row = expressionEncoder.toRow(record)
+ recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(row.getStruct(0, 4).get(0, IntegerType) == null)
+ assert(row.getStruct(0, 4).get(1, FloatType) == null)
+ assert(row.getStruct(0, 4).getStruct(3, 2) == null)
+ assert(record == recordFromRow)
+
+ record.put(0, new GenericRecordBuilder(nested).build())
+ row = expressionEncoder.toRow(record)
+ recordFromRow = expressionEncoder.resolveAndBind().fromRow(row)
+
+ assert(row.getStruct(0, 4).get(0, IntegerType) == null)
+ assert(row.getStruct(0, 4).get(1, FloatType) == null)
+ assert(row.getStruct(0, 4).get(2, StringType) == null)
+ assert(record == recordFromRow)
+ }
+
}