Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion convex-db/src/main/java/convex/db/psql/PgMessageDecoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.CorruptedFrameException;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
Expand All @@ -14,6 +15,9 @@
*/
public class PgMessageDecoder extends ByteToMessageDecoder {

/** Maximum accepted length of a single Postgres protocol message. */
private static final int MAX_MESSAGE_LENGTH = 32 * 1024 * 1024; // 32 MB

private boolean startupComplete = false;

@Override
Expand All @@ -36,6 +40,9 @@ private void decodeStartup(ByteBuf in, List<Object> out) {
in.markReaderIndex();
int length = in.readInt();

if (length < 8 || length > MAX_MESSAGE_LENGTH)
throw new CorruptedFrameException("Invalid startup message length: " + length);

if (in.readableBytes() < length - 4) {
in.resetReaderIndex();
return; // Wait for full message
Expand Down Expand Up @@ -87,6 +94,9 @@ private void decodeRegular(ByteBuf in, List<Object> out) {
byte type = in.readByte();
int length = in.readInt();

if (length < 4 || length > MAX_MESSAGE_LENGTH)
throw new CorruptedFrameException("Invalid message length: " + length);

if (in.readableBytes() < length - 4) {
in.resetReaderIndex();
return; // Wait for full message
Expand All @@ -101,6 +111,7 @@ private void decodeRegular(ByteBuf in, List<Object> out) {
String name = readCString(in);
String query = readCString(in);
short paramCount = in.readShort();
if (paramCount < 0) throw new CorruptedFrameException("Negative paramCount: " + paramCount);
int[] paramTypes = new int[paramCount];
for (int i = 0; i < paramCount; i++) {
paramTypes[i] = in.readInt();
Expand All @@ -113,26 +124,31 @@ private void decodeRegular(ByteBuf in, List<Object> out) {

// Parameter format codes
short numParamFormats = in.readShort();
if (numParamFormats < 0) throw new CorruptedFrameException("Negative numParamFormats: " + numParamFormats);
short[] paramFormats = new short[numParamFormats];
for (int i = 0; i < numParamFormats; i++) {
paramFormats[i] = in.readShort();
}

// Parameter values
short numParams = in.readShort();
if (numParams < 0) throw new CorruptedFrameException("Negative numParams: " + numParams);
byte[][] paramValues = new byte[numParams][];
for (int i = 0; i < numParams; i++) {
int paramLen = in.readInt();
if (paramLen == -1) {
paramValues[i] = null; // NULL
} else {
if (paramLen < 0 || paramLen > MAX_MESSAGE_LENGTH)
throw new CorruptedFrameException("Invalid paramLen: " + paramLen);
paramValues[i] = new byte[paramLen];
in.readBytes(paramValues[i]);
Comment on lines +142 to 145
}
}

// Result format codes
short numResultFormats = in.readShort();
if (numResultFormats < 0) throw new CorruptedFrameException("Negative numResultFormats: " + numResultFormats);
short[] resultFormats = new short[numResultFormats];
for (int i = 0; i < numResultFormats; i++) {
resultFormats[i] = in.readShort();
Expand Down Expand Up @@ -171,10 +187,12 @@ private void decodeRegular(ByteBuf in, List<Object> out) {

private String readCString(ByteBuf buf) {
int start = buf.readerIndex();
int limit = buf.writerIndex();
int end = start;
while (buf.getByte(end) != 0) {
while (end < limit && buf.getByte(end) != 0) {
end++;
Comment on lines 189 to 193
}
if (end >= limit) throw new CorruptedFrameException("Unterminated C-string");
byte[] bytes = new byte[end - start];
buf.readBytes(bytes);
buf.readByte(); // consume null terminator
Expand Down
179 changes: 179 additions & 0 deletions convex-db/src/test/java/convex/db/psql/PgMessageDecoderTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package convex.db.psql;

import static org.junit.jupiter.api.Assertions.*;

import java.nio.charset.StandardCharsets;

import org.junit.jupiter.api.Test;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.CorruptedFrameException;

/**
* Unit tests for {@link PgMessageDecoder}, exercising malformed / hostile frames
* against the raw decoder via a Netty {@link EmbeddedChannel}.
*
* <p>These cover pre-authentication denial-of-service and crash inputs: the decoder
* runs on untrusted bytes before any authentication step, so it must reject
* malformed length/count fields with a {@link CorruptedFrameException} rather than
* buffering unbounded memory or throwing {@code NegativeArraySizeException} /
* {@code IndexOutOfBoundsException}.
*/
public class PgMessageDecoderTest {

private static final int PROTOCOL_3_0 = (3 << 16); // major 3, minor 0

/** Writes a NUL-terminated C-string. */
private static void putCString(ByteBuf buf, String s) {
buf.writeBytes(s.getBytes(StandardCharsets.UTF_8));
buf.writeByte(0);
}

/** Builds a well-formed startup message frame (length prefix + version + params). */
private static ByteBuf startupFrame() {
ByteBuf body = Unpooled.buffer();
body.writeInt(PROTOCOL_3_0);
putCString(body, "user");
putCString(body, "test");
body.writeByte(0); // empty key terminates parameter list

ByteBuf frame = Unpooled.buffer();
frame.writeInt(body.readableBytes() + 4); // length includes the length field
frame.writeBytes(body);
return frame;
}

/** Feeds a valid startup so the decoder flips into regular-message mode. */
private static EmbeddedChannel authenticatedChannel() {
EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder());
assertTrue(ch.writeInbound(startupFrame()));
Object msg = ch.readInbound();
assertInstanceOf(PgMessageDecoder.StartupMessage.class, msg);
return ch;
}

private static void assertRejected(EmbeddedChannel ch, ByteBuf frame) {
Throwable t = assertThrows(Throwable.class, () -> ch.writeInbound(frame));
// CorruptedFrameException is a DecoderException; ByteToMessageDecoder rethrows it as-is.
boolean found = false;
for (Throwable c = t; c != null; c = c.getCause()) {
if (c instanceof CorruptedFrameException) { found = true; break; }
}
assertTrue(found, "Expected CorruptedFrameException, got: " + t);
}

@Test
public void testOversizedStartupLength() {
EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder());
ByteBuf frame = Unpooled.buffer();
frame.writeInt(Integer.MAX_VALUE); // absurd length, must be rejected before buffering
frame.writeInt(PROTOCOL_3_0);
assertRejected(ch, frame);
}

@Test
public void testTooSmallStartupLength() {
EmbeddedChannel ch = new EmbeddedChannel(new PgMessageDecoder());
ByteBuf frame = Unpooled.buffer();
frame.writeInt(7); // below the 8-byte minimum
frame.writeInt(PROTOCOL_3_0);
assertRejected(ch, frame);
}

@Test
public void testOversizedRegularLength() {
EmbeddedChannel ch = authenticatedChannel();
ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.QUERY);
frame.writeInt(Integer.MAX_VALUE);
assertRejected(ch, frame);
}

@Test
public void testTooSmallRegularLength() {
EmbeddedChannel ch = authenticatedChannel();
ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.QUERY);
frame.writeInt(3); // below the 4-byte minimum
frame.writeByte('x');
assertRejected(ch, frame);
}

@Test
public void testBindNegativeNumParams() {
EmbeddedChannel ch = authenticatedChannel();
ByteBuf body = Unpooled.buffer();
putCString(body, ""); // portal
putCString(body, ""); // statement
body.writeShort(0); // numParamFormats
body.writeShort(-1); // numParams <-- hostile

ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.BIND);
frame.writeInt(body.readableBytes() + 4);
frame.writeBytes(body);
assertRejected(ch, frame);
}

@Test
public void testBindNegativeParamLen() {
EmbeddedChannel ch = authenticatedChannel();
ByteBuf body = Unpooled.buffer();
putCString(body, ""); // portal
putCString(body, ""); // statement
body.writeShort(0); // numParamFormats
body.writeShort(1); // numParams
body.writeInt(-2); // paramLen (-1 is NULL sentinel; -2 is hostile)

ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.BIND);
frame.writeInt(body.readableBytes() + 4);
frame.writeBytes(body);
assertRejected(ch, frame);
}

@Test
public void testParseNegativeParamCount() {
EmbeddedChannel ch = authenticatedChannel();
ByteBuf body = Unpooled.buffer();
putCString(body, ""); // statement name
putCString(body, ""); // query
body.writeShort(-1); // paramCount <-- hostile

ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.PARSE);
frame.writeInt(body.readableBytes() + 4);
frame.writeBytes(body);
assertRejected(ch, frame);
}

@Test
public void testUnterminatedCString() {
EmbeddedChannel ch = authenticatedChannel();
byte[] noNul = "SELECT 1".getBytes(StandardCharsets.UTF_8);
ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.QUERY);
frame.writeInt(noNul.length + 4); // valid length, but no NUL terminator in body
frame.writeBytes(noNul);
Comment on lines +155 to +159
assertRejected(ch, frame);
}

@Test
public void testWellFormedStartupAndQueryStillDecode() {
EmbeddedChannel ch = authenticatedChannel(); // asserts StartupMessage decodes

byte[] sql = "SELECT 1".getBytes(StandardCharsets.UTF_8);
ByteBuf frame = Unpooled.buffer();
frame.writeByte(PgMessage.QUERY);
frame.writeInt(sql.length + 1 + 4); // body = sql + NUL, plus length field
frame.writeBytes(sql);
frame.writeByte(0);

assertTrue(ch.writeInbound(frame));
Object msg = ch.readInbound();
PgMessageDecoder.Query q = assertInstanceOf(PgMessageDecoder.Query.class, msg);
assertEquals("SELECT 1", q.sql());
}
}