diff --git a/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloader.java index 771e229981d429..8b12dfa321872a 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloader.java @@ -14,6 +14,7 @@ package com.google.devtools.build.lib.remote; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture; import build.bazel.remote.execution.v2.Digest; @@ -27,11 +28,20 @@ import com.google.devtools.build.lib.remote.util.Utils; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.Nullable; -/** Downloads blobs by sequentially fetching chunks via the SplitBlob API. */ +/** Downloads blobs by fetching chunks through a per-blob sliding window via the SplitBlob API. */ public class ChunkedBlobDownloader { + // Guard against pathological fanout from a single large chunked blob. This is only a per-blob + // cap; chunk requests still flow through CombinedCache and the shared remote cache transport + // stack below it, which is what bounds active remote RPC concurrency across blobs. + private static final int MAX_IN_FLIGHT_CHUNK_DOWNLOADS = 16; + private final GrpcCacheClient grpcCacheClient; private final CombinedCache combinedCache; private final DigestUtil digestUtil; @@ -45,8 +55,8 @@ public ChunkedBlobDownloader( /** * Downloads a blob using chunked download via the SplitBlob API. This should be called with - * virtual threads, as it blocks on futures via {@link - * com.google.devtools.build.lib.remote.util.Utils#getFromFuture}. + * virtual threads, as it may block while waiting for chunk metadata and completed chunk + * downloads. */ public void downloadChunked( RemoteActionExecutionContext context, Digest blobDigest, OutputStream out) @@ -81,11 +91,135 @@ private List getChunkDigests(RemoteActionExecutionContext context, Diges return chunkDigests; } + private static final class PendingDownload { + private final Digest digest; + private final ListenableFuture future; + private final List chunkIndices = new ArrayList<>(1); + + PendingDownload(Digest digest, ListenableFuture future, int firstChunkIndex) { + this.digest = digest; + this.future = future; + chunkIndices.add(firstChunkIndex); + } + + void addChunkIndex(int chunkIndex) { + chunkIndices.add(chunkIndex); + } + + Digest digest() { + return digest; + } + + ListenableFuture future() { + return future; + } + + List chunkIndices() { + return chunkIndices; + } + } + private void downloadAndReassembleChunks( RemoteActionExecutionContext context, List chunkDigests, OutputStream out) throws IOException, InterruptedException { - for (Digest chunkDigest : chunkDigests) { - getFromFuture(combinedCache.downloadBlob(context, chunkDigest, out)); + new DownloadSession(context, chunkDigests, out).run(); + } + + private final class DownloadSession { + private final LinkedBlockingQueue completedDownloads = + new LinkedBlockingQueue<>(); + private final Map activeDownloads = + new HashMap<>(MAX_IN_FLIGHT_CHUNK_DOWNLOADS); + private final Map readyChunks = + new HashMap<>(MAX_IN_FLIGHT_CHUNK_DOWNLOADS); + private final RemoteActionExecutionContext context; + private final List chunkDigests; + private final OutputStream out; + private int nextToStart = 0; + private int nextToWrite = 0; + + DownloadSession( + RemoteActionExecutionContext context, List chunkDigests, OutputStream out) { + this.context = context; + this.chunkDigests = chunkDigests; + this.out = out; + } + + void run() throws IOException, InterruptedException { + try { + fillWindow(); + while (nextToWrite < chunkDigests.size()) { + drainCompletedDownloads(); + drainReadyChunks(); + fillWindow(); + } + } finally { + cancelAllDownloads(); + } + } + + private void fillWindow() { + while (nextToStart < chunkDigests.size()) { + if (nextToStart - nextToWrite >= MAX_IN_FLIGHT_CHUNK_DOWNLOADS) { + return; + } + Digest chunkDigest = chunkDigests.get(nextToStart); + PendingDownload existing = activeDownloads.get(chunkDigest); + if (existing != null) { + existing.addChunkIndex(nextToStart); + nextToStart++; + continue; + } + startDownload(chunkDigest, nextToStart); + nextToStart++; + } + } + + private void startDownload(Digest chunkDigest, int chunkIndex) { + PendingDownload download = + new PendingDownload( + chunkDigest, combinedCache.downloadBlob(context, chunkDigest), chunkIndex); + activeDownloads.put(chunkDigest, download); + download.future().addListener(() -> completedDownloads.add(download), directExecutor()); + } + + private void drainCompletedDownloads() throws IOException, InterruptedException { + PendingDownload download = completedDownloads.take(); + do { + processCompletedDownload(download); + download = completedDownloads.poll(); + } while (download != null); + } + + private void processCompletedDownload(PendingDownload download) + throws IOException, InterruptedException { + activeDownloads.remove(download.digest()); + byte[] chunkData = getFromFuture(download.future()); + for (int chunkIndex : download.chunkIndices()) { + if (chunkIndex == nextToWrite) { + out.write(chunkData); + nextToWrite++; + } else { + readyChunks.put(chunkIndex, chunkData); + } + } + } + + private void drainReadyChunks() throws IOException { + while (true) { + byte[] chunk = readyChunks.remove(nextToWrite); + if (chunk == null) { + return; + } + out.write(chunk); + nextToWrite++; + } + } + + private void cancelAllDownloads() { + for (PendingDownload download : activeDownloads.values()) { + download.future().cancel(/* mayInterruptIfRunning= */ true); + } } } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobUploader.java index 6cf21eaa09ebd4..836334fe40937c 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobUploader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ChunkedBlobUploader.java @@ -14,22 +14,28 @@ package com.google.devtools.build.lib.remote; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture; import build.bazel.remote.execution.v2.Digest; import com.google.common.collect.ImmutableSet; import com.google.common.io.ByteStreams; +import com.google.common.util.concurrent.ListenableFuture; import com.google.devtools.build.lib.remote.chunking.ChunkingConfig; import com.google.devtools.build.lib.remote.chunking.FastCdcChunker; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.vfs.Path; -import com.google.protobuf.ByteString; +import java.io.EOFException; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.channels.FileChannel; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; /** * Uploads blobs in chunks using Content-Defined Chunking with FastCDC 2020. @@ -44,6 +50,10 @@ * */ public class ChunkedBlobUploader { + // Guard against pathological fanout from a single large chunked blob. This is only a per-blob + // cap; chunk uploads still flow through CombinedCache and the shared remote cache transport + // stack below it, which is what bounds active remote RPC concurrency across blobs. + private static final int MAX_IN_FLIGHT_CHUNK_UPLOADS = 16; private final GrpcCacheClient grpcCacheClient; private final CombinedCache combinedCache; @@ -104,18 +114,139 @@ private void uploadMissingChunks( if (missingDigests.isEmpty()) { return; } + new UploadSession(context, missingDigests, chunkDigests).run(file); + } - Set uploaded = new HashSet<>(); - try (InputStream input = file.getInputStream()) { - for (Digest chunkDigest : chunkDigests) { - if (missingDigests.contains(chunkDigest) && uploaded.add(chunkDigest)) { - ByteString.Output out = ByteString.newOutput((int) chunkDigest.getSizeBytes()); - ByteStreams.limit(input, chunkDigest.getSizeBytes()).transferTo(out); - getFromFuture(combinedCache.uploadBlob(context, chunkDigest, out.toByteString())); - } else { - input.skipNBytes(chunkDigest.getSizeBytes()); + private final class UploadSession { + private final LinkedBlockingQueue> completedUploads = + new LinkedBlockingQueue<>(); + private final Set> inFlightUploads = + new HashSet<>(MAX_IN_FLIGHT_CHUNK_UPLOADS); + private final Set scheduledDigests = new HashSet<>(); + private final RemoteActionExecutionContext context; + private final ImmutableSet missingDigests; + private final List chunkDigests; + + UploadSession( + RemoteActionExecutionContext context, + ImmutableSet missingDigests, + List chunkDigests) { + this.context = context; + this.missingDigests = missingDigests; + this.chunkDigests = chunkDigests; + } + + void run(Path file) throws IOException, InterruptedException { + try { + long offset = 0; + for (Digest chunkDigest : chunkDigests) { + drainCompletedUploads(); + long chunkOffset = offset; + offset += chunkDigest.getSizeBytes(); + if (!shouldScheduleUpload(chunkDigest)) { + continue; + } + if (inFlightUploads.size() >= MAX_IN_FLIGHT_CHUNK_UPLOADS) { + awaitCompletedUpload(); + } + startUpload(file, chunkOffset, chunkDigest); + } + while (!inFlightUploads.isEmpty()) { + awaitCompletedUpload(); + } + } finally { + cancelAllUploads(); + } + } + + private boolean shouldScheduleUpload(Digest chunkDigest) { + return missingDigests.contains(chunkDigest) && scheduledDigests.add(chunkDigest); + } + + private void startUpload(Path file, long chunkOffset, Digest chunkDigest) { + ListenableFuture upload = + combinedCache.uploadBlob( + context, chunkDigest, new ChunkBlob(file, chunkOffset, chunkDigest)); + inFlightUploads.add(upload); + upload.addListener(() -> completedUploads.add(upload), directExecutor()); + } + + private void drainCompletedUploads() throws IOException, InterruptedException { + while (true) { + ListenableFuture upload = completedUploads.poll(); + if (upload == null) { + return; + } + finishUpload(upload); + } + } + + private void awaitCompletedUpload() throws IOException, InterruptedException { + finishUpload(completedUploads.take()); + drainCompletedUploads(); + } + + private void finishUpload(ListenableFuture upload) + throws IOException, InterruptedException { + inFlightUploads.remove(upload); + getFromFuture(upload); + } + + private void cancelAllUploads() { + for (ListenableFuture upload : inFlightUploads) { + upload.cancel(/* mayInterruptIfRunning= */ true); + } + } + } + + private static final class ChunkBlob implements Blob { + private final Path file; + private final long offset; + private final Digest digest; + + private ChunkBlob(Path file, long offset, Digest digest) { + this.file = file; + this.offset = offset; + this.digest = digest; + } + + @Override + public InputStream get() throws IOException { + InputStream input = file.getInputStream(); + boolean success = false; + try { + seekOrSkip(input, offset); + InputStream limitedInput = ByteStreams.limit(input, digest.getSizeBytes()); + success = true; + return limitedInput; + } catch (EOFException e) { + throw new IOException("file was concurrently modified during upload: " + file, e); + } finally { + if (!success) { + input.close(); } } } + + @Override + public String description() { + return "chunk %s at offset %d of file %s" + .formatted(DigestUtil.toString(digest), offset, file); + } + } + + private static void seekOrSkip(InputStream input, long offset) throws IOException { + if (offset == 0) { + return; + } + if (input instanceof FileInputStream fileInputStream) { + FileChannel channel = fileInputStream.getChannel(); + if (channel.size() < offset) { + throw new EOFException(); + } + channel.position(offset); + return; + } + input.skipNBytes(offset); } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/CombinedCache.java b/src/main/java/com/google/devtools/build/lib/remote/CombinedCache.java index 7772fbebf72406..63641fd3a04409 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/CombinedCache.java +++ b/src/main/java/com/google/devtools/build/lib/remote/CombinedCache.java @@ -46,6 +46,7 @@ import com.google.devtools.build.lib.remote.common.ProgressStatusListener; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; import com.google.devtools.build.lib.remote.common.RemoteCacheClient; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; import com.google.devtools.build.lib.remote.disk.DiskCacheClient; import com.google.devtools.build.lib.remote.util.AsyncTaskCache; import com.google.devtools.build.lib.remote.util.DigestUtil; @@ -69,6 +70,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -364,11 +366,6 @@ public ListenableFuture uploadActionResult( */ public ListenableFuture uploadFile( RemoteActionExecutionContext context, Digest digest, Path file) { - return uploadFile(context, digest, file, /* force= */ false); - } - - protected ListenableFuture uploadFile( - RemoteActionExecutionContext context, Digest digest, Path file, boolean force) { if (digest.getSizeBytes() == 0) { return COMPLETED_SUCCESS; } @@ -388,19 +385,20 @@ protected ListenableFuture uploadFile( ListenableFuture remoteCacheFuture = Futures.immediateVoidFuture(); if (remoteCacheClient != null && context.getWriteCachePolicy().allowRemoteCache()) { if (chunkingSupported && digest.getSizeBytes() > chunking.config().chunkingThreshold()) { - remoteCacheFuture = - virtualThreadExecutor.submit( - () -> { - chunking.uploader().uploadChunked(context, digest, file); - return null; - }); + Completable upload = + casUploadCache.execute( + digest, + RxFutures.toCompletable( + () -> uploadChunked(context, digest, file), directExecutor()), + /* force= */ false); + remoteCacheFuture = RxFutures.toListenableFuture(upload); } else { Completable upload = casUploadCache.execute( digest, RxFutures.toCompletable( () -> remoteCacheClient.uploadFile(context, digest, file), directExecutor()), - force); + /* force= */ false); remoteCacheFuture = RxFutures.toListenableFuture(upload); } } @@ -409,30 +407,45 @@ protected ListenableFuture uploadFile( .call(() -> null, directExecutor()); } + private ListenableFuture uploadChunked( + RemoteActionExecutionContext context, Digest digest, Path file) { + return virtualThreadExecutor.submit( + () -> { + chunking.uploader().uploadChunked(context, digest, file); + return null; + }); + } + /** - * Upload sequence of bytes to the remote cache. + * Uploads a sequence of bytes to the cache. * *

Trying to upload the same BLOB multiple times concurrently, results in only one upload being * performed. * * @param context the context for the action. - * @param digest the digest of the file. + * @param digest the digest of the BLOB. * @param data the BLOB to upload. */ public ListenableFuture uploadBlob( RemoteActionExecutionContext context, Digest digest, ByteString data) { - return uploadBlob(context, digest, data, /* force= */ false); + return uploadBlob(context, digest, (Blob) data::newInput); } - protected ListenableFuture uploadBlob( - RemoteActionExecutionContext context, Digest digest, ByteString data, boolean force) { + /** + * Uploads a blob to the cache from a repeatable stream supplier. + * + *

The supplier may be opened more than once, including concurrently when both disk and remote + * cache writes are enabled. + */ + public ListenableFuture uploadBlob( + RemoteActionExecutionContext context, Digest digest, Blob blob) { if (digest.getSizeBytes() == 0) { return COMPLETED_SUCCESS; } ListenableFuture diskCacheFuture = Futures.immediateVoidFuture(); if (diskCacheClient != null && context.getWriteCachePolicy().allowDiskCache()) { - diskCacheFuture = diskCacheClient.uploadBlob(digest, data); + diskCacheFuture = diskCacheClient.uploadBlob(digest, blob); } ListenableFuture remoteCacheFuture = Futures.immediateVoidFuture(); @@ -441,8 +454,8 @@ protected ListenableFuture uploadBlob( casUploadCache.execute( digest, RxFutures.toCompletable( - () -> remoteCacheClient.uploadBlob(context, digest, data), directExecutor()), - force); + () -> remoteCacheClient.uploadBlob(context, digest, blob), directExecutor()), + /* force= */ false); remoteCacheFuture = RxFutures.toListenableFuture(upload); } @@ -806,6 +819,7 @@ protected void deallocate() { diskCacheClient.close(); } casUploadCache.shutdown(); + virtualThreadExecutor.shutdown(); if (remoteCacheClient != null) { remoteCacheClient.close(); } @@ -829,11 +843,13 @@ public CombinedCache retain() { public void awaitTermination() throws InterruptedException { casUploadCache.awaitTermination(); closeCountDownLatch.await(); + virtualThreadExecutor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS); } /** Shuts the cache down and cancels active network I/Os. */ public void shutdownNow() { casUploadCache.shutdownNow(); + virtualThreadExecutor.shutdownNow(); } public static FailureDetail createFailureDetail(String message, Code detailedCode) { diff --git a/src/main/java/com/google/devtools/build/lib/remote/disk/DiskCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/disk/DiskCacheClient.java index 5cdd86e25c80f7..138bad34eb40e2 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/disk/DiskCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/disk/DiskCacheClient.java @@ -33,6 +33,7 @@ import com.google.devtools.build.lib.remote.common.ActionKey; import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.common.MaybePathBacked; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.Utils; import com.google.devtools.build.lib.vfs.FileSystemUtils; @@ -277,9 +278,14 @@ public ListenableFuture uploadFile(Digest digest, Path file) { } public ListenableFuture uploadBlob(Digest digest, ByteString data) { + return uploadBlob(digest, (Blob) data::newInput); + } + + /** Uploads a blob from a stream supplier. */ + public ListenableFuture uploadBlob(Digest digest, Blob blob) { return executorService.submit( () -> { - try (InputStream in = data.newInput()) { + try (InputStream in = blob.get()) { saveFile(digest, Store.CAS, in); } return null; diff --git a/src/test/java/com/google/devtools/build/lib/remote/BUILD b/src/test/java/com/google/devtools/build/lib/remote/BUILD index 380f6a6bccc3b7..3677a61144156e 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/BUILD @@ -1,4 +1,5 @@ load("@rules_java//java:defs.bzl", "java_library", "java_test") +load("//src:java_opt_binary.bzl", "java_opt_binary") package( default_applicable_licenses = ["//:license"], @@ -65,6 +66,7 @@ java_library( "RemoteActionFileSystemTestBase.java", "BuildWithoutTheBytesIntegrationTest.java", "BuildWithoutTheBytesIntegrationTestBase.java", + "ChunkedTransferBenchmark.java", "ChunkedCacheIntegrationTest.java", "ChunkedDiskCacheIntegrationTest.java", "DiskCacheIntegrationTest.java", @@ -251,6 +253,27 @@ java_test( ], ) +java_opt_binary( + name = "ChunkedTransferBenchmark", + srcs = ["ChunkedTransferBenchmark.java"], + main_class = "org.openjdk.jmh.Main", + deps = [ + "@com_google_protobuf//java/core:lite_runtime_only", + "//src/main/java/com/google/devtools/build/lib/clock", + "//src/main/java/com/google/devtools/build/lib/remote:combined_cache", + "//src/main/java/com/google/devtools/build/lib/remote:grpc_cache_client", + "//src/main/java/com/google/devtools/build/lib/remote/chunking", + "//src/main/java/com/google/devtools/build/lib/remote/common", + "//src/main/java/com/google/devtools/build/lib/remote/util:digest_utils", + "//src/main/java/com/google/devtools/build/lib/vfs", + "//src/main/java/com/google/devtools/build/lib/vfs/inmemoryfs", + "//third_party:guava", + "//third_party:jmh", + "//third_party:mockito", + "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto", + ], +) + java_library( name = "build_without_the_bytes_integration_test_base", srcs = [ diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java index d33582e9d9d2ba..cb4cc3820df177 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamBuildEventArtifactUploaderTest.java @@ -54,6 +54,7 @@ import com.google.devtools.build.lib.remote.Retrier.ResultClassifier.Result; import com.google.devtools.build.lib.remote.common.MissingDigestsFinder; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; import com.google.devtools.build.lib.remote.options.RemoteBuildEventUploadMode; import com.google.devtools.build.lib.remote.options.RemoteOptions; import com.google.devtools.build.lib.remote.util.DigestUtil; @@ -67,6 +68,7 @@ import com.google.devtools.build.lib.vfs.bazel.BazelHashFunctions; import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; import com.google.devtools.common.options.Options; +import com.google.protobuf.ByteString; import io.grpc.Server; import io.grpc.Status; import io.grpc.inprocess.InProcessChannelBuilder; @@ -456,7 +458,8 @@ public void remoteFileShouldNotBeUploaded_actionFs() throws Exception { + "/" + digest.getSizeBytes()); verify(combinedCache, times(0)).uploadFile(any(), any(), any()); - verify(combinedCache, times(0)).uploadBlob(any(), any(), any()); + verify(combinedCache, times(0)).uploadBlob(any(), any(), any(ByteString.class)); + verify(combinedCache, times(0)).uploadBlob(any(), any(), any(Blob.class)); } @Test diff --git a/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloaderTest.java index c46a1bca7b32c1..00fb44dbf118da 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobDownloaderTest.java @@ -17,12 +17,14 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.SplitBlobResponse; import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.remote.common.CacheNotFoundException; import com.google.devtools.build.lib.remote.common.OutputDigestMismatchException; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; @@ -31,7 +33,10 @@ import com.google.devtools.build.lib.vfs.SyscallCache; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -46,6 +51,7 @@ public class ChunkedBlobDownloaderTest { private static final DigestUtil DIGEST_UTIL = new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256); + private static final int MAX_IN_FLIGHT_CHUNK_DOWNLOADS = 16; @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @@ -81,13 +87,8 @@ public void downloadChunked_singleChunk_downloadsAndReassembles() throws Excepti SplitBlobResponse.newBuilder().addChunkDigests(chunkDigest).build(); when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) .thenReturn(Futures.immediateFuture(splitResponse)); - when(combinedCache.downloadBlob(any(), eq(chunkDigest), any())) - .thenAnswer( - invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunkData); - return Futures.immediateFuture(null); - }); + when(combinedCache.downloadBlob(any(), eq(chunkDigest))) + .thenReturn(Futures.immediateFuture(chunkData)); ByteArrayOutputStream out = new ByteArrayOutputStream(); downloader.downloadChunked(context, blobDigest, out); @@ -113,35 +114,203 @@ public void downloadChunked_multipleChunks_downloadsAndReassemblesInOrder() thro .build(); when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) .thenReturn(Futures.immediateFuture(splitResponse)); - when(combinedCache.downloadBlob(any(), eq(chunk1Digest), any())) + when(combinedCache.downloadBlob(any(), eq(chunk1Digest))) + .thenReturn(Futures.immediateFuture(chunk1Data)); + when(combinedCache.downloadBlob(any(), eq(chunk2Digest))) + .thenReturn(Futures.immediateFuture(chunk2Data)); + when(combinedCache.downloadBlob(any(), eq(chunk3Digest))) + .thenReturn(Futures.immediateFuture(chunk3Data)); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + downloader.downloadChunked(context, blobDigest, out); + + assertThat(out.toByteArray()).isEqualTo(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); + verify(combinedCache).downloadBlob(any(), eq(chunk1Digest)); + verify(combinedCache).downloadBlob(any(), eq(chunk2Digest)); + verify(combinedCache).downloadBlob(any(), eq(chunk3Digest)); + } + + @Test + public void downloadChunked_windowRefillsAfterOneChunkCompletes() throws Exception { + List chunkDigests = new ArrayList<>(MAX_IN_FLIGHT_CHUNK_DOWNLOADS + 1); + List> chunkFutures = new ArrayList<>(MAX_IN_FLIGHT_CHUNK_DOWNLOADS + 1); + byte[] expectedData = new byte[MAX_IN_FLIGHT_CHUNK_DOWNLOADS + 1]; + SplitBlobResponse.Builder splitResponse = SplitBlobResponse.newBuilder(); + for (int i = 0; i < MAX_IN_FLIGHT_CHUNK_DOWNLOADS + 1; i++) { + byte[] chunkData = new byte[] {(byte) (i + 1)}; + expectedData[i] = chunkData[0]; + chunkDigests.add(DIGEST_UTIL.compute(chunkData)); + chunkFutures.add(SettableFuture.create()); + splitResponse.addChunkDigests(chunkDigests.get(i)); + } + Digest blobDigest = DIGEST_UTIL.compute(expectedData); + + when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) + .thenReturn(Futures.immediateFuture(splitResponse.build())); + + CountDownLatch firstWindowRequested = new CountDownLatch(MAX_IN_FLIGHT_CHUNK_DOWNLOADS); + CountDownLatch overflowChunkRequested = new CountDownLatch(1); + + when(combinedCache.downloadBlob(any(), any(Digest.class))) + .thenAnswer( + invocation -> { + Digest digest = invocation.getArgument(1); + int chunkIndex = chunkDigests.indexOf(digest); + if (chunkIndex < MAX_IN_FLIGHT_CHUNK_DOWNLOADS) { + firstWindowRequested.countDown(); + } else if (chunkIndex == MAX_IN_FLIGHT_CHUNK_DOWNLOADS) { + overflowChunkRequested.countDown(); + } + return chunkFutures.get(chunkIndex); + }); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Thread downloadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + downloader.downloadChunked(context, blobDigest, out); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + downloadThread.start(); + + assertThat(firstWindowRequested.await(1, TimeUnit.SECONDS)).isTrue(); + assertThat(overflowChunkRequested.await(100, TimeUnit.MILLISECONDS)).isFalse(); + + chunkFutures.get(0).set(new byte[] {expectedData[0]}); + assertThat(overflowChunkRequested.await(1, TimeUnit.SECONDS)).isTrue(); + + for (int i = 0; i < chunkFutures.size(); i++) { + SettableFuture future = chunkFutures.get(i); + if (!future.isDone()) { + future.set(new byte[] {expectedData[i]}); + } + } + downloadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(downloadThread.isAlive()).isFalse(); + assertThat(out.toByteArray()).isEqualTo(expectedData); + } + + @Test + public void downloadChunked_duplicateInFlightChunks_reusesDownload() throws Exception { + byte[] chunkData = new byte[] {1, 2, 3}; + Digest chunkDigest = DIGEST_UTIL.compute(chunkData); + Digest blobDigest = DIGEST_UTIL.compute(new byte[] {1, 2, 3, 1, 2, 3}); + + SplitBlobResponse splitResponse = + SplitBlobResponse.newBuilder() + .addChunkDigests(chunkDigest) + .addChunkDigests(chunkDigest) + .build(); + when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) + .thenReturn(Futures.immediateFuture(splitResponse)); + + SettableFuture chunkFuture = SettableFuture.create(); + when(combinedCache.downloadBlob(any(), eq(chunkDigest))).thenReturn(chunkFuture); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Thread downloadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + downloader.downloadChunked(context, blobDigest, out); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + downloadThread.start(); + + chunkFuture.set(chunkData); + downloadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(downloadThread.isAlive()).isFalse(); + assertThat(out.toByteArray()).isEqualTo(new byte[] {1, 2, 3, 1, 2, 3}); + verify(combinedCache, times(1)).downloadBlob(any(), eq(chunkDigest)); + } + + @Test + public void downloadChunked_longDuplicateRun_resumesAfterDrain() throws Exception { + byte[] firstChunkData = new byte[] {1}; + byte[] duplicateChunkData = new byte[] {2}; + byte[] finalChunkData = new byte[] {3}; + Digest firstChunkDigest = DIGEST_UTIL.compute(firstChunkData); + Digest duplicateChunkDigest = DIGEST_UTIL.compute(duplicateChunkData); + Digest finalChunkDigest = DIGEST_UTIL.compute(finalChunkData); + + byte[] blobData = new byte[MAX_IN_FLIGHT_CHUNK_DOWNLOADS + 1]; + blobData[0] = firstChunkData[0]; + for (int i = 1; i < MAX_IN_FLIGHT_CHUNK_DOWNLOADS; i++) { + blobData[i] = duplicateChunkData[0]; + } + blobData[MAX_IN_FLIGHT_CHUNK_DOWNLOADS] = finalChunkData[0]; + Digest blobDigest = DIGEST_UTIL.compute(blobData); + + SplitBlobResponse.Builder splitResponse = SplitBlobResponse.newBuilder(); + splitResponse.addChunkDigests(firstChunkDigest); + for (int i = 1; i < MAX_IN_FLIGHT_CHUNK_DOWNLOADS; i++) { + splitResponse.addChunkDigests(duplicateChunkDigest); + } + splitResponse.addChunkDigests(finalChunkDigest); + when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) + .thenReturn(Futures.immediateFuture(splitResponse.build())); + + SettableFuture firstChunkFuture = SettableFuture.create(); + SettableFuture duplicateChunkFuture = SettableFuture.create(); + SettableFuture finalChunkFuture = SettableFuture.create(); + CountDownLatch initialDownloadsRequested = new CountDownLatch(2); + CountDownLatch finalChunkRequested = new CountDownLatch(1); + + when(combinedCache.downloadBlob(any(), eq(firstChunkDigest))) .thenAnswer( invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunk1Data); - return Futures.immediateFuture(null); + initialDownloadsRequested.countDown(); + return firstChunkFuture; }); - when(combinedCache.downloadBlob(any(), eq(chunk2Digest), any())) + when(combinedCache.downloadBlob(any(), eq(duplicateChunkDigest))) .thenAnswer( invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunk2Data); - return Futures.immediateFuture(null); + initialDownloadsRequested.countDown(); + return duplicateChunkFuture; }); - when(combinedCache.downloadBlob(any(), eq(chunk3Digest), any())) + when(combinedCache.downloadBlob(any(), eq(finalChunkDigest))) .thenAnswer( invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunk3Data); - return Futures.immediateFuture(null); + finalChunkRequested.countDown(); + return finalChunkFuture; }); ByteArrayOutputStream out = new ByteArrayOutputStream(); - downloader.downloadChunked(context, blobDigest, out); - - assertThat(out.toByteArray()).isEqualTo(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); - verify(combinedCache).downloadBlob(any(), eq(chunk1Digest), any()); - verify(combinedCache).downloadBlob(any(), eq(chunk2Digest), any()); - verify(combinedCache).downloadBlob(any(), eq(chunk3Digest), any()); + Thread downloadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + downloader.downloadChunked(context, blobDigest, out); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + downloadThread.start(); + + assertThat(initialDownloadsRequested.await(1, TimeUnit.SECONDS)).isTrue(); + assertThat(finalChunkRequested.await(100, TimeUnit.MILLISECONDS)).isFalse(); + + duplicateChunkFuture.set(duplicateChunkData); + assertThat(finalChunkRequested.await(100, TimeUnit.MILLISECONDS)).isFalse(); + + firstChunkFuture.set(firstChunkData); + assertThat(finalChunkRequested.await(1, TimeUnit.SECONDS)).isTrue(); + + finalChunkFuture.set(finalChunkData); + downloadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(downloadThread.isAlive()).isFalse(); + assertThat(out.toByteArray()).isEqualTo(blobData); } @Test @@ -159,7 +328,7 @@ public void downloadChunked_emptyChunkList_producesEmptyOutput() throws Exceptio } @Test - public void downloadChunked_chunkFailsAfterPartialWrite_throwsIOException() throws Exception { + public void downloadChunked_chunkFails_throwsIOException() throws Exception { byte[] chunk1Data = new byte[] {1, 2, 3}; byte[] chunk2Data = new byte[] {4, 5, 6}; Digest chunk1Digest = DIGEST_UTIL.compute(chunk1Data); @@ -173,14 +342,9 @@ public void downloadChunked_chunkFailsAfterPartialWrite_throwsIOException() thro .build(); when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) .thenReturn(Futures.immediateFuture(splitResponse)); - when(combinedCache.downloadBlob(any(), eq(chunk1Digest), any())) - .thenAnswer( - invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunk1Data); - return Futures.immediateFuture(null); - }); - when(combinedCache.downloadBlob(any(), eq(chunk2Digest), any())) + when(combinedCache.downloadBlob(any(), eq(chunk1Digest))) + .thenReturn(Futures.immediateFuture(chunk1Data)); + when(combinedCache.downloadBlob(any(), eq(chunk2Digest))) .thenReturn(Futures.immediateFailedFuture(new IOException("connection reset"))); ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -197,13 +361,8 @@ public void downloadChunked_blobDigestMismatch_throwsOutputDigestMismatch() thro SplitBlobResponse.newBuilder().addChunkDigests(chunkDigest).build(); when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) .thenReturn(Futures.immediateFuture(splitResponse)); - when(combinedCache.downloadBlob(any(), eq(chunkDigest), any())) - .thenAnswer( - invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunkData); - return Futures.immediateFuture(null); - }); + when(combinedCache.downloadBlob(any(), eq(chunkDigest))) + .thenReturn(Futures.immediateFuture(chunkData)); OutputDigestMismatchException e = assertThrows( @@ -225,17 +384,86 @@ public void downloadChunked_blobDigestMismatchVerificationDisabled_succeeds() th SplitBlobResponse.newBuilder().addChunkDigests(chunkDigest).build(); when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) .thenReturn(Futures.immediateFuture(splitResponse)); - when(combinedCache.downloadBlob(any(), eq(chunkDigest), any())) - .thenAnswer( - invocation -> { - OutputStream out = invocation.getArgument(2); - out.write(chunkData); - return Futures.immediateFuture(null); - }); + when(combinedCache.downloadBlob(any(), eq(chunkDigest))) + .thenReturn(Futures.immediateFuture(chunkData)); ByteArrayOutputStream out = new ByteArrayOutputStream(); downloader.downloadChunked(context, blobDigest, out); assertThat(out.toByteArray()).isEqualTo(chunkData); } + + @Test + public void downloadChunked_cancelledChunk_throwsInterruptedException() throws Exception { + byte[] chunkData = new byte[] {1, 2, 3}; + Digest chunkDigest = DIGEST_UTIL.compute(chunkData); + Digest blobDigest = chunkDigest; + + SplitBlobResponse splitResponse = + SplitBlobResponse.newBuilder().addChunkDigests(chunkDigest).build(); + when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) + .thenReturn(Futures.immediateFuture(splitResponse)); + + SettableFuture cancelledDownload = SettableFuture.create(); + cancelledDownload.cancel(/* mayInterruptIfRunning= */ true); + when(combinedCache.downloadBlob(any(), eq(chunkDigest))).thenReturn(cancelledDownload); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + assertThrows( + InterruptedException.class, () -> downloader.downloadChunked(context, blobDigest, out)); + } + + @Test + public void downloadChunked_chunkFails_cancelsOtherInFlightDownloads() throws Exception { + byte[] chunk1Data = new byte[] {1, 2, 3}; + byte[] chunk2Data = new byte[] {4, 5, 6}; + Digest chunk1Digest = DIGEST_UTIL.compute(chunk1Data); + Digest chunk2Digest = DIGEST_UTIL.compute(chunk2Data); + Digest blobDigest = DIGEST_UTIL.compute(new byte[] {1, 2, 3, 4, 5, 6}); + + SplitBlobResponse splitResponse = + SplitBlobResponse.newBuilder() + .addChunkDigests(chunk1Digest) + .addChunkDigests(chunk2Digest) + .build(); + when(grpcCacheClient.splitBlob(any(), eq(blobDigest))) + .thenReturn(Futures.immediateFuture(splitResponse)); + + SettableFuture failedDownload = SettableFuture.create(); + SettableFuture cancelledDownload = SettableFuture.create(); + CountDownLatch downloadsStarted = new CountDownLatch(2); + when(combinedCache.downloadBlob(any(), eq(chunk1Digest))) + .thenAnswer( + invocation -> { + downloadsStarted.countDown(); + return failedDownload; + }); + when(combinedCache.downloadBlob(any(), eq(chunk2Digest))) + .thenAnswer( + invocation -> { + downloadsStarted.countDown(); + return cancelledDownload; + }); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Thread downloadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + downloader.downloadChunked(context, blobDigest, out); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + downloadThread.start(); + + assertThat(downloadsStarted.await(1, TimeUnit.SECONDS)).isTrue(); + failedDownload.setException(new IOException("connection reset")); + + downloadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(downloadThread.isAlive()).isFalse(); + assertThat(cancelledDownload.isCancelled()).isTrue(); + } } diff --git a/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobUploaderTest.java index 5088eab5242304..9e6ca94bb37890 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobUploaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ChunkedBlobUploaderTest.java @@ -16,18 +16,23 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import build.bazel.remote.execution.v2.Digest; import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.SettableFuture; import com.google.devtools.build.lib.clock.JavaClock; import com.google.devtools.build.lib.remote.chunking.ChunkingConfig; import com.google.devtools.build.lib.remote.chunking.FastCdcChunker; import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.vfs.DigestHashFunction; import com.google.devtools.build.lib.vfs.FileSystem; @@ -35,15 +40,21 @@ import com.google.devtools.build.lib.vfs.SyscallCache; import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; import com.google.protobuf.ByteString; +import java.io.ByteArrayInputStream; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -59,6 +70,7 @@ public class ChunkedBlobUploaderTest { private static final DigestUtil DIGEST_UTIL = new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256); + private static final int MAX_IN_FLIGHT_CHUNK_UPLOADS = 16; @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @@ -105,7 +117,7 @@ public void uploadChunked_allChunksMissing_uploadsAllChunks() throws Exception { List digests = invocation.getArgument(1); return immediateFuture(ImmutableSet.copyOf(digests)); }); - when(combinedCache.uploadBlob(any(), any(Digest.class), any())) + when(combinedCache.uploadBlob(any(), any(Digest.class), any(Blob.class))) .thenReturn(immediateVoidFuture()); when(grpcCacheClient.spliceBlob(any(), any(), any())).thenReturn(immediateVoidFuture()); @@ -132,7 +144,7 @@ public void uploadChunked_noChunksMissing_skipsChunkUpload() throws Exception { uploader.uploadChunked(context, blobDigest, file); - verify(combinedCache, never()).uploadBlob(any(), any(Digest.class), any()); + verify(combinedCache, never()).uploadBlob(any(), any(Digest.class), any(Blob.class)); verify(grpcCacheClient).spliceBlob(any(), eq(blobDigest), any()); } @@ -176,12 +188,14 @@ public void uploadChunked_someChunksMissing_uploadsOnlyMissingWithCorrectData() when(grpcCacheClient.findMissingDigests(any(), any())) .thenReturn(immediateFuture(ImmutableSet.copyOf(digestsToReportMissing))); Map actualUploads = new HashMap<>(); - when(combinedCache.uploadBlob(any(), any(Digest.class), any())) + when(combinedCache.uploadBlob(any(), any(Digest.class), any(Blob.class))) .thenAnswer( invocation -> { Digest d = invocation.getArgument(1); - ByteString bs = invocation.getArgument(2); - actualUploads.put(d, bs); + Blob blob = invocation.getArgument(2); + try (InputStream in = blob.get()) { + actualUploads.put(d, ByteString.readFrom(in)); + } return immediateVoidFuture(); }); when(grpcCacheClient.spliceBlob(any(), any(), any())).thenReturn(immediateVoidFuture()); @@ -195,6 +209,267 @@ public void uploadChunked_someChunksMissing_uploadsOnlyMissingWithCorrectData() verify(grpcCacheClient).spliceBlob(any(), eq(blobDigest), eq(allChunkDigests)); } + @Test + @SuppressWarnings("unchecked") + public void uploadChunked_windowRefillsAfterOneChunkCompletes() throws Exception { + Path file = execRoot.getRelative("test_window.txt"); + byte[] data = new byte[262144]; + new Random(42).nextBytes(data); + writeFile(file, data); + Digest blobDigest = DIGEST_UTIL.compute(data); + + FastCdcChunker testChunker = new FastCdcChunker(new ChunkingConfig(1024, 2, 0), DIGEST_UTIL); + List chunkDigests; + try (InputStream input = file.getInputStream()) { + chunkDigests = testChunker.chunkToDigests(input); + } + + List uniqueChunkDigests = new ArrayList<>(); + Set seen = new HashSet<>(); + for (Digest chunkDigest : chunkDigests) { + if (seen.add(chunkDigest)) { + uniqueChunkDigests.add(chunkDigest); + } + if (uniqueChunkDigests.size() == MAX_IN_FLIGHT_CHUNK_UPLOADS + 1) { + break; + } + } + assertThat(uniqueChunkDigests).hasSize(MAX_IN_FLIGHT_CHUNK_UPLOADS + 1); + + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(immediateFuture(ImmutableSet.copyOf(uniqueChunkDigests))); + when(grpcCacheClient.spliceBlob(any(), any(), any())).thenReturn(immediateVoidFuture()); + + List> uploads = new ArrayList<>(uniqueChunkDigests.size()); + for (int i = 0; i < uniqueChunkDigests.size(); i++) { + uploads.add(SettableFuture.create()); + } + CountDownLatch firstWindowRequested = new CountDownLatch(MAX_IN_FLIGHT_CHUNK_UPLOADS); + CountDownLatch overflowUploadRequested = new CountDownLatch(1); + + when(combinedCache.uploadBlob(any(), any(Digest.class), any(Blob.class))) + .thenAnswer( + invocation -> { + Digest digest = invocation.getArgument(1); + int chunkIndex = uniqueChunkDigests.indexOf(digest); + if (chunkIndex < MAX_IN_FLIGHT_CHUNK_UPLOADS) { + firstWindowRequested.countDown(); + } else if (chunkIndex == MAX_IN_FLIGHT_CHUNK_UPLOADS) { + overflowUploadRequested.countDown(); + } + return uploads.get(chunkIndex); + }); + + Thread uploadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + uploader.uploadChunked(context, blobDigest, file); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + uploadThread.start(); + + assertThat(firstWindowRequested.await(1, TimeUnit.SECONDS)).isTrue(); + assertThat(overflowUploadRequested.await(100, TimeUnit.MILLISECONDS)).isFalse(); + + uploads.get(1).set(null); + assertThat(overflowUploadRequested.await(1, TimeUnit.SECONDS)).isTrue(); + + for (SettableFuture upload : uploads) { + if (!upload.isDone()) { + upload.set(null); + } + } + uploadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(uploadThread.isAlive()).isFalse(); + verify(grpcCacheClient).spliceBlob(any(), eq(blobDigest), eq(chunkDigests)); + } + + @Test + @SuppressWarnings("unchecked") + public void uploadChunked_chunkFails_cancelsOtherInFlightUploads() throws Exception { + Path file = execRoot.getRelative("test_failure.txt"); + byte[] data = new byte[16384]; + new Random(42).nextBytes(data); + writeFile(file, data); + Digest blobDigest = DIGEST_UTIL.compute(data); + + FastCdcChunker testChunker = new FastCdcChunker(new ChunkingConfig(1024, 2, 0), DIGEST_UTIL); + List chunkDigests; + try (InputStream input = file.getInputStream()) { + chunkDigests = testChunker.chunkToDigests(input); + } + + List uniqueChunkDigests = new ArrayList<>(); + Set seen = new HashSet<>(); + for (Digest chunkDigest : chunkDigests) { + if (seen.add(chunkDigest)) { + uniqueChunkDigests.add(chunkDigest); + } + if (uniqueChunkDigests.size() == 2) { + break; + } + } + assertThat(uniqueChunkDigests).hasSize(2); + + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(immediateFuture(ImmutableSet.copyOf(uniqueChunkDigests))); + + SettableFuture failedUpload = SettableFuture.create(); + SettableFuture cancelledUpload = SettableFuture.create(); + CountDownLatch uploadsStarted = new CountDownLatch(2); + when(combinedCache.uploadBlob(any(), any(Digest.class), any(Blob.class))) + .thenAnswer( + invocation -> { + Digest digest = invocation.getArgument(1); + uploadsStarted.countDown(); + if (digest.equals(uniqueChunkDigests.get(0))) { + return failedUpload; + } + if (digest.equals(uniqueChunkDigests.get(1))) { + return cancelledUpload; + } + return immediateVoidFuture(); + }); + + Thread uploadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + uploader.uploadChunked(context, blobDigest, file); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + uploadThread.start(); + + assertThat(uploadsStarted.await(1, TimeUnit.SECONDS)).isTrue(); + failedUpload.setException(new IOException("upload failed")); + + uploadThread.join(TimeUnit.SECONDS.toMillis(1)); + + assertThat(uploadThread.isAlive()).isFalse(); + assertThat(cancelledUpload.isCancelled()).isTrue(); + verify(grpcCacheClient, never()).spliceBlob(any(), any(), any()); + } + + @Test + @SuppressWarnings("unchecked") + public void uploadChunked_cancelledUpload_throwsInterruptedException() throws Exception { + Path file = execRoot.getRelative("test_cancelled.txt"); + byte[] data = new byte[8192]; + new Random(42).nextBytes(data); + writeFile(file, data); + Digest blobDigest = DIGEST_UTIL.compute(data); + + FastCdcChunker testChunker = new FastCdcChunker(new ChunkingConfig(1024, 2, 0), DIGEST_UTIL); + List chunkDigests; + try (InputStream input = file.getInputStream()) { + chunkDigests = testChunker.chunkToDigests(input); + } + Digest firstChunkDigest = chunkDigests.get(0); + + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(immediateFuture(ImmutableSet.of(firstChunkDigest))); + + SettableFuture cancelledUpload = SettableFuture.create(); + cancelledUpload.cancel(/* mayInterruptIfRunning= */ true); + when(combinedCache.uploadBlob(any(), eq(firstChunkDigest), any(Blob.class))) + .thenReturn(cancelledUpload); + + assertThrows( + InterruptedException.class, () -> uploader.uploadChunked(context, blobDigest, file)); + verify(grpcCacheClient, never()).spliceBlob(any(), any(), any()); + } + + @Test + @SuppressWarnings("unchecked") + public void uploadChunked_failedUploadDuringPendingChunks_surfacesBeforeOpeningChunkStream() + throws Exception { + byte[] data = new byte[16384]; + new Random(42).nextBytes(data); + Digest blobDigest = DIGEST_UTIL.compute(data); + + FastCdcChunker testChunker = new FastCdcChunker(new ChunkingConfig(1024, 2, 0), DIGEST_UTIL); + List chunkDigests; + try (InputStream input = new ByteArrayInputStream(data)) { + chunkDigests = testChunker.chunkToDigests(input); + } + assertThat(chunkDigests.size()).isAtLeast(2); + + Path file = mock(Path.class); + when(file.getInputStream()).thenReturn(new ByteArrayInputStream(data)); + + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(immediateFuture(ImmutableSet.of(chunkDigests.get(0)))); + + SettableFuture failedUpload = SettableFuture.create(); + failedUpload.setException(new IOException("upload failed")); + when(combinedCache.uploadBlob(any(), eq(chunkDigests.get(0)), any(Blob.class))) + .thenReturn(failedUpload); + + Thread uploadThread = + Thread.ofVirtual() + .unstarted( + () -> { + try { + uploader.uploadChunked(context, blobDigest, file); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + uploadThread.start(); + + uploadThread.join(TimeUnit.SECONDS.toMillis(1)); + assertThat(uploadThread.isAlive()).isFalse(); + verify(file, times(1)).getInputStream(); + verify(grpcCacheClient, never()).spliceBlob(any(), any(), any()); + } + + @Test + @SuppressWarnings("unchecked") + public void uploadChunked_fileTruncatedBeforeChunkUpload_reportsConcurrentModification() + throws Exception { + byte[] data = new byte[8192]; + new Random(42).nextBytes(data); + Digest blobDigest = DIGEST_UTIL.compute(data); + + FastCdcChunker testChunker = new FastCdcChunker(new ChunkingConfig(1024, 2, 0), DIGEST_UTIL); + List chunkDigests; + try (InputStream input = new ByteArrayInputStream(data)) { + chunkDigests = testChunker.chunkToDigests(input); + } + assertThat(chunkDigests.size()).isAtLeast(2); + + Digest secondChunkDigest = chunkDigests.get(1); + Path file = mock(Path.class); + when(file.getInputStream()) + .thenReturn(new ByteArrayInputStream(data), new ByteArrayInputStream(new byte[0])); + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(immediateFuture(ImmutableSet.of(secondChunkDigest))); + when(combinedCache.uploadBlob(any(), eq(secondChunkDigest), any(Blob.class))) + .thenAnswer( + invocation -> { + Blob blob = invocation.getArgument(2); + try (InputStream in = blob.get()) { + ByteString unused = ByteString.readFrom(in); + } + return immediateVoidFuture(); + }); + + IOException e = + assertThrows(IOException.class, () -> uploader.uploadChunked(context, blobDigest, file)); + + assertThat(e).hasMessageThat().contains("file was concurrently modified during upload"); + assertThat(e).hasCauseThat().isInstanceOf(EOFException.class); + verify(grpcCacheClient, never()).spliceBlob(any(), any(), any()); + } + private void writeFile(Path path, byte[] data) throws IOException { try (var out = path.getOutputStream()) { out.write(data); diff --git a/src/test/java/com/google/devtools/build/lib/remote/ChunkedTransferBenchmark.java b/src/test/java/com/google/devtools/build/lib/remote/ChunkedTransferBenchmark.java new file mode 100644 index 00000000000000..8579a6334e6512 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/ChunkedTransferBenchmark.java @@ -0,0 +1,243 @@ +// Copyright 2026 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.devtools.build.lib.remote; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import build.bazel.remote.execution.v2.Digest; +import build.bazel.remote.execution.v2.RequestMetadata; +import build.bazel.remote.execution.v2.SplitBlobResponse; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.devtools.build.lib.clock.JavaClock; +import com.google.devtools.build.lib.remote.chunking.ChunkingConfig; +import com.google.devtools.build.lib.remote.chunking.FastCdcChunker; +import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext; +import com.google.devtools.build.lib.remote.common.RemoteCacheClient.Blob; +import com.google.devtools.build.lib.remote.util.DigestUtil; +import com.google.devtools.build.lib.vfs.DigestHashFunction; +import com.google.devtools.build.lib.vfs.FileSystem; +import com.google.devtools.build.lib.vfs.Path; +import com.google.devtools.build.lib.vfs.SyscallCache; +import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +/** Benchmark for chunk download/upload with per-chunk latency jitter. */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 1, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 3, time = 3, timeUnit = TimeUnit.SECONDS) +@Fork(1) +public class ChunkedTransferBenchmark { + private static final DigestUtil DIGEST_UTIL = + new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256); + private static final RemoteActionExecutionContext CONTEXT = + RemoteActionExecutionContext.create(RequestMetadata.getDefaultInstance()); + + @Benchmark + public void downloadChunked(DownloadState state) throws Exception { + state.downloader.downloadChunked(CONTEXT, state.blobDigest, OutputStream.nullOutputStream()); + } + + @Benchmark + public void uploadChunked(UploadState state) throws Exception { + state.uploader.uploadChunked(CONTEXT, state.blobDigest, state.file); + } + + @State(Scope.Thread) + public static class DownloadState { + @Param({"1", "2", "4", "8"}) + public int schedulerThreads; + + @Param({"32"}) + public int chunkCount; + + @Param({"1024"}) + public int chunkSizeBytes; + + @Param({"25"}) + public int delayMillis; + + @Param({"10"}) + public int jitterMillis; + + private ScheduledExecutorService scheduler; + private ChunkedBlobDownloader downloader; + private Digest blobDigest; + private Random latencyJitter; + + @Setup(Level.Trial) + public void setup() throws Exception { + scheduler = Executors.newScheduledThreadPool(schedulerThreads); + latencyJitter = new Random(12345L); + + GrpcCacheClient grpcCacheClient = mock(GrpcCacheClient.class); + CombinedCache combinedCache = mock(CombinedCache.class); + + List chunkDigests = new ArrayList<>(chunkCount); + Map chunkDataByDigest = new HashMap<>(chunkCount); + long totalBytes = 0; + for (int i = 0; i < chunkCount; i++) { + byte[] chunkData = new byte[chunkSizeBytes]; + new Random(i).nextBytes(chunkData); + Digest chunkDigest = DIGEST_UTIL.compute(chunkData); + chunkDigests.add(chunkDigest); + chunkDataByDigest.put(chunkDigest, chunkData); + totalBytes += chunkData.length; + } + + when(combinedCache.downloadBlob(any(), any(Digest.class))) + .thenAnswer( + invocation -> + delayedFuture( + chunkDataByDigest.get(invocation.getArgument(1)), + delayMillis, + jitterMillis, + latencyJitter, + scheduler)); + + blobDigest = + Digest.newBuilder() + .setHash("chunked-transfer-benchmark-download-" + chunkCount + "-" + chunkSizeBytes) + .setSizeBytes(totalBytes) + .build(); + + SplitBlobResponse splitBlobResponse = + SplitBlobResponse.newBuilder().addAllChunkDigests(chunkDigests).build(); + when(grpcCacheClient.splitBlob(any(), any(Digest.class))) + .thenReturn(Futures.immediateFuture(splitBlobResponse)); + + downloader = new ChunkedBlobDownloader(grpcCacheClient, combinedCache, DIGEST_UTIL); + } + + @TearDown(Level.Trial) + public void tearDown() { + scheduler.shutdownNow(); + } + } + + @State(Scope.Thread) + public static class UploadState { + @Param({"1", "2", "4", "8"}) + public int schedulerThreads; + + @Param({"32768"}) + public int fileSizeBytes; + + @Param({"1024"}) + public int avgChunkSizeBytes; + + @Param({"25"}) + public int delayMillis; + + @Param({"10"}) + public int jitterMillis; + + private ScheduledExecutorService scheduler; + private ChunkedBlobUploader uploader; + private Path file; + private Digest blobDigest; + private Random latencyJitter; + + @Setup(Level.Trial) + public void setup() throws Exception { + scheduler = Executors.newScheduledThreadPool(schedulerThreads); + latencyJitter = new Random(54321L); + + GrpcCacheClient grpcCacheClient = mock(GrpcCacheClient.class); + CombinedCache combinedCache = mock(CombinedCache.class); + + byte[] data = new byte[fileSizeBytes]; + new Random(42).nextBytes(data); + blobDigest = DIGEST_UTIL.compute(data); + + FileSystem fs = new InMemoryFileSystem(new JavaClock(), DigestHashFunction.SHA256); + file = fs.getPath("/bench/blob.bin"); + file.getParentDirectory().createDirectoryAndParents(); + try (var out = file.getOutputStream()) { + out.write(data); + } + + ChunkingConfig chunkingConfig = new ChunkingConfig(avgChunkSizeBytes, 2, 0); + uploader = new ChunkedBlobUploader(grpcCacheClient, combinedCache, chunkingConfig, DIGEST_UTIL); + + List chunkDigests; + try (var input = file.getInputStream()) { + chunkDigests = new FastCdcChunker(chunkingConfig, DIGEST_UTIL).chunkToDigests(input); + } + + when(grpcCacheClient.findMissingDigests(any(), any())) + .thenReturn(Futures.immediateFuture(ImmutableSet.copyOf(chunkDigests))); + when(grpcCacheClient.spliceBlob(any(), any(Digest.class), any())) + .thenReturn(Futures.immediateVoidFuture()); + when(combinedCache.uploadBlob(any(), any(Digest.class), any(Blob.class))) + .thenAnswer( + invocation -> + delayedFuture(null, delayMillis, jitterMillis, latencyJitter, scheduler)); + } + + @TearDown(Level.Trial) + public void tearDown() { + scheduler.shutdownNow(); + } + } + + private static ListenableFuture delayedFuture( + T value, + int delayMillis, + int jitterMillis, + Random latencyJitter, + ScheduledExecutorService scheduler) { + SettableFuture future = SettableFuture.create(); + scheduler.schedule( + () -> future.set(value), + jitteredDelayMillis(delayMillis, jitterMillis, latencyJitter), + TimeUnit.MILLISECONDS); + return future; + } + + private static int jitteredDelayMillis(int delayMillis, int jitterMillis, Random latencyJitter) { + if (jitterMillis == 0) { + return delayMillis; + } + return Math.max( + 0, delayMillis + latencyJitter.nextInt((jitterMillis * 2) + 1) - jitterMillis); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/CombinedCacheTest.java b/src/test/java/com/google/devtools/build/lib/remote/CombinedCacheTest.java index c18aab417ecea7..ce35e10bfe3a4b 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/CombinedCacheTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/CombinedCacheTest.java @@ -26,8 +26,11 @@ import static org.mockito.Mockito.verify; import build.bazel.remote.execution.v2.ActionResult; +import build.bazel.remote.execution.v2.CacheCapabilities; import build.bazel.remote.execution.v2.Digest; +import build.bazel.remote.execution.v2.FastCdc2020Params; import build.bazel.remote.execution.v2.RequestMetadata; +import build.bazel.remote.execution.v2.ServerCapabilities; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -713,6 +716,58 @@ public void shutdownNow_cancelInProgressUploads() throws Exception { assertThat(upload.isCancelled()).isTrue(); } + @Test + public void uploadFile_chunkedUpload_deduplicatesRemoteUpload() throws Exception { + GrpcCacheClient grpcCacheClient = mock(GrpcCacheClient.class); + doAnswer(unused -> chunkingCapabilities()).when(grpcCacheClient).getServerCapabilities(); + doAnswer(unused -> immediateFuture(ImmutableSet.of())) + .when(grpcCacheClient) + .findMissingDigests(any(), any()); + + CountDownLatch spliceStarted = new CountDownLatch(1); + SettableFuture spliceFuture = SettableFuture.create(); + doAnswer( + unused -> { + spliceStarted.countDown(); + return spliceFuture; + }) + .when(grpcCacheClient) + .spliceBlob(any(), any(), any()); + + CombinedCache combinedCache = + new CombinedCache( + grpcCacheClient, + /* diskCacheClient= */ null, + /* symlinkTemplate= */ null, + digestUtil, + /* chunkingEnabled= */ true); + byte[] data = new byte[8192]; + Path file = execRoot.getRelative("chunked-output"); + try (var out = file.getOutputStream()) { + out.write(data); + } + Digest digest = digestUtil.compute(data); + + try { + ListenableFuture firstUpload = + combinedCache.uploadFile(remoteActionExecutionContext, digest, file); + assertThat(spliceStarted.await(1, TimeUnit.SECONDS)).isTrue(); + + ListenableFuture secondUpload = + combinedCache.uploadFile(remoteActionExecutionContext, digest, file); + + assertThat(combinedCache.casUploadCache.getSubscriberCount(digest)).isEqualTo(2); + verify(grpcCacheClient).findMissingDigests(any(), any()); + verify(grpcCacheClient).spliceBlob(any(), any(), any()); + + spliceFuture.set(null); + getFromFuture(firstUpload); + getFromFuture(secondUpload); + } finally { + combinedCache.release(); + } + } + private InMemoryCombinedCache newCombinedCache() { return new InMemoryCombinedCache(digestUtil); } @@ -739,4 +794,13 @@ private RemoteExecutionCache newRemoteExecutionCache(RemoteCacheClient remoteCac digestUtil, /* chunkingEnabled= */ false); } + + private static ServerCapabilities chunkingCapabilities() { + return ServerCapabilities.newBuilder() + .setCacheCapabilities( + CacheCapabilities.newBuilder() + .setFastCdc2020Params( + FastCdc2020Params.newBuilder().setAvgChunkSizeBytes(1024).build())) + .build(); + } } diff --git a/src/tools/remote/src/main/java/com/google/devtools/build/remote/worker/OnDiskBlobStoreCache.java b/src/tools/remote/src/main/java/com/google/devtools/build/remote/worker/OnDiskBlobStoreCache.java index 362cb760d08c3e..a181bb65111877 100644 --- a/src/tools/remote/src/main/java/com/google/devtools/build/remote/worker/OnDiskBlobStoreCache.java +++ b/src/tools/remote/src/main/java/com/google/devtools/build/remote/worker/OnDiskBlobStoreCache.java @@ -34,7 +34,6 @@ import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.vfs.Path; import com.google.devtools.build.lib.vfs.PathFragment; -import com.google.protobuf.ByteString; import java.io.IOException; import java.util.HashSet; import java.util.concurrent.ConcurrentHashMap; @@ -138,18 +137,6 @@ public DigestUtil getDigestUtil() { return digestUtil; } - @Override - public ListenableFuture uploadBlob( - RemoteActionExecutionContext context, Digest digest, ByteString data) { - return uploadBlob(context, digest, data, /* force= */ true); - } - - @Override - public ListenableFuture uploadFile( - RemoteActionExecutionContext context, Digest digest, Path file) { - return uploadFile(context, digest, file, /* force= */ true); - } - public DiskCacheClient getDiskCacheClient() { return checkNotNull(diskCacheClient); }