Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public ColumnarMap getMap(int ordinal) {

@Override
public Object get(int ordinal, DataType dataType) {
return SpecializedGettersReader.read(this, ordinal, dataType, false, false);
return SpecializedGettersReader.read(this, ordinal, dataType, false, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ public Object get(int ordinal, DataType dataType) {
return getMap(ordinal);
} else if (dataType instanceof VariantType) {
return getVariant(ordinal);
} else if (dataType instanceof UserDefinedType<?> udt) {
return get(ordinal, udt.sqlType());
} else {
throw new SparkUnsupportedOperationException(
"_LEGACY_ERROR_TEMP_3152", Map.of("dataType", String.valueOf(dataType)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ public Object get(int ordinal, DataType dataType) {
return getMap(ordinal);
} else if (dataType instanceof VariantType) {
return getVariant(ordinal);
} else if (dataType instanceof UserDefinedType<?> udt) {
return get(ordinal, udt.sqlType());
} else {
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3155");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,38 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
import org.apache.spark.util.ArrayImplicits._

/**
* A minimal UDT backed by IntegerType, used by SPARK-55897 tests.
*/
@SQLUserDefinedType(udt = classOf[TestIntUDT])
private case class TestIntWrapper(value: Int)

private class TestIntUDT extends UserDefinedType[TestIntWrapper] {
override def sqlType: DataType = IntegerType
override def serialize(obj: TestIntWrapper): Any = obj.value
override def userClass: Class[TestIntWrapper] = classOf[TestIntWrapper]
override def deserialize(datum: Any): TestIntWrapper = datum match {
case v: Int => TestIntWrapper(v)
}
}

/**
* A minimal UDT backed by StructType, used by SPARK-55897 tests.
*/
@SQLUserDefinedType(udt = classOf[TestStructWrapperUDT])
private case class TestStructWrapper(x: Int, y: Long)

private class TestStructWrapperUDT extends UserDefinedType[TestStructWrapper] {
override def sqlType: DataType = new StructType()
.add("x", IntegerType)
.add("y", LongType)
override def serialize(obj: TestStructWrapper): Any = InternalRow(obj.x, obj.y)
override def userClass: Class[TestStructWrapper] = classOf[TestStructWrapper]
override def deserialize(datum: Any): TestStructWrapper = datum match {
case row: InternalRow => TestStructWrapper(row.getInt(0), row.getLong(1))
}
}

@ExtendedSQLTest
class ColumnarBatchSuite extends SparkFunSuite {

Expand Down Expand Up @@ -2025,4 +2057,93 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
}
}

testVector(
"SPARK-55897: ColumnarRow.get with primitive-backed UDT",
10,
new StructType().add("name", StringType).add("udt_field", IntegerType)) { column =>
column.getChild(0).putByteArray(0, "hello".getBytes)
column.getChild(1).putInt(0, 42)

val row = column.getStruct(0)
assert(row.get(1, new TestIntUDT()) === 42)
}

testVector(
"SPARK-55897: ColumnarRow.get with struct-backed UDT",
10,
new StructType()
.add("id", IntegerType)
.add("nested", new StructType().add("x", IntegerType).add("y", LongType))) { column =>
column.getChild(0).putInt(0, 1)
column.getChild(1).getChild(0).putInt(0, 10)
column.getChild(1).getChild(1).putLong(0, 20L)

val row = column.getStruct(0)
val nested = row.get(1, new TestStructWrapperUDT()).asInstanceOf[InternalRow]
assert(nested.getInt(0) === 10)
assert(nested.getLong(1) === 20L)
}

testVector(
"SPARK-55897: ColumnarArray.get with primitive-backed UDT",
10,
new ArrayType(IntegerType, false)) { column =>
val data = column.arrayData()
data.putInt(0, 10)
data.putInt(1, 20)
column.putArray(0, 0, 2)

val arr = column.getArray(0)
assert(arr.get(0, new TestIntUDT()) === 10)
assert(arr.get(1, new TestIntUDT()) === 20)
}

testVector(
"SPARK-55897: ColumnarArray.get with struct-backed UDT",
10,
new ArrayType(new StructType().add("x", IntegerType).add("y", LongType), false)) { column =>
val data = column.arrayData()
data.getChild(0).putInt(0, 100)
data.getChild(1).putLong(0, 200L)
column.putArray(0, 0, 1)

val arr = column.getArray(0)
val row = arr.get(0, new TestStructWrapperUDT()).asInstanceOf[InternalRow]
assert(row.getInt(0) === 100)
assert(row.getLong(1) === 200L)
}

test("SPARK-55897: ColumnarBatchRow.get with primitive-backed UDT") {
Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { memMode =>
val col = allocate(10, IntegerType, memMode)
try {
col.putInt(0, 99)
val batchRow = new ColumnarBatchRow(Array(col))
batchRow.rowId = 0
assert(batchRow.get(0, new TestIntUDT()) === 99)
} finally {
col.close()
}
}
}

test("SPARK-55897: ColumnarBatchRow.get with struct-backed UDT") {
Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { memMode =>
val col = allocate(10,
new StructType().add("x", IntegerType).add("y", LongType), memMode)
try {
col.getChild(0).putInt(0, 5)
col.getChild(1).putLong(0, 15L)
val batchRow = new ColumnarBatchRow(Array(col))
batchRow.rowId = 0

val row = batchRow.get(0, new TestStructWrapperUDT()).asInstanceOf[InternalRow]
assert(row.getInt(0) === 5)
assert(row.getLong(1) === 15L)
} finally {
col.close()
}
}
}
}