diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index 297c39d6..ea0e4c06 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -45,7 +45,8 @@ private[avro] class AvroOutputWriter( recordName: String, recordNamespace: String) extends OutputWriter { - private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) + private lazy val converter = + SchemaConverters.createConverterToAvro(schema, recordName, recordNamespace) /** * Overrides the couple of methods responsible for generating the output streams / files so @@ -72,79 +73,4 @@ private[avro] class AvroOutputWriter( } override def close(): Unit = recordWriter.close(context) - - /** - * This function constructs converter function for a given sparkSQL datatype. This is used in - * writing Avro records out to disk - */ - private def createConverterToAvro( - dataType: DataType, - structName: String, - recordNamespace: String): (Any) => Any = { - dataType match { - case BinaryType => (item: Any) => item match { - case null => null - case bytes: Array[Byte] => ByteBuffer.wrap(bytes) - } - case ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | StringType | BooleanType => identity - case _: DecimalType => (item: Any) => if (item == null) null else item.toString - case TimestampType => (item: Any) => - if (item == null) null else item.asInstanceOf[Timestamp].getTime - case DateType => (item: Any) => - if (item == null) null else item.asInstanceOf[Date].getTime - case ArrayType(elementType, _) => - val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) - (item: Any) => { - if (item == null) { - null - } else { - val sourceArray = item.asInstanceOf[Seq[Any]] - val sourceArraySize = sourceArray.size - val targetArray = new Array[Any](sourceArraySize) - var idx = 0 - while (idx < sourceArraySize) { - targetArray(idx) = elementConverter(sourceArray(idx)) - idx += 1 - } - targetArray - } - } - case MapType(StringType, valueType, _) => - val valueConverter = createConverterToAvro(valueType, structName, recordNamespace) - (item: Any) => { - if (item == null) { - null - } else { - val javaMap = new HashMap[String, Any]() - item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => - javaMap.put(key, valueConverter(value)) - } - javaMap - } - } - case structType: StructType => - val builder = SchemaBuilder.record(structName).namespace(recordNamespace) - val schema: Schema = SchemaConverters.convertStructToAvro( - structType, builder, recordNamespace) - val fieldConverters = structType.fields.map(field => - createConverterToAvro(field.dataType, field.name, recordNamespace)) - (item: Any) => { - if (item == null) { - null - } else { - val record = new Record(schema) - val convertersIterator = fieldConverters.iterator - val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator - val rowIterator = item.asInstanceOf[Row].toSeq.iterator - - while (convertersIterator.hasNext) { - val converter = convertersIterator.next() - record.put(fieldNamesIterator.next(), converter(rowIterator.next())) - } - record - } - } - } - } } diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index 7f8e20f4..a163e942 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -16,18 +16,21 @@ package com.databricks.spark.avro import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} +import java.util.HashMap -import scala.collection.JavaConverters._ - -import org.apache.avro.generic.GenericData.Fixed +import org.apache.avro.Schema.Type._ +import org.apache.avro.SchemaBuilder._ +import org.apache.avro.generic.GenericData.{Fixed, Record} import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.{Schema, SchemaBuilder} -import org.apache.avro.SchemaBuilder._ -import org.apache.avro.Schema.Type._ - +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types._ +import scala.collection.JavaConverters._ +import scala.collection.immutable.Map + /** * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice * versa. @@ -113,16 +116,20 @@ object SchemaConverters { def convertStructToAvro[T]( structType: StructType, schemaBuilder: RecordBuilder[T], - recordNamespace: String): T = { + recordNamespace: String, + structName: String = "", + schemaMap: collection.mutable.Map[String, Object] = + collection.mutable.Map[String, Object]()): T = { val fieldsAssembler: FieldAssembler[T] = schemaBuilder.fields() structType.fields.foreach { field => val newField = fieldsAssembler.name(field.name).`type`() if (field.nullable) { - convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace) + convertFieldTypeToAvro(field.dataType, newField.nullable(), field.name, recordNamespace, + schemaMap) .noDefault } else { - convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace) + convertFieldTypeToAvro(field.dataType, newField, field.name, recordNamespace, schemaMap) .noDefault } } @@ -315,7 +322,7 @@ object SchemaConverters { dataType: DataType, schemaBuilder: BaseTypeBuilder[T], structName: String, - recordNamespace: String): T = { + recordNamespace: String, schemaMap: collection.mutable.Map[String, Object]): T = { dataType match { case ByteType => schemaBuilder.intType() case ShortType => schemaBuilder.intType() @@ -332,19 +339,21 @@ object SchemaConverters { case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace, + schemaMap) schemaBuilder.array().items(elementSchema) case MapType(StringType, valueType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace, + schemaMap) schemaBuilder.map().values(valueSchema) case structType: StructType => convertStructToAvro( structType, schemaBuilder.record(structName).namespace(recordNamespace), - recordNamespace) + recordNamespace, structName, schemaMap) case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") } @@ -359,7 +368,8 @@ object SchemaConverters { dataType: DataType, newFieldBuilder: BaseFieldTypeBuilder[T], structName: String, - recordNamespace: String): FieldDefault[T, _] = { + recordNamespace: String, + schemaMap: collection.mutable.Map[String, Object]): FieldDefault[T, _] = { dataType match { case ByteType => newFieldBuilder.intType() case ShortType => newFieldBuilder.intType() @@ -376,19 +386,28 @@ object SchemaConverters { case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) - val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace) + val elementSchema = convertTypeToAvro(elementType, builder, structName, recordNamespace, + schemaMap) newFieldBuilder.array().items(elementSchema) case MapType(StringType, valueType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[MapType].valueContainsNull) - val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace) + val valueSchema = convertTypeToAvro(valueType, builder, structName, recordNamespace, + schemaMap) newFieldBuilder.map().values(valueSchema) case structType: StructType => - convertStructToAvro( - structType, - newFieldBuilder.record(structName).namespace(recordNamespace), - recordNamespace) + val schemaKey = s"$recordNamespace.$structName" + if (schemaMap.contains(schemaKey)) { + val schema = schemaMap.get(schemaKey).get + schema.asInstanceOf[RecordDefault[T]] + } else { + val schema : RecordDefault[T] = SchemaConverters.convertStructToAvro( + structType, newFieldBuilder.record(structName).namespace(recordNamespace), + recordNamespace, structName, schemaMap) + schemaMap.put(schemaKey, schema) + schema + } case other => throw new IncompatibleSchemaException(s"Unexpected type $dataType.") } @@ -401,4 +420,83 @@ object SchemaConverters { SchemaBuilder.builder() } } + + /** + * This function constructs converter function for a given sparkSQL datatype. This is used in + * writing Avro records out to disk + */ + def createConverterToAvro( + dataType: DataType, + structName: String, + recordNamespace: String, + schemaMap: collection.mutable.Map[String, Object] = + collection.mutable.Map[String, Object]()): (Any) => Any = { + dataType match { + case BinaryType => (item: Any) => item match { + case null => null + case bytes: Array[Byte] => ByteBuffer.wrap(bytes) + } + case ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | StringType | BooleanType => identity + case _: DecimalType => (item: Any) => if (item == null) null else item.toString + case TimestampType => (item: Any) => + if (item == null) null else item.asInstanceOf[Timestamp].getTime + case DateType => (item: Any) => + if (item == null) null else item.asInstanceOf[Date].getTime + case ArrayType(elementType, _) => + val elementConverter = createConverterToAvro(elementType, structName, recordNamespace, + schemaMap) + (item: Any) => { + if (item == null) { + null + } else { + val sourceArray = item.asInstanceOf[Seq[Any]] + val sourceArraySize = sourceArray.size + val targetArray = new Array[Any](sourceArraySize) + var idx = 0 + while (idx < sourceArraySize) { + targetArray(idx) = elementConverter(sourceArray(idx)) + idx += 1 + } + targetArray + } + } + case MapType(StringType, valueType, _) => + val valueConverter = createConverterToAvro(valueType, structName, recordNamespace, + schemaMap) + (item: Any) => { + if (item == null) { + null + } else { + val javaMap = new HashMap[String, Any]() + item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => + javaMap.put(key, valueConverter(value)) + } + javaMap + } + } + case structType: StructType => + val builder = SchemaBuilder.record(structName).namespace(recordNamespace) + val schema: Schema = SchemaConverters.convertStructToAvro( + structType, builder, recordNamespace, structName, schemaMap) + val fieldConverters = structType.fields.map(field => + createConverterToAvro(field.dataType, field.name, recordNamespace, schemaMap)) + (item: Any) => { + if (item == null) { + null + } else { + val record = new Record(schema) + val convertersIterator = fieldConverters.iterator + val fieldNamesIterator = dataType.asInstanceOf[StructType].fieldNames.iterator + val rowIterator = item.asInstanceOf[Row].toSeq.iterator + + while (convertersIterator.hasNext) { + val converter = convertersIterator.next() + record.put(fieldNamesIterator.next(), converter(rowIterator.next())) + } + record + } + } + } + } }