Skip to content
This repository was archived by the owner on Dec 20, 2018. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
591 changes: 591 additions & 0 deletions src/main/scala/com/databricks/spark/avro/AvroEncoder.scala

Large diffs are not rendered by default.

66 changes: 38 additions & 28 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ package com.databricks.spark.avro
import java.nio.ByteBuffer

import scala.collection.JavaConverters._

import org.apache.avro.generic.GenericData.Fixed
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.SchemaBuilder._
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -74,38 +72,50 @@ object SchemaConverters {
nullable = false)

case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
toSqlType(remainingUnionTypes.head).copy(nullable = true)
} else {
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
toSqlType(avroSchema.getTypes.get(0))
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
SchemaType(DoubleType, nullable = false)
case _ =>
// Convert complex unions to struct types where field names are member0, member1, etc.
// This is consistent with the behavior when converting between Avro and Parquet.
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlType(s)
// All fields are nullable because only one of them is set at a time
StructField(s"member$i", schemaType.dataType, nullable = true)
}

SchemaType(StructType(fields), nullable = false)
resolveUnionType(avroSchema) match {
case (schema, nullable) => toSqlType(schema).copy(nullable = nullable)
}

case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
}
}

/**
* Resolves an avro UNION type to an SQL-compatible avro type. Converts complex unions to records
* if necessary.
*/
def resolveUnionType(avroSchema: Schema, nullable: Boolean = false): (Schema, Boolean) = {
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, eliminate it, and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
(remainingUnionTypes.head, true)
} else {
resolveUnionType(Schema.createUnion(remainingUnionTypes.asJava), nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
(avroSchema.getTypes.get(0), true)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
(Schema.create(LONG), false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
(Schema.create(DOUBLE), false)
case _ =>
// Convert complex unions to records where field names are member0, member1, etc.
// This is consistent with the behavior when converting between Avro and Parquet.
val record = SchemaBuilder.record(avroSchema.getName).fields()
avroSchema.getTypes.asScala.zipWithIndex.foreach {
case (s, i) =>
// All fields are nullable because only one of them is set at a time
record.name(s"member$i").`type`(SchemaBuilder.unionOf()
.`type`(Schema.create(NULL)).and
.`type`(s).endUnion())
.withDefault(null)
}
(record.endRecord(), false)
}
}

/**
* This function converts sparkSQL StructType into avro schema. This method uses two other
* converter methods in order to do the conversion.
Expand Down
13 changes: 13 additions & 0 deletions src/test/java/com/databricks/spark/avro/SimpleEnums.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* Autogenerated by Avro
*
* DO NOT EDIT DIRECTLY
*/
package com.databricks.spark.avro;
@SuppressWarnings("all")
@org.apache.avro.specific.AvroGenerated
public enum SimpleEnums {
SPADES, HEARTS, DIAMONDS, CLUBS ;
public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"enum\",\"name\":\"SimpleEnums\",\"namespace\":\"com.databricks.spark.avro\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}");
public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
}
23 changes: 23 additions & 0 deletions src/test/java/com/databricks/spark/avro/SimpleFixed.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/**
* Autogenerated by Avro
*
* DO NOT EDIT DIRECTLY
*/
package com.databricks.spark.avro;
@SuppressWarnings("all")
@org.apache.avro.specific.FixedSize(16)
@org.apache.avro.specific.AvroGenerated
public class SimpleFixed extends org.apache.avro.specific.SpecificFixed {
public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"fixed\",\"name\":\"SimpleFixed\",\"namespace\":\"com.databricks.spark.avro\",\"size\":16}");
public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }

/** Creates a new SimpleFixed */
public SimpleFixed() {
super();
}

/** Creates a new SimpleFixed with the given bytes */
public SimpleFixed(byte[] bytes) {
super(bytes);
}
}
195 changes: 195 additions & 0 deletions src/test/java/com/databricks/spark/avro/SimpleRecord.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/**
* Autogenerated by Avro
*
* DO NOT EDIT DIRECTLY
*/
package com.databricks.spark.avro;
@SuppressWarnings("all")
@org.apache.avro.specific.AvroGenerated
public class SimpleRecord extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord {
public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"SimpleRecord\",\"namespace\":\"com.databricks.spark.avro\",\"fields\":[{\"name\":\"nested1\",\"type\":\"int\",\"default\":0},{\"name\":\"nested2\",\"type\":\"string\",\"default\":\"string\"}]}");
public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; }
@Deprecated public int nested1;
@Deprecated public java.lang.CharSequence nested2;

/**
* Default constructor. Note that this does not initialize fields
* to their default values from the schema. If that is desired then
* one should use <code>newBuilder()</code>.
*/
public SimpleRecord() {}

/**
* All-args constructor.
*/
public SimpleRecord(java.lang.Integer nested1, java.lang.CharSequence nested2) {
this.nested1 = nested1;
this.nested2 = nested2;
}

public org.apache.avro.Schema getSchema() { return SCHEMA$; }
// Used by DatumWriter. Applications should not call.
public java.lang.Object get(int field$) {
switch (field$) {
case 0: return nested1;
case 1: return nested2;
default: throw new org.apache.avro.AvroRuntimeException("Bad index");
}
}
// Used by DatumReader. Applications should not call.
@SuppressWarnings(value="unchecked")
public void put(int field$, java.lang.Object value$) {
switch (field$) {
case 0: nested1 = (java.lang.Integer)value$; break;
case 1: nested2 = (java.lang.CharSequence)value$; break;
default: throw new org.apache.avro.AvroRuntimeException("Bad index");
}
}

/**
* Gets the value of the 'nested1' field.
*/
public java.lang.Integer getNested1() {
return nested1;
}

/**
* Sets the value of the 'nested1' field.
* @param value the value to set.
*/
public void setNested1(java.lang.Integer value) {
this.nested1 = value;
}

/**
* Gets the value of the 'nested2' field.
*/
public java.lang.CharSequence getNested2() {
return nested2;
}

/**
* Sets the value of the 'nested2' field.
* @param value the value to set.
*/
public void setNested2(java.lang.CharSequence value) {
this.nested2 = value;
}

/** Creates a new SimpleRecord RecordBuilder */
public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder() {
return new com.databricks.spark.avro.SimpleRecord.Builder();
}

/** Creates a new SimpleRecord RecordBuilder by copying an existing Builder */
public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord.Builder other) {
return new com.databricks.spark.avro.SimpleRecord.Builder(other);
}

/** Creates a new SimpleRecord RecordBuilder by copying an existing SimpleRecord instance */
public static com.databricks.spark.avro.SimpleRecord.Builder newBuilder(com.databricks.spark.avro.SimpleRecord other) {
return new com.databricks.spark.avro.SimpleRecord.Builder(other);
}

/**
* RecordBuilder for SimpleRecord instances.
*/
public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase<SimpleRecord>
implements org.apache.avro.data.RecordBuilder<SimpleRecord> {

private int nested1;
private java.lang.CharSequence nested2;

/** Creates a new Builder */
private Builder() {
super(com.databricks.spark.avro.SimpleRecord.SCHEMA$);
}

/** Creates a Builder by copying an existing Builder */
private Builder(com.databricks.spark.avro.SimpleRecord.Builder other) {
super(other);
if (isValidValue(fields()[0], other.nested1)) {
this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1);
fieldSetFlags()[0] = true;
}
if (isValidValue(fields()[1], other.nested2)) {
this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2);
fieldSetFlags()[1] = true;
}
}

/** Creates a Builder by copying an existing SimpleRecord instance */
private Builder(com.databricks.spark.avro.SimpleRecord other) {
super(com.databricks.spark.avro.SimpleRecord.SCHEMA$);
if (isValidValue(fields()[0], other.nested1)) {
this.nested1 = data().deepCopy(fields()[0].schema(), other.nested1);
fieldSetFlags()[0] = true;
}
if (isValidValue(fields()[1], other.nested2)) {
this.nested2 = data().deepCopy(fields()[1].schema(), other.nested2);
fieldSetFlags()[1] = true;
}
}

/** Gets the value of the 'nested1' field */
public java.lang.Integer getNested1() {
return nested1;
}

/** Sets the value of the 'nested1' field */
public com.databricks.spark.avro.SimpleRecord.Builder setNested1(int value) {
validate(fields()[0], value);
this.nested1 = value;
fieldSetFlags()[0] = true;
return this;
}

/** Checks whether the 'nested1' field has been set */
public boolean hasNested1() {
return fieldSetFlags()[0];
}

/** Clears the value of the 'nested1' field */
public com.databricks.spark.avro.SimpleRecord.Builder clearNested1() {
fieldSetFlags()[0] = false;
return this;
}

/** Gets the value of the 'nested2' field */
public java.lang.CharSequence getNested2() {
return nested2;
}

/** Sets the value of the 'nested2' field */
public com.databricks.spark.avro.SimpleRecord.Builder setNested2(java.lang.CharSequence value) {
validate(fields()[1], value);
this.nested2 = value;
fieldSetFlags()[1] = true;
return this;
}

/** Checks whether the 'nested2' field has been set */
public boolean hasNested2() {
return fieldSetFlags()[1];
}

/** Clears the value of the 'nested2' field */
public com.databricks.spark.avro.SimpleRecord.Builder clearNested2() {
nested2 = null;
fieldSetFlags()[1] = false;
return this;
}

@Override
public SimpleRecord build() {
try {
SimpleRecord record = new SimpleRecord();
record.nested1 = fieldSetFlags()[0] ? this.nested1 : (java.lang.Integer) defaultValue(fields()[0]);
record.nested2 = fieldSetFlags()[1] ? this.nested2 : (java.lang.CharSequence) defaultValue(fields()[1]);
return record;
} catch (Exception e) {
throw new org.apache.avro.AvroRuntimeException(e);
}
}
}
}
Loading