Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor read context streams to async streams #10284

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,24 +254,7 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
}
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception {
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber);
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
Expand Down Expand Up @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception {
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<InputStreamContainer> blobPartStreams = new ArrayList<>();
List<CompletableFuture<InputStreamContainer>> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(blobPartStream);
blobPartStreams.add(CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
Expand Down Expand Up @@ -49,12 +48,11 @@
* Asynchronously downloads the blob to the specified location using an executor from the thread pool.
* @param blobName The name of the blob for which needs to be downloaded.
* @param fileLocation The path on local disk where the blob needs to be downloaded.
* @param threadPool The threadpool instance which will provide the executor for performing a multipart download.
* @param completionListener Listener which will be notified when the download is complete.
*/
@ExperimentalApi
default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener);
default void asyncBlobDownload(String blobName, Path fileLocation, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, completionListener);

Check warning on line 55 in server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java#L55

Added line #L55 was not covered by tests
readBlobAsync(blobName, readContextListener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -144,8 +145,10 @@ public long getBlobSize() {
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
return super.getPartStreams().stream()
.map(cf -> cf.thenApply(this::decryptInputStreamContainer))
.collect(Collectors.toUnmodifiableList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,26 @@
import org.opensearch.common.io.InputStreamContainer;

import java.util.List;
import java.util.concurrent.CompletableFuture;

/**
* ReadContext is used to encapsulate all data needed by <code>BlobContainer#readBlobAsync</code>
*/
@ExperimentalApi
public class ReadContext {
private final long blobSize;
private final List<InputStreamContainer> partStreams;
private final List<CompletableFuture<InputStreamContainer>> asyncPartStreams;
private final String blobChecksum;

public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String blobChecksum) {
public ReadContext(long blobSize, List<CompletableFuture<InputStreamContainer>> asyncPartStreams, String blobChecksum) {
this.blobSize = blobSize;
this.partStreams = partStreams;
this.asyncPartStreams = asyncPartStreams;
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.asyncPartStreams = readContext.asyncPartStreams;
this.blobChecksum = readContext.blobChecksum;
}

Expand All @@ -39,14 +40,14 @@ public String getBlobChecksum() {
}

public int getNumberOfParts() {
return partStreams.size();
return asyncPartStreams.size();
}

public long getBlobSize() {
return blobSize;
}

public List<InputStreamContainer> getPartStreams() {
return partStreams;
public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
return asyncPartStreams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;

/**
* FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel}
* instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion.
*/
@InternalApi
class FilePartWriter implements Runnable {
class FilePartWriter implements BiConsumer<InputStreamContainer, Throwable> {

private final int partNumber;
private final InputStreamContainer blobPartStreamContainer;
private final Path fileLocation;
private final AtomicBoolean anyPartStreamFailed;
private final ActionListener<Integer> fileCompletionListener;
Expand All @@ -42,20 +42,26 @@

public FilePartWriter(
int partNumber,
InputStreamContainer blobPartStreamContainer,
Path fileLocation,
AtomicBoolean anyPartStreamFailed,
ActionListener<Integer> fileCompletionListener
) {
this.partNumber = partNumber;
this.blobPartStreamContainer = blobPartStreamContainer;
this.fileLocation = fileLocation;
this.anyPartStreamFailed = anyPartStreamFailed;
this.fileCompletionListener = fileCompletionListener;
}

@Override
public void run() {
public void accept(InputStreamContainer blobPartStreamContainer, Throwable throwable) {
if (throwable != null) {
if (throwable instanceof Exception) {
processFailure((Exception) throwable);

Check warning on line 59 in server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java#L59

Added line #L59 was not covered by tests
} else {
processFailure(new Exception(throwable));

Check warning on line 61 in server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java#L61

Added line #L61 was not covered by tests
}
return;

Check warning on line 63 in server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java#L63

Added line #L63 was not covered by tests
}
// Ensures no writes to the file if any stream fails.
if (anyPartStreamFailed.get() == false) {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,25 @@
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer}
* using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are
* spread across a {@link ThreadPool} executor.
* using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams.
*/
@InternalApi
public class ReadContextListener implements ActionListener<ReadContext> {

private final String fileName;
private final Path fileLocation;
private final ThreadPool threadPool;
private final ActionListener<String> completionListener;
private static final Logger logger = LogManager.getLogger(ReadContextListener.class);

public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
public ReadContextListener(String fileName, Path fileLocation, ActionListener<String> completionListener) {
this.fileName = fileName;
this.fileLocation = fileLocation;
this.threadPool = threadPool;
this.completionListener = completionListener;
}

Expand All @@ -47,14 +43,9 @@ public void onResponse(ReadContext readContext) {
FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener);

for (int partNumber = 0; partNumber < numParts; partNumber++) {
FilePartWriter filePartWriter = new FilePartWriter(
partNumber,
readContext.getPartStreams().get(partNumber),
fileLocation,
anyPartStreamFailed,
fileCompletionListener
);
threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter);
readContext.getPartStreams()
.get(partNumber)
.whenComplete(new FilePartWriter(partNumber, fileLocation, anyPartStreamFailed, fileCompletionListener));
}
}

Expand Down
24 changes: 8 additions & 16 deletions server/src/main/java/org/opensearch/index/shard/IndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.opensearch.action.admin.indices.flush.FlushRequest;
import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest;
import org.opensearch.action.admin.indices.upgrade.post.UpgradeRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.support.replication.PendingReplicationActions;
import org.opensearch.action.support.replication.ReplicationResponse;
Expand Down Expand Up @@ -4914,24 +4913,17 @@ private void downloadSegments(
RemoteSegmentStoreDirectory targetRemoteDirectory,
Set<String> toDownloadSegments,
final Runnable onFileSync
) {
final PlainActionFuture<Void> completionListener = PlainActionFuture.newFuture();
final GroupedActionListener<Void> batchDownloadListener = new GroupedActionListener<>(
ActionListener.map(completionListener, v -> null),
toDownloadSegments.size()
);

final ActionListener<String> segmentsDownloadListener = ActionListener.map(batchDownloadListener, fileName -> {
) throws IOException {
final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex();
for (String segment : toDownloadSegments) {
final PlainActionFuture<String> segmentListener = PlainActionFuture.newFuture();
sourceRemoteDirectory.copyTo(segment, storeDirectory, indexPath, segmentListener);
segmentListener.actionGet();
onFileSync.run();
if (targetRemoteDirectory != null) {
targetRemoteDirectory.copyFrom(storeDirectory, fileName, fileName, IOContext.DEFAULT);
targetRemoteDirectory.copyFrom(storeDirectory, segment, segment, IOContext.DEFAULT);
}
return null;
});

final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex();
toDownloadSegments.forEach(file -> { sourceRemoteDirectory.copyTo(file, storeDirectory, indexPath, segmentsDownloadListener); });
completionListener.actionGet();
}
}

private boolean localDirectoryContains(Directory localDirectory, String file, long checksum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati
if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) {
final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer();
final Path destinationFilePath = destinationPath.resolve(source);
blobContainer.asyncBlobDownload(blobName, destinationFilePath, threadPool, fileCompletionListener);
blobContainer.asyncBlobDownload(blobName, destinationFilePath, fileCompletionListener);
} else {
// Fallback to older mechanism of downloading the file
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.Version;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.concurrent.GatedCloseable;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.shard.IndexShard;
Expand Down Expand Up @@ -141,14 +141,12 @@ private void downloadSegments(
ActionListener<GetSegmentFilesResponse> completionListener
) {
final Path indexPath = shardPath == null ? null : shardPath.resolveIndex();
final GroupedActionListener<Void> batchDownloadListener = new GroupedActionListener<>(
ActionListener.map(completionListener, v -> new GetSegmentFilesResponse(toDownloadSegments)),
toDownloadSegments.size()
);
ActionListener<String> segmentsDownloadListener = ActionListener.map(batchDownloadListener, result -> null);
toDownloadSegments.forEach(
fileMetadata -> remoteStoreDirectory.copyTo(fileMetadata.name(), storeDirectory, indexPath, segmentsDownloadListener)
);
for (StoreFileMetadata storeFileMetadata : toDownloadSegments) {
final PlainActionFuture<String> segmentListener = PlainActionFuture.newFuture();
remoteStoreDirectory.copyTo(storeFileMetadata.name(), storeDirectory, indexPath, segmentListener);
segmentListener.actionGet();
}
completionListener.onResponse(new GetSegmentFilesResponse(toDownloadSegments));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.UnaryOperator;

import org.mockito.Mockito;
Expand Down Expand Up @@ -51,10 +52,12 @@ public void testReadBlobAsync() throws Exception {
// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);

final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
final CompletableFuture<InputStreamContainer> streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer);
final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
Expand All @@ -76,7 +79,7 @@ public void testReadBlobAsync() throws Exception {
assertEquals(1, response.getNumberOfParts());
assertEquals(size, response.getBlobSize());

InputStreamContainer responseContainer = response.getPartStreams().get(0);
InputStreamContainer responseContainer = response.getPartStreams().get(0).get();
assertEquals(0, responseContainer.getOffset());
assertEquals(size, responseContainer.getContentLength());
assertEquals(100, responseContainer.getInputStream().available());
Expand All @@ -99,7 +102,8 @@ public void testReadBlobAsyncException() throws Exception {
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);
final CompletableFuture<InputStreamContainer> streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer);
final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
Expand Down
Loading