From 6262e243fe6d2aca62115d50a827f97f068f8f93 Mon Sep 17 00:00:00 2001 From: Prazwal Ratti Date: Fri, 3 Jul 2026 00:59:33 +0530 Subject: [PATCH] convex-db: validate Pg wire frame lengths/counts to prevent pre-auth DoS --- .../java/convex/db/psql/PgMessageDecoder.java | 20 +- .../convex/db/psql/PgMessageDecoderTest.java | 179 ++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 convex-db/src/test/java/convex/db/psql/PgMessageDecoderTest.java diff --git a/convex-db/src/main/java/convex/db/psql/PgMessageDecoder.java b/convex-db/src/main/java/convex/db/psql/PgMessageDecoder.java index 301a5db9f..e1af2b71a 100644 --- a/convex-db/src/main/java/convex/db/psql/PgMessageDecoder.java +++ b/convex-db/src/main/java/convex/db/psql/PgMessageDecoder.java @@ -3,6 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CorruptedFrameException; import java.nio.charset.StandardCharsets; import java.util.HashMap; @@ -14,6 +15,9 @@ */ public class PgMessageDecoder extends ByteToMessageDecoder { + /** Maximum accepted length of a single Postgres protocol message. */ + private static final int MAX_MESSAGE_LENGTH = 32 * 1024 * 1024; // 32 MB + private boolean startupComplete = false; @Override @@ -36,6 +40,9 @@ private void decodeStartup(ByteBuf in, List out) { in.markReaderIndex(); int length = in.readInt(); + if (length < 8 || length > MAX_MESSAGE_LENGTH) + throw new CorruptedFrameException("Invalid startup message length: " + length); + if (in.readableBytes() < length - 4) { in.resetReaderIndex(); return; // Wait for full message @@ -87,6 +94,9 @@ private void decodeRegular(ByteBuf in, List out) { byte type = in.readByte(); int length = in.readInt(); + if (length < 4 || length > MAX_MESSAGE_LENGTH) + throw new CorruptedFrameException("Invalid message length: " + length); + if (in.readableBytes() < length - 4) { in.resetReaderIndex(); return; // Wait for full message @@ -101,6 +111,7 @@ private void decodeRegular(ByteBuf in, List out) { String name = readCString(in); String query = readCString(in); short paramCount = in.readShort(); + if (paramCount < 0) throw new CorruptedFrameException("Negative paramCount: " + paramCount); int[] paramTypes = new int[paramCount]; for (int i = 0; i < paramCount; i++) { paramTypes[i] = in.readInt(); @@ -113,6 +124,7 @@ private void decodeRegular(ByteBuf in, List out) { // Parameter format codes short numParamFormats = in.readShort(); + if (numParamFormats < 0) throw new CorruptedFrameException("Negative numParamFormats: " + numParamFormats); short[] paramFormats = new short[numParamFormats]; for (int i = 0; i < numParamFormats; i++) { paramFormats[i] = in.readShort(); @@ -120,12 +132,15 @@ private void decodeRegular(ByteBuf in, List out) { // Parameter values short numParams = in.readShort(); + if (numParams < 0) throw new CorruptedFrameException("Negative numParams: " + numParams); byte[][] paramValues = new byte[numParams][]; for (int i = 0; i < numParams; i++) { int paramLen = in.readInt(); if (paramLen == -1) { paramValues[i] = null; // NULL } else { + if (paramLen < 0 || paramLen > MAX_MESSAGE_LENGTH) + throw new CorruptedFrameException("Invalid paramLen: " + paramLen); paramValues[i] = new byte[paramLen]; in.readBytes(paramValues[i]); } @@ -133,6 +148,7 @@ private void decodeRegular(ByteBuf in, List out) { // Result format codes short numResultFormats = in.readShort(); + if (numResultFormats < 0) throw new CorruptedFrameException("Negative numResultFormats: " + numResultFormats); short[] resultFormats = new short[numResultFormats]; for (int i = 0; i < numResultFormats; i++) { resultFormats[i] = in.readShort(); @@ -171,10 +187,12 @@ private void decodeRegular(ByteBuf in, List out) { private String readCString(ByteBuf buf) { int start = buf.readerIndex(); + int limit = buf.writerIndex(); int end = start; - while (buf.getByte(end) != 0) { + while (end < limit && buf.getByte(end) != 0) { end++; } + if (end >= limit) throw new CorruptedFrameException("Unterminated C-string"); byte[] bytes = new byte[end - start]; buf.readBytes(bytes); buf.readByte(); // consume null terminator diff --git a/convex-db/src/test/java/convex/db/psql/PgMessageDecoderTest.java b/convex-db/src/test/java/convex/db/psql/PgMessageDecoderTest.java new file mode 100644 index 000000000..50c939268 --- /dev/null +++ b/convex-db/src/test/java/convex/db/psql/PgMessageDecoderTest.java @@ -0,0 +1,179 @@ +package convex.db.psql; + +import static org.junit.jupiter.api.Assertions.*; + +import java.nio.charset.StandardCharsets; + +import org.junit.jupiter.api.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CorruptedFrameException; + +/** + * Unit tests for {@link PgMessageDecoder}, exercising malformed / hostile frames + * against the raw decoder via a Netty {@link EmbeddedChannel}. + * + *

These cover pre-authentication denial-of-service and crash inputs: the decoder + * runs on untrusted bytes before any authentication step, so it must reject + * malformed length/count fields with a {@link CorruptedFrameException} rather than + * buffering unbounded memory or throwing {@code NegativeArraySizeException} / + * {@code IndexOutOfBoundsException}. + */ +public class PgMessageDecoderTest { + + private static final int PROTOCOL_3_0 = (3 << 16); // major 3, minor 0 + + /** Writes a NUL-terminated C-string. */ + private static void putCString(ByteBuf buf, String s) { + buf.writeBytes(s.getBytes(StandardCharsets.UTF_8)); + buf.writeByte(0); + } + + /** Builds a well-formed startup message frame (length prefix + version + params). */ + private static ByteBuf startupFrame() { + ByteBuf body = Unpooled.buffer(); + body.writeInt(PROTOCOL_3_0); + putCString(body, "user"); + putCString(body, "test"); + body.writeByte(0); // empty key terminates parameter list + + ByteBuf frame = Unpooled.buffer(); + frame.writeInt(body.readableBytes() + 4); // length includes the length field + frame.writeBytes(body); + return frame; + } + + /** Feeds a valid startup so the decoder flips into regular-message mode. */ + private static EmbeddedChannel authenticatedChannel() { + EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder()); + assertTrue(ch.writeInbound(startupFrame())); + Object msg = ch.readInbound(); + assertInstanceOf(PgMessageDecoder.StartupMessage.class, msg); + return ch; + } + + private static void assertRejected(EmbeddedChannel ch, ByteBuf frame) { + Throwable t = assertThrows(Throwable.class, () -> ch.writeInbound(frame)); + // CorruptedFrameException is a DecoderException; ByteToMessageDecoder rethrows it as-is. + boolean found = false; + for (Throwable c = t; c != null; c = c.getCause()) { + if (c instanceof CorruptedFrameException) { found = true; break; } + } + assertTrue(found, "Expected CorruptedFrameException, got: " + t); + } + + @Test + public void testOversizedStartupLength() { + EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder()); + ByteBuf frame = Unpooled.buffer(); + frame.writeInt(Integer.MAX_VALUE); // absurd length, must be rejected before buffering + frame.writeInt(PROTOCOL_3_0); + assertRejected(ch, frame); + } + + @Test + public void testTooSmallStartupLength() { + EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder()); + ByteBuf frame = Unpooled.buffer(); + frame.writeInt(7); // below the 8-byte minimum + frame.writeInt(PROTOCOL_3_0); + assertRejected(ch, frame); + } + + @Test + public void testOversizedRegularLength() { + EmbeddedChannel ch = authenticatedChannel(); + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.QUERY); + frame.writeInt(Integer.MAX_VALUE); + assertRejected(ch, frame); + } + + @Test + public void testTooSmallRegularLength() { + EmbeddedChannel ch = authenticatedChannel(); + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.QUERY); + frame.writeInt(3); // below the 4-byte minimum + frame.writeByte('x'); + assertRejected(ch, frame); + } + + @Test + public void testBindNegativeNumParams() { + EmbeddedChannel ch = authenticatedChannel(); + ByteBuf body = Unpooled.buffer(); + putCString(body, ""); // portal + putCString(body, ""); // statement + body.writeShort(0); // numParamFormats + body.writeShort(-1); // numParams <-- hostile + + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.BIND); + frame.writeInt(body.readableBytes() + 4); + frame.writeBytes(body); + assertRejected(ch, frame); + } + + @Test + public void testBindNegativeParamLen() { + EmbeddedChannel ch = authenticatedChannel(); + ByteBuf body = Unpooled.buffer(); + putCString(body, ""); // portal + putCString(body, ""); // statement + body.writeShort(0); // numParamFormats + body.writeShort(1); // numParams + body.writeInt(-2); // paramLen (-1 is NULL sentinel; -2 is hostile) + + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.BIND); + frame.writeInt(body.readableBytes() + 4); + frame.writeBytes(body); + assertRejected(ch, frame); + } + + @Test + public void testParseNegativeParamCount() { + EmbeddedChannel ch = authenticatedChannel(); + ByteBuf body = Unpooled.buffer(); + putCString(body, ""); // statement name + putCString(body, ""); // query + body.writeShort(-1); // paramCount <-- hostile + + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.PARSE); + frame.writeInt(body.readableBytes() + 4); + frame.writeBytes(body); + assertRejected(ch, frame); + } + + @Test + public void testUnterminatedCString() { + EmbeddedChannel ch = authenticatedChannel(); + byte[] noNul = "SELECT 1".getBytes(StandardCharsets.UTF_8); + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.QUERY); + frame.writeInt(noNul.length + 4); // valid length, but no NUL terminator in body + frame.writeBytes(noNul); + assertRejected(ch, frame); + } + + @Test + public void testWellFormedStartupAndQueryStillDecode() { + EmbeddedChannel ch = authenticatedChannel(); // asserts StartupMessage decodes + + byte[] sql = "SELECT 1".getBytes(StandardCharsets.UTF_8); + ByteBuf frame = Unpooled.buffer(); + frame.writeByte(PgMessage.QUERY); + frame.writeInt(sql.length + 1 + 4); // body = sql + NUL, plus length field + frame.writeBytes(sql); + frame.writeByte(0); + + assertTrue(ch.writeInbound(frame)); + Object msg = ch.readInbound(); + PgMessageDecoder.Query q = assertInstanceOf(PgMessageDecoder.Query.class, msg); + assertEquals("SELECT 1", q.sql()); + } +}