From 35376ceb082790fa77701169e068396486a28c8f Mon Sep 17 00:00:00 2001 From: Pete Bentley Date: Sat, 11 Oct 2025 16:58:37 +0100 Subject: [PATCH] Fix null checks for engine unwrap buffer lists. Fixes #1372 I verified on OpenJDK 8, 11 and 21 that this is the RI expected behaviour, i.e. a null entry in a ByteBuffer array that is outside the selected range should *not* cause wrap or unwrap to fail. Also made the unwrap check methods for srcs and dests more aligned with each other, although ultimately this should probably be more like the wrap approach where we simply take a copy of the selected ByteBuffer array range. Now we have better tests we can do that in a future CL. This is a slight behaviour change on all platforms (including Mainline) but as there were no tests for it, it won't break any CI, and as it is making the checks slightly more lenient it seems very unlikely to break any application code. No impact on SSLSocket implementation as that only ever uses a single buffer. I added a preconditions check for unwrap along the lines of the one for wrap (and converted to assertThrows). Also because my spidey sense was tingling and I wasn't sure that this miscount could ever cause unwrap to produce more data than it was meant to I added an array offsets test along the lines of the one for wrap, only a bit more involved as it first checks that the exact amount of data flows as well as returning BUFFER_UNDERFLOW if you try to unwrap more data than is supposed to fit. Didn't find any bugs but it's still a good regression test. --- .../java/org/conscrypt/ConscryptEngine.java | 20 +- .../javax/net/ssl/SSLEngineTest.java | 352 ++++++++++++------ .../main/java/org/conscrypt/TestUtils.java | 9 + 3 files changed, 246 insertions(+), 135 deletions(-) diff --git a/common/src/main/java/org/conscrypt/ConscryptEngine.java b/common/src/main/java/org/conscrypt/ConscryptEngine.java index 818fa93cc..c573ef73d 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngine.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngine.java @@ -753,8 +753,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe checkPositionIndexes(dstsOffset, dstsOffset + dstsLength, dsts.length); // Determine the output capacity. - final int dstLength = calcDstsLength(dsts, dstsOffset, dstsLength); - final int endOffset = dstsOffset + dstsLength; + final int dstsEndOffset = dstsOffset + dstsLength; + final long dstLength = calcDstsLength(dsts, dstsOffset, dstsEndOffset); final int srcsEndOffset = srcsOffset + srcsLength; final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset); @@ -863,7 +863,7 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe try { if (dstLength > 0) { // Write decrypted data to dsts buffers - for (int idx = dstsOffset; idx < endOffset; ++idx) { + for (int idx = dstsOffset; idx < dstsEndOffset; ++idx) { ByteBuffer dst = dsts[idx]; if (!dst.hasRemaining()) { continue; @@ -933,17 +933,15 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe } } - private static int calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsLength) { - int capacity = 0; - for (int i = 0; i < dsts.length; i++) { + private static long calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsEndOffset) { + long capacity = 0; + for (int i = dstsOffset; i < dstsEndOffset; i++) { ByteBuffer dst = dsts[i]; checkArgument(dst != null, "dsts[%d] is null", i); if (dst.isReadOnly()) { throw new ReadOnlyBufferException(); } - if (i >= dstsOffset && i < dstsOffset + dstsLength) { - capacity += dst.remaining(); - } + capacity += dst.remaining(); } return capacity; } @@ -952,9 +950,7 @@ private static long calcSrcsLength(ByteBuffer[] srcs, int srcsOffset, int srcsEn long len = 0; for (int i = srcsOffset; i < srcsEndOffset; i++) { ByteBuffer src = srcs[i]; - if (src == null) { - throw new IllegalArgumentException("srcs[" + i + "] is null"); - } + checkArgument(src != null, "srcs[%d] is null", i); len += src.remaining(); } return len; diff --git a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java index c49648955..06e95f94a 100644 --- a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java +++ b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java @@ -22,9 +22,18 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import org.conscrypt.TestUtils; +import org.conscrypt.TestUtils.BufferType; +import org.conscrypt.java.security.StandardNames; +import org.conscrypt.java.security.TestKeyStore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + import java.io.IOException; import java.net.Socket; import java.nio.ByteBuffer; @@ -35,6 +44,7 @@ import java.util.HashSet; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; + import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import javax.net.ssl.KeyManager; @@ -49,13 +59,6 @@ import javax.net.ssl.SSLSession; import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedTrustManager; -import org.conscrypt.TestUtils; -import org.conscrypt.TestUtils.BufferType; -import org.conscrypt.java.security.StandardNames; -import org.conscrypt.java.security.TestKeyStore; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SSLEngineTest { @@ -928,150 +931,253 @@ public void test_SSLEngine_setSSLParameters() throws Exception { @Test public void wrapPreconditions() throws Exception { - ByteBuffer buffer = ByteBuffer.allocate(10); + ByteBuffer buffer = ByteBuffer.allocate(128); + ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); ByteBuffer[] buffers = new ByteBuffer[] { buffer, buffer, buffer }; ByteBuffer[] badBuffers = new ByteBuffer[] { buffer, buffer, null, buffer }; // Client/server mode not set => IllegalStateException - try { - newUnconnectedEngine().wrap(buffer, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().wrap(buffer, buffer)); + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().wrap(buffers, buffer)); + assertThrows(IllegalStateException.class, + () -> newUnconnectedEngine().wrap(buffers, 0, 1, buffer)); - try { - newUnconnectedEngine().wrap(buffers, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + // Read-only destination => ReadOnlyBufferException + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffer, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffers, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffers, 0, 1, readOnlyBuffer)); - try { - newUnconnectedEngine().wrap(buffers, 0, 1, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + // Null destination => IllegalArgumentException + assertThrows(IllegalArgumentException.class, () -> newConnectedEngine().wrap(buffer, null)); + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().wrap(buffers, null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(buffers, 0, 1, null)); - // Read-only destination => ReadOnlyBufferException - try { - newConnectedEngine().wrap(buffer, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Null source => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap((ByteBuffer) null, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap((ByteBuffer[]) null, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(null, 0, 1, buffer)); - try { - newConnectedEngine().wrap(buffers, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Null entries in buffer array => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(badBuffers, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(badBuffers, 0, badBuffers.length, buffer)); + // But not if they are outside the selected offset and length + newConnectedEngine().wrap(badBuffers, 0, 2, buffer); + newConnectedEngine().wrap(badBuffers, 3, 1, buffer); - try { - newConnectedEngine().wrap(buffers, 0, 1, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Bad offset or length => IndexOutOfBoundsException + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().wrap(buffers, 0, buffers.length + 1, buffer)); + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().wrap(buffers, buffers.length, 1, buffer)); + } - // Null destination => IllegalArgumentException - try { - newConnectedEngine().wrap(buffer, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + @Test + public void unwrapPreconditions() throws Exception { + ByteBuffer buffer = ByteBuffer.allocate(128); + ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); + ByteBuffer[] buffers = new ByteBuffer[] {buffer, buffer, buffer}; + ByteBuffer[] badBuffers = new ByteBuffer[] {buffer, buffer, null, buffer}; + ByteBuffer[] readOnlyBuffers = new ByteBuffer[] {buffer, readOnlyBuffer, buffer}; - try { - newConnectedEngine().wrap(buffers, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Client/server mode not set => IllegalStateException + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().unwrap(buffer, buffer)); + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().unwrap(buffer, buffers)); + assertThrows(IllegalStateException.class, + () -> newUnconnectedEngine().unwrap(buffer, buffers, 0, 1)); - try { - newConnectedEngine().wrap(buffers, 0, 1, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Read-only destination => ReadOnlyBufferException + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().unwrap(buffer, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().unwrap(buffer, readOnlyBuffers)); + assertThrows(ReadOnlyBufferException.class, + () + -> newConnectedEngine().unwrap( + buffer, readOnlyBuffers, 0, readOnlyBuffers.length)); + + // Null destination => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, (ByteBuffer) null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, (ByteBuffer[]) null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, null, 0, 1)); // Null source => IllegalArgumentException - try { - newConnectedEngine().wrap((ByteBuffer) null, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().unwrap(null, buffer)); + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().unwrap(null, buffers)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(null, buffers, 0, 1)); - try { - newConnectedEngine().wrap((ByteBuffer[]) null, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected + // Null entries in buffer array => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, badBuffers)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, badBuffers, 0, badBuffers.length)); + // But not if they are outside the selected offset and length + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + doUnwrap(pair, badBuffers, 0, 2); + doUnwrap(pair, badBuffers, 3, 1); } - try { - newConnectedEngine().wrap(null, 0, 1, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Bad offset or length => IndexOutOfBoundsException + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().unwrap(buffer, buffers, 0, buffers.length + 1)); + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().unwrap(buffer, buffers, badBuffers.length, 1)); + } - // Null entries in buffer array => IllegalArgumentException - try { - newConnectedEngine().wrap(badBuffers, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + private void doUnwrap(TestSSLEnginePair pair, ByteBuffer[] dest, int offset, int length) + throws Exception { + int bufferSize = 128; + ByteBuffer src = ByteBuffer.allocate(bufferSize); + ByteBuffer tlsBuffer = ByteBuffer.allocate(bufferSize + 128); + tlsBuffer.clear(); + SSLEngineResult result = pair.client.wrap(src, tlsBuffer); + assertEquals(Status.OK, result.getStatus()); - try { - newConnectedEngine().wrap(badBuffers, 0, badBuffers.length, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected + tlsBuffer.flip(); + for (int i = offset; i < offset + length; i++) { + dest[i].clear(); } + result = pair.server.unwrap(tlsBuffer, dest, offset, length); + assertEquals(Status.OK, result.getStatus()); + } - // Bad offset or length => IndexOutOfBoundsException - try { - newConnectedEngine().wrap(buffers, 0, 7, buffer); - fail(); - } catch (IndexOutOfBoundsException e) { - // Expected + @Test + public void bufferArrayOffsets_Wrap() throws Exception { + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + int dataSize = 1024; // Should be less than SSL3_RT_MAX_PLAIN_LENGTH + int bufferSize = 128; + int bufferArrayLength = dataSize / bufferSize; + int[] sizeArray = new int[bufferArrayLength]; + Arrays.fill(sizeArray, bufferSize); + ByteBuffer tlsBuffer = ByteBuffer.allocate(dataSize); + + for (BufferType bufferType : BufferType.values()) { + ByteBuffer[] sourceBuffers = bufferType.newRandomBuffers(sizeArray); + for (int offset = 0; offset < sourceBuffers.length; offset++) { + for (int length = 1; length < sourceBuffers.length - offset; length++) { + String statusMessage = + String.format("offset=%d, length=%d", offset, length); + // Reset source buffers (only some are emptied on each iteration) + for (ByteBuffer buffer : sourceBuffers) { + if (buffer.remaining() == 0) { + buffer.flip(); + } + assertEquals(bufferSize, buffer.remaining()); + } + // Make an array copy of what we expect to send, for later comparison + byte[] sourceBytes = copyDataFromBuffers(sourceBuffers, offset, length); + byte[] destinationBytes = new byte[sourceBytes.length]; + ByteBuffer destination = ByteBuffer.wrap(destinationBytes); + + // Encrypt from the selected buffers + tlsBuffer.clear(); + SSLEngineResult result = + pair.client.wrap(sourceBuffers, offset, length, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, sourceBytes.length, result.bytesConsumed()); + int produced = result.bytesProduced(); + + // Decrypt and compare + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destination); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, sourceBytes.length, result.bytesProduced()); + assertEquals(statusMessage, produced, result.bytesConsumed()); + assertArrayEquals(sourceBytes, destinationBytes); + } + } + } } } @Test - public void bufferArrayOffsets() throws Exception{ - TestSSLEnginePair pair = TestSSLEnginePair.create(); - ByteBuffer tlsBuffer = ByteBuffer.allocate(600); - int bufferSize = 100; - - for (BufferType bufferType : BufferType.values()) { - ByteBuffer[] sourceBuffers = bufferType.newRandomBuffers( - bufferSize, bufferSize, bufferSize, bufferSize, bufferSize); - for (int offset = 0; offset < sourceBuffers.length; offset++) { - for (int length = 1; length < sourceBuffers.length - offset; length++) { - // Reset source buffers - for (ByteBuffer buffer : sourceBuffers) { - if (buffer.remaining() == 0) { + public void bufferArrayOffsets_Unwrap() throws Exception { + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + int dataSize = 1024; // Should be less than SSL3_RT_MAX_PLAIN_LENGTH - 1 + int bufferSize = 128; + int bufferArrayLength = dataSize / bufferSize; + int[] sizeArray = new int[bufferArrayLength]; + Arrays.fill(sizeArray, bufferSize); + + for (BufferType bufferType : BufferType.values()) { + ByteBuffer[] destBuffers = bufferType.newEmptyBuffers(sizeArray); + for (int offset = 0; offset < destBuffers.length; offset++) { + for (int length = 1; length < destBuffers.length - offset; length++) { + String statusMessage = + String.format("offset=%d, length=%d", offset, length); + // Reset all the destination buffers + for (ByteBuffer buffer : destBuffers) { + buffer.clear(); + } + int expectedSize = bufferSize * length; + ByteBuffer sourceData = bufferType.newRandomBuffer(expectedSize); + + // Encrypt enough data to exactly fill our selected buffers + ByteBuffer tlsBuffer = ByteBuffer.allocate(expectedSize + 128); + SSLEngineResult result = pair.client.wrap(sourceData, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, expectedSize, result.bytesConsumed()); + int produced = result.bytesProduced(); + + // Decrypt into our selected destination buffers + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destBuffers, offset, length); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, expectedSize, result.bytesProduced()); + assertEquals(statusMessage, produced, result.bytesConsumed()); + + // Copy data out and compare + for (ByteBuffer buffer : destBuffers) { buffer.flip(); } - assertEquals(bufferSize, buffer.remaining()); + byte[] decrypted = copyDataFromBuffers(destBuffers, 0, destBuffers.length); + byte[] expectedData = new byte[expectedSize]; + sourceData.flip(); + sourceData.get(expectedData); + assertArrayEquals(expectedData, decrypted); + + // Ensure destination capacity is no bigger than expected + // by sending more data than can fit in our selected buffers + int extraBytes = 32; + int overflowSize = expectedSize + extraBytes; + sourceData = bufferType.newRandomBuffer(overflowSize); + tlsBuffer.clear(); + result = pair.client.wrap(sourceData, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, overflowSize, result.bytesConsumed()); + + for (ByteBuffer buffer : destBuffers) { + buffer.clear(); + } + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destBuffers, offset, length); + assertEquals(statusMessage, Status.BUFFER_OVERFLOW, result.getStatus()); + + // Discard the rest of the data ready for the next iteration + ByteBuffer discard = ByteBuffer.allocate(extraBytes); + result = pair.server.unwrap(tlsBuffer, discard); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, extraBytes, result.bytesProduced()); } - // Make an array copy of what we expect to send - byte[] sourceBytes = copyDataFromBuffers(sourceBuffers, offset, length); - byte[] destinationBytes = new byte[sourceBytes.length]; - ByteBuffer destination = ByteBuffer.wrap(destinationBytes); - // Send and compare - tlsBuffer.clear(); - pair.client.wrap(sourceBuffers, offset, length, tlsBuffer); - tlsBuffer.flip(); - pair.server.unwrap(tlsBuffer, destination); - assertArrayEquals(sourceBytes, destinationBytes); } } } diff --git a/testing/src/main/java/org/conscrypt/TestUtils.java b/testing/src/main/java/org/conscrypt/TestUtils.java index c239c0aa3..6efe0368c 100644 --- a/testing/src/main/java/org/conscrypt/TestUtils.java +++ b/testing/src/main/java/org/conscrypt/TestUtils.java @@ -109,6 +109,15 @@ ByteBuffer newBuffer(int size) { private static final Random random = new Random(System.currentTimeMillis()); abstract ByteBuffer newBuffer(int size); + public ByteBuffer[] newEmptyBuffers(int... sizes) { + int numBuffers = sizes.length; + ByteBuffer[] result = new ByteBuffer[numBuffers]; + for (int i = 0; i < numBuffers; i++) { + result[i] = newBuffer(sizes[i]); + } + return result; + } + public ByteBuffer[] newRandomBuffers(int... sizes) { int numBuffers = sizes.length; ByteBuffer[] result = new ByteBuffer[numBuffers];