diff --git a/common/src/main/java/org/conscrypt/ConscryptEngine.java b/common/src/main/java/org/conscrypt/ConscryptEngine.java index 818fa93cc..c573ef73d 100644 --- a/common/src/main/java/org/conscrypt/ConscryptEngine.java +++ b/common/src/main/java/org/conscrypt/ConscryptEngine.java @@ -753,8 +753,8 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe checkPositionIndexes(dstsOffset, dstsOffset + dstsLength, dsts.length); // Determine the output capacity. - final int dstLength = calcDstsLength(dsts, dstsOffset, dstsLength); - final int endOffset = dstsOffset + dstsLength; + final int dstsEndOffset = dstsOffset + dstsLength; + final long dstLength = calcDstsLength(dsts, dstsOffset, dstsEndOffset); final int srcsEndOffset = srcsOffset + srcsLength; final long srcLength = calcSrcsLength(srcs, srcsOffset, srcsEndOffset); @@ -863,7 +863,7 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe try { if (dstLength > 0) { // Write decrypted data to dsts buffers - for (int idx = dstsOffset; idx < endOffset; ++idx) { + for (int idx = dstsOffset; idx < dstsEndOffset; ++idx) { ByteBuffer dst = dsts[idx]; if (!dst.hasRemaining()) { continue; @@ -933,17 +933,15 @@ SSLEngineResult unwrap(final ByteBuffer[] srcs, int srcsOffset, final int srcsLe } } - private static int calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsLength) { - int capacity = 0; - for (int i = 0; i < dsts.length; i++) { + private static long calcDstsLength(ByteBuffer[] dsts, int dstsOffset, int dstsEndOffset) { + long capacity = 0; + for (int i = dstsOffset; i < dstsEndOffset; i++) { ByteBuffer dst = dsts[i]; checkArgument(dst != null, "dsts[%d] is null", i); if (dst.isReadOnly()) { throw new ReadOnlyBufferException(); } - if (i >= dstsOffset && i < dstsOffset + dstsLength) { - capacity += dst.remaining(); - } + capacity += dst.remaining(); } return capacity; } @@ -952,9 +950,7 @@ private static long calcSrcsLength(ByteBuffer[] srcs, int srcsOffset, int srcsEn long len = 0; for (int i = srcsOffset; i < srcsEndOffset; i++) { ByteBuffer src = srcs[i]; - if (src == null) { - throw new IllegalArgumentException("srcs[" + i + "] is null"); - } + checkArgument(src != null, "srcs[%d] is null", i); len += src.remaining(); } return len; diff --git a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java index c49648955..06e95f94a 100644 --- a/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java +++ b/common/src/test/java/org/conscrypt/javax/net/ssl/SSLEngineTest.java @@ -22,9 +22,18 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import org.conscrypt.TestUtils; +import org.conscrypt.TestUtils.BufferType; +import org.conscrypt.java.security.StandardNames; +import org.conscrypt.java.security.TestKeyStore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + import java.io.IOException; import java.net.Socket; import java.nio.ByteBuffer; @@ -35,6 +44,7 @@ import java.util.HashSet; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; + import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import javax.net.ssl.KeyManager; @@ -49,13 +59,6 @@ import javax.net.ssl.SSLSession; import javax.net.ssl.X509ExtendedKeyManager; import javax.net.ssl.X509ExtendedTrustManager; -import org.conscrypt.TestUtils; -import org.conscrypt.TestUtils.BufferType; -import org.conscrypt.java.security.StandardNames; -import org.conscrypt.java.security.TestKeyStore; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class SSLEngineTest { @@ -928,150 +931,253 @@ public void test_SSLEngine_setSSLParameters() throws Exception { @Test public void wrapPreconditions() throws Exception { - ByteBuffer buffer = ByteBuffer.allocate(10); + ByteBuffer buffer = ByteBuffer.allocate(128); + ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); ByteBuffer[] buffers = new ByteBuffer[] { buffer, buffer, buffer }; ByteBuffer[] badBuffers = new ByteBuffer[] { buffer, buffer, null, buffer }; // Client/server mode not set => IllegalStateException - try { - newUnconnectedEngine().wrap(buffer, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().wrap(buffer, buffer)); + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().wrap(buffers, buffer)); + assertThrows(IllegalStateException.class, + () -> newUnconnectedEngine().wrap(buffers, 0, 1, buffer)); - try { - newUnconnectedEngine().wrap(buffers, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + // Read-only destination => ReadOnlyBufferException + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffer, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffers, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().wrap(buffers, 0, 1, readOnlyBuffer)); - try { - newUnconnectedEngine().wrap(buffers, 0, 1, buffer); - fail(); - } catch (IllegalStateException e) { - // Expected - } + // Null destination => IllegalArgumentException + assertThrows(IllegalArgumentException.class, () -> newConnectedEngine().wrap(buffer, null)); + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().wrap(buffers, null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(buffers, 0, 1, null)); - // Read-only destination => ReadOnlyBufferException - try { - newConnectedEngine().wrap(buffer, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Null source => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap((ByteBuffer) null, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap((ByteBuffer[]) null, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(null, 0, 1, buffer)); - try { - newConnectedEngine().wrap(buffers, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Null entries in buffer array => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(badBuffers, buffer)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().wrap(badBuffers, 0, badBuffers.length, buffer)); + // But not if they are outside the selected offset and length + newConnectedEngine().wrap(badBuffers, 0, 2, buffer); + newConnectedEngine().wrap(badBuffers, 3, 1, buffer); - try { - newConnectedEngine().wrap(buffers, 0, 1, buffer.asReadOnlyBuffer()); - fail(); - } catch (ReadOnlyBufferException e) { - // Expected - } + // Bad offset or length => IndexOutOfBoundsException + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().wrap(buffers, 0, buffers.length + 1, buffer)); + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().wrap(buffers, buffers.length, 1, buffer)); + } - // Null destination => IllegalArgumentException - try { - newConnectedEngine().wrap(buffer, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + @Test + public void unwrapPreconditions() throws Exception { + ByteBuffer buffer = ByteBuffer.allocate(128); + ByteBuffer readOnlyBuffer = buffer.asReadOnlyBuffer(); + ByteBuffer[] buffers = new ByteBuffer[] {buffer, buffer, buffer}; + ByteBuffer[] badBuffers = new ByteBuffer[] {buffer, buffer, null, buffer}; + ByteBuffer[] readOnlyBuffers = new ByteBuffer[] {buffer, readOnlyBuffer, buffer}; - try { - newConnectedEngine().wrap(buffers, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Client/server mode not set => IllegalStateException + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().unwrap(buffer, buffer)); + assertThrows( + IllegalStateException.class, () -> newUnconnectedEngine().unwrap(buffer, buffers)); + assertThrows(IllegalStateException.class, + () -> newUnconnectedEngine().unwrap(buffer, buffers, 0, 1)); - try { - newConnectedEngine().wrap(buffers, 0, 1, null); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Read-only destination => ReadOnlyBufferException + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().unwrap(buffer, readOnlyBuffer)); + assertThrows(ReadOnlyBufferException.class, + () -> newConnectedEngine().unwrap(buffer, readOnlyBuffers)); + assertThrows(ReadOnlyBufferException.class, + () + -> newConnectedEngine().unwrap( + buffer, readOnlyBuffers, 0, readOnlyBuffers.length)); + + // Null destination => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, (ByteBuffer) null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, (ByteBuffer[]) null)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, null, 0, 1)); // Null source => IllegalArgumentException - try { - newConnectedEngine().wrap((ByteBuffer) null, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().unwrap(null, buffer)); + assertThrows( + IllegalArgumentException.class, () -> newConnectedEngine().unwrap(null, buffers)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(null, buffers, 0, 1)); - try { - newConnectedEngine().wrap((ByteBuffer[]) null, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected + // Null entries in buffer array => IllegalArgumentException + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, badBuffers)); + assertThrows(IllegalArgumentException.class, + () -> newConnectedEngine().unwrap(buffer, badBuffers, 0, badBuffers.length)); + // But not if they are outside the selected offset and length + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + doUnwrap(pair, badBuffers, 0, 2); + doUnwrap(pair, badBuffers, 3, 1); } - try { - newConnectedEngine().wrap(null, 0, 1, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + // Bad offset or length => IndexOutOfBoundsException + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().unwrap(buffer, buffers, 0, buffers.length + 1)); + assertThrows(IndexOutOfBoundsException.class, + () -> newConnectedEngine().unwrap(buffer, buffers, badBuffers.length, 1)); + } - // Null entries in buffer array => IllegalArgumentException - try { - newConnectedEngine().wrap(badBuffers, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected - } + private void doUnwrap(TestSSLEnginePair pair, ByteBuffer[] dest, int offset, int length) + throws Exception { + int bufferSize = 128; + ByteBuffer src = ByteBuffer.allocate(bufferSize); + ByteBuffer tlsBuffer = ByteBuffer.allocate(bufferSize + 128); + tlsBuffer.clear(); + SSLEngineResult result = pair.client.wrap(src, tlsBuffer); + assertEquals(Status.OK, result.getStatus()); - try { - newConnectedEngine().wrap(badBuffers, 0, badBuffers.length, buffer); - fail(); - } catch (IllegalArgumentException e) { - // Expected + tlsBuffer.flip(); + for (int i = offset; i < offset + length; i++) { + dest[i].clear(); } + result = pair.server.unwrap(tlsBuffer, dest, offset, length); + assertEquals(Status.OK, result.getStatus()); + } - // Bad offset or length => IndexOutOfBoundsException - try { - newConnectedEngine().wrap(buffers, 0, 7, buffer); - fail(); - } catch (IndexOutOfBoundsException e) { - // Expected + @Test + public void bufferArrayOffsets_Wrap() throws Exception { + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + int dataSize = 1024; // Should be less than SSL3_RT_MAX_PLAIN_LENGTH + int bufferSize = 128; + int bufferArrayLength = dataSize / bufferSize; + int[] sizeArray = new int[bufferArrayLength]; + Arrays.fill(sizeArray, bufferSize); + ByteBuffer tlsBuffer = ByteBuffer.allocate(dataSize); + + for (BufferType bufferType : BufferType.values()) { + ByteBuffer[] sourceBuffers = bufferType.newRandomBuffers(sizeArray); + for (int offset = 0; offset < sourceBuffers.length; offset++) { + for (int length = 1; length < sourceBuffers.length - offset; length++) { + String statusMessage = + String.format("offset=%d, length=%d", offset, length); + // Reset source buffers (only some are emptied on each iteration) + for (ByteBuffer buffer : sourceBuffers) { + if (buffer.remaining() == 0) { + buffer.flip(); + } + assertEquals(bufferSize, buffer.remaining()); + } + // Make an array copy of what we expect to send, for later comparison + byte[] sourceBytes = copyDataFromBuffers(sourceBuffers, offset, length); + byte[] destinationBytes = new byte[sourceBytes.length]; + ByteBuffer destination = ByteBuffer.wrap(destinationBytes); + + // Encrypt from the selected buffers + tlsBuffer.clear(); + SSLEngineResult result = + pair.client.wrap(sourceBuffers, offset, length, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, sourceBytes.length, result.bytesConsumed()); + int produced = result.bytesProduced(); + + // Decrypt and compare + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destination); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, sourceBytes.length, result.bytesProduced()); + assertEquals(statusMessage, produced, result.bytesConsumed()); + assertArrayEquals(sourceBytes, destinationBytes); + } + } + } } } @Test - public void bufferArrayOffsets() throws Exception{ - TestSSLEnginePair pair = TestSSLEnginePair.create(); - ByteBuffer tlsBuffer = ByteBuffer.allocate(600); - int bufferSize = 100; - - for (BufferType bufferType : BufferType.values()) { - ByteBuffer[] sourceBuffers = bufferType.newRandomBuffers( - bufferSize, bufferSize, bufferSize, bufferSize, bufferSize); - for (int offset = 0; offset < sourceBuffers.length; offset++) { - for (int length = 1; length < sourceBuffers.length - offset; length++) { - // Reset source buffers - for (ByteBuffer buffer : sourceBuffers) { - if (buffer.remaining() == 0) { + public void bufferArrayOffsets_Unwrap() throws Exception { + try (TestSSLEnginePair pair = TestSSLEnginePair.create()) { + int dataSize = 1024; // Should be less than SSL3_RT_MAX_PLAIN_LENGTH - 1 + int bufferSize = 128; + int bufferArrayLength = dataSize / bufferSize; + int[] sizeArray = new int[bufferArrayLength]; + Arrays.fill(sizeArray, bufferSize); + + for (BufferType bufferType : BufferType.values()) { + ByteBuffer[] destBuffers = bufferType.newEmptyBuffers(sizeArray); + for (int offset = 0; offset < destBuffers.length; offset++) { + for (int length = 1; length < destBuffers.length - offset; length++) { + String statusMessage = + String.format("offset=%d, length=%d", offset, length); + // Reset all the destination buffers + for (ByteBuffer buffer : destBuffers) { + buffer.clear(); + } + int expectedSize = bufferSize * length; + ByteBuffer sourceData = bufferType.newRandomBuffer(expectedSize); + + // Encrypt enough data to exactly fill our selected buffers + ByteBuffer tlsBuffer = ByteBuffer.allocate(expectedSize + 128); + SSLEngineResult result = pair.client.wrap(sourceData, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, expectedSize, result.bytesConsumed()); + int produced = result.bytesProduced(); + + // Decrypt into our selected destination buffers + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destBuffers, offset, length); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, expectedSize, result.bytesProduced()); + assertEquals(statusMessage, produced, result.bytesConsumed()); + + // Copy data out and compare + for (ByteBuffer buffer : destBuffers) { buffer.flip(); } - assertEquals(bufferSize, buffer.remaining()); + byte[] decrypted = copyDataFromBuffers(destBuffers, 0, destBuffers.length); + byte[] expectedData = new byte[expectedSize]; + sourceData.flip(); + sourceData.get(expectedData); + assertArrayEquals(expectedData, decrypted); + + // Ensure destination capacity is no bigger than expected + // by sending more data than can fit in our selected buffers + int extraBytes = 32; + int overflowSize = expectedSize + extraBytes; + sourceData = bufferType.newRandomBuffer(overflowSize); + tlsBuffer.clear(); + result = pair.client.wrap(sourceData, tlsBuffer); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, overflowSize, result.bytesConsumed()); + + for (ByteBuffer buffer : destBuffers) { + buffer.clear(); + } + tlsBuffer.flip(); + result = pair.server.unwrap(tlsBuffer, destBuffers, offset, length); + assertEquals(statusMessage, Status.BUFFER_OVERFLOW, result.getStatus()); + + // Discard the rest of the data ready for the next iteration + ByteBuffer discard = ByteBuffer.allocate(extraBytes); + result = pair.server.unwrap(tlsBuffer, discard); + assertEquals(statusMessage, Status.OK, result.getStatus()); + assertEquals(statusMessage, extraBytes, result.bytesProduced()); } - // Make an array copy of what we expect to send - byte[] sourceBytes = copyDataFromBuffers(sourceBuffers, offset, length); - byte[] destinationBytes = new byte[sourceBytes.length]; - ByteBuffer destination = ByteBuffer.wrap(destinationBytes); - // Send and compare - tlsBuffer.clear(); - pair.client.wrap(sourceBuffers, offset, length, tlsBuffer); - tlsBuffer.flip(); - pair.server.unwrap(tlsBuffer, destination); - assertArrayEquals(sourceBytes, destinationBytes); } } } diff --git a/testing/src/main/java/org/conscrypt/TestUtils.java b/testing/src/main/java/org/conscrypt/TestUtils.java index c239c0aa3..6efe0368c 100644 --- a/testing/src/main/java/org/conscrypt/TestUtils.java +++ b/testing/src/main/java/org/conscrypt/TestUtils.java @@ -109,6 +109,15 @@ ByteBuffer newBuffer(int size) { private static final Random random = new Random(System.currentTimeMillis()); abstract ByteBuffer newBuffer(int size); + public ByteBuffer[] newEmptyBuffers(int... sizes) { + int numBuffers = sizes.length; + ByteBuffer[] result = new ByteBuffer[numBuffers]; + for (int i = 0; i < numBuffers; i++) { + result[i] = newBuffer(sizes[i]); + } + return result; + } + public ByteBuffer[] newRandomBuffers(int... sizes) { int numBuffers = sizes.length; ByteBuffer[] result = new ByteBuffer[numBuffers];