Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.reactivestreams.Subscriber;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class ParallelPresignedUrlMultipartDownloaderSubscriber

private final AtomicInteger partNumber = new AtomicInteger(0);
private final AtomicInteger completedParts = new AtomicInteger(0);
private final AtomicInteger inFlightRequestsNum = new AtomicInteger(0);
private final Semaphore inFlightPermits;
private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false);
private final AtomicBoolean processingPending = new AtomicBoolean(false);
private final Map<Integer, CompletableFuture<GetObjectResponse>> inFlightRequests = new ConcurrentHashMap<>();
Expand All @@ -90,6 +91,7 @@ public ParallelPresignedUrlMultipartDownloaderSubscriber(
this.configuredPartSizeInBytes = configuredPartSizeInBytes;
this.resultFuture = resultFuture;
this.maxInFlightParts = maxInFlightParts;
this.inFlightPermits = new Semaphore(maxInFlightParts);
}

@Override
Expand Down Expand Up @@ -128,21 +130,25 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj
PresignedUrlDownloadRequest partRequest = createRangedGetRequest(0);
log.debug(() -> "Sending first range request with range=" + partRequest.range());

if (!inFlightPermits.tryAcquire()) {
throw new IllegalStateException("Failed to acquire permit for first request");
}

CompletableFuture<GetObjectResponse> response =
s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer);

inFlightRequests.put(0, response);
inFlightRequestsNum.incrementAndGet();
CompletableFutureUtils.forwardExceptionTo(resultFuture, response);

response.whenComplete((res, error) -> {
inFlightRequests.remove(0);
inFlightPermits.release();

if (error != null || isCompletedExceptionally.get()) {
handlePartError(error, 0);
return;
}

inFlightRequests.remove(0);
inFlightRequestsNum.decrementAndGet();
completedParts.incrementAndGet();

this.eTag = res.eTag();
Expand Down Expand Up @@ -188,7 +194,7 @@ private void processRequest(AsyncResponseTransformer<GetObjectResponse, GetObjec
return;
}

if (inFlightRequestsNum.get() >= maxInFlightParts) {
if (!inFlightPermits.tryAcquire()) {
pendingTransformers.offer(Pair.of(currentPart, transformer));
return;
}
Expand All @@ -200,6 +206,7 @@ private void processRequest(AsyncResponseTransformer<GetObjectResponse, GetObjec
private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse> transformer,
int partIndex) {
if (isCompletedExceptionally.get()) {
inFlightPermits.release();
return;
}

Expand All @@ -210,10 +217,12 @@ private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
s3AsyncClient.presignedUrlExtension().getObject(partRequest, transformer);

inFlightRequests.put(partIndex, response);
inFlightRequestsNum.incrementAndGet();
CompletableFutureUtils.forwardExceptionTo(resultFuture, response);

response.whenComplete((res, error) -> {
inFlightRequests.remove(partIndex);
inFlightPermits.release();

if (error != null || isCompletedExceptionally.get()) {
handlePartError(error, partIndex);
return;
Expand All @@ -226,8 +235,6 @@ private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
}

log.debug(() -> "Completed part: " + partIndex);
inFlightRequests.remove(partIndex);
inFlightRequestsNum.decrementAndGet();
int totalComplete = completedParts.incrementAndGet();

if (totalComplete == totalParts) {
Expand All @@ -245,22 +252,27 @@ private void sendPartRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
}

private void processPendingTransformers() {
// Re-check after releasing the gate to catch permits that arrived
// while exiting — prevents "missed signal" where no thread drains the queue.
do {
if (!processingPending.compareAndSet(false, true)) {
return;
}
try {
while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts) {
// Drain pending queue while permits are available
while (!pendingTransformers.isEmpty() && inFlightPermits.tryAcquire()) {
Pair<Integer, AsyncResponseTransformer<GetObjectResponse, GetObjectResponse>> pendingPart =
pendingTransformers.poll();
if (pendingPart != null && pendingPart.left() < totalParts) {
sendPartRequest(pendingPart.right(), pendingPart.left());
} else {
inFlightPermits.release();
}
}
} finally {
processingPending.set(false);
}
} while (!pendingTransformers.isEmpty() && inFlightRequestsNum.get() < maxInFlightParts);
} while (!pendingTransformers.isEmpty() && inFlightPermits.availablePermits() > 0);
Comment thread
alextwoods marked this conversation as resolved.
}

private Optional<SdkClientException> validatePartResponse(GetObjectResponse response, int partIndex) {
Expand Down Expand Up @@ -336,4 +348,4 @@ public void onError(Throwable t) {
public void onComplete() {
// Completion is handled by resultFuture
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.awssdk.services.s3.internal.multipart;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.findAll;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
Expand All @@ -27,6 +28,7 @@

import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
import com.github.tomakehurst.wiremock.junit5.WireMockTest;
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
Expand All @@ -35,6 +37,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -234,6 +237,45 @@ void onNext_withNullTransformer_shouldThrowNPE() {
.isInstanceOf(NullPointerException.class);
}

@Test
void multiPartDownload_manyParts_shouldCompleteSuccessfully() throws Exception {
// 13 parts to exceed maxInFlightParts (10)
byte[] data = new byte[208]; // 13 × 16 bytes
Arrays.fill(data, (byte) 'X');

stubFor(get(urlEqualTo(PRESIGNED_URL_PATH))
.withHeader("Range", matching("bytes=0-15"))
.willReturn(aResponse().withStatus(206)
.withHeader("Content-Length", "16")
.withHeader("Content-Range", "bytes 0-15/208")
.withHeader("ETag", "\"etag\"")
.withBody(Arrays.copyOfRange(data, 0, 16))));

for (int i = 1; i < 13; i++) {
int start = i * 16;
int end = start + 15;
stubFor(get(urlEqualTo(PRESIGNED_URL_PATH))
.withHeader("Range", matching("bytes=" + start + "-" + end))
.willReturn(aResponse().withStatus(206)
.withHeader("Content-Length", "16")
.withHeader("Content-Range", "bytes " + start + "-" + end + "/208")
.withHeader("ETag", "\"etag\"")
.withBody(Arrays.copyOfRange(data, start, end + 1))));
}

tempFile = createTempFileUnchecked();
PresignedUrlDownloadRequest request = PresignedUrlDownloadRequest.builder()
.presignedUrl(presignedUrl)
.build();

s3AsyncClient.presignedUrlExtension()
.getObject(request, AsyncResponseTransformer.toFile(tempFile))
.join();

assertThat(Files.readAllBytes(tempFile)).isEqualTo(data);
verify(13, getRequestedFor(urlEqualTo(PRESIGNED_URL_PATH)));
}

private static Path createTempFile() throws IOException {
Path path = Files.createTempFile("parallel-test-" + UUID.randomUUID(), ".tmp");
Files.deleteIfExists(path);
Expand Down
Loading