From fc7dc20c525ee70e12695256d8d19587c08af59f Mon Sep 17 00:00:00 2001 From: Andrew Ross Date: Wed, 4 Oct 2023 22:38:22 -0500 Subject: [PATCH] Refactor multipart download to a more async model (#10349) (#10373) * Refactor read context streams to async streams Signed-off-by: Kunal Kotwani * Refactor multipart download to a more async model The previous approach of kicking off the stream requests for all parts of a file did not work well for very large files. For example, a 20GiB file uploaded in 16MiB parts will consist of 1200+ parts. When we attempted to initiate streaming for all parts concurrently, some parts would hit a client timeout after 2 minutes without being able to get a connection due to the other parts not having been completed in that time frame. This refactoring adds yet another layer of indirection in order to allow the code that is actually writing the destination file to control the rate at which streams are started. This should allow for downloading files consisting of arbitrarily many parts at any connection speed. This commit also wires in the download rate limiter so that the `indices.recovery.max_bytes_per_sec` is properly honored. Signed-off-by: Andrew Ross --------- Signed-off-by: Kunal Kotwani Signed-off-by: Andrew Ross Co-authored-by: Kunal Kotwani (cherry picked from commit 28f185b347a3333c8670ca1a7bd7d0a85fed14e9) --- .../repositories/s3/S3BlobContainer.java | 28 +--- .../s3/S3BlobStoreContainerTests.java | 4 +- .../opensearch/index/shard/IndexShardIT.java | 3 +- .../mocks/MockFsAsyncBlobContainer.java | 5 +- .../AsyncMultiStreamBlobContainer.java | 16 -- ...syncMultiStreamEncryptedBlobContainer.java | 6 +- .../blobstore/stream/read/ReadContext.java | 32 +++- .../read/listener/FileCompletionListener.java | 47 ------ .../stream/read/listener/FilePartWriter.java | 70 ++------- .../read/listener/ReadContextListener.java | 139 +++++++++++++++--- .../common/settings/ClusterSettings.java | 1 + .../org/opensearch/index/IndexService.java | 4 +- .../opensearch/index/shard/IndexShard.java | 39 +++-- .../opensearch/index/shard/StoreRecovery.java | 11 +- .../index/store/RemoteDirectory.java | 8 +- .../store/RemoteSegmentStoreDirectory.java | 18 ++- .../RemoteSegmentStoreDirectoryFactory.java | 19 ++- .../indices/recovery/RecoverySettings.java | 25 ++++ .../RemoteStoreReplicationSource.java | 16 +- .../main/java/org/opensearch/node/Node.java | 6 +- .../blobstore/BlobStoreRepository.java | 6 +- .../org/opensearch/threadpool/ThreadPool.java | 6 + ...ultiStreamEncryptedBlobContainerTests.java | 10 +- .../listener/FileCompletionListenerTests.java | 58 -------- .../read/listener/FilePartWriterTests.java | 104 +------------ .../listener/ReadContextListenerTests.java | 54 +++++-- .../opensearch/index/IndexModuleTests.java | 3 +- .../RemoteStoreRefreshListenerTests.java | 4 +- ...moteSegmentStoreDirectoryFactoryTests.java | 7 +- .../RemoteSegmentStoreDirectoryTests.java | 29 ++-- .../snapshots/SnapshotResiliencyTests.java | 3 +- .../index/shard/IndexShardTestCase.java | 17 ++- .../recovery/DefaultRecoverySettings.java | 24 +++ 33 files changed, 402 insertions(+), 420 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java delete mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java create mode 100644 test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index c6ae58371e15c..fcfccf50ad326 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -241,37 +241,23 @@ public void readBlobAsync(String blobName, ActionListener listener) return; } - final List> blobPartInputStreamFutures = new ArrayList<>(); + final List blobPartInputStreamFutures = new ArrayList<>(); final long blobSize = blobMetadata.objectSize(); final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount(); final String blobChecksum = blobMetadata.checksum().checksumCRC32(); if (numberOfParts == null) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); + blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); } else { // S3 multipart files use 1 to n indexing for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber)); + final int innerPartNumber = partNumber; + blobPartInputStreamFutures.add( + () -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber) + ); } } - - CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)) - .whenComplete((unused, partThrowable) -> { - if (partThrowable == null) { - listener.onResponse( - new ReadContext( - blobSize, - blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()), - blobChecksum - ) - ); - } else { - Exception ex = partThrowable.getCause() instanceof Exception - ? (Exception) partThrowable.getCause() - : new Exception(partThrowable.getCause()); - listener.onFailure(ex); - } - }); + listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum)); }); } catch (Exception ex) { listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex)); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 9817d7cd520ef..2e54705e9cd78 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception { assertEquals(objectSize, readContext.getBlobSize()); for (int partNumber = 1; partNumber < objectPartCount; partNumber++) { - InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join(); final int offset = partNumber * partSize; assertEquals(partSize, inputStreamContainer.getContentLength()); assertEquals(offset, inputStreamContainer.getOffset()); @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception { assertEquals(checksum, readContext.getBlobChecksum()); assertEquals(objectSize, readContext.getBlobSize()); - InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get(); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join(); assertEquals(objectSize, inputStreamContainer.getContentLength()); assertEquals(0, inputStreamContainer.getOffset()); assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length); diff --git a/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java index e232d1c8fa7c8..44a900491d949 100644 --- a/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java @@ -689,7 +689,8 @@ public static final IndexShard newIndexShard( null, null, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL, - nodeId + nodeId, + null ); } diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java index 079753de95680..36987ac2d4991 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java @@ -27,6 +27,7 @@ import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener listener) long contentLength = listBlobs().get(blobName).length(); long partSize = contentLength / 10; int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); - List blobPartStreams = new ArrayList<>(); + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { long offset = partNumber * partSize; InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); - blobPartStreams.add(blobPartStream); + blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream)); } ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); listener.onResponse(blobReadContext); diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java index e73a9f5cd0bc9..97f304d776f5c 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java @@ -10,13 +10,10 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; -import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -45,19 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer { @ExperimentalApi void readBlobAsync(String blobName, ActionListener listener); - /** - * Asynchronously downloads the blob to the specified location using an executor from the thread pool. - * @param blobName The name of the blob for which needs to be downloaded. - * @param fileLocation The path on local disk where the blob needs to be downloaded. - * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. - * @param completionListener Listener which will be notified when the download is complete. - */ - @ExperimentalApi - default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); - readBlobAsync(blobName, readContextListener); - } - /* * Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected * checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index c64dc6b9e3ae4..82bc7a0baed50 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -144,8 +144,10 @@ public long getBlobSize() { } @Override - public List getPartStreams() { - return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList()); + public List getPartStreams() { + return super.getPartStreams().stream() + .map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer)) + .collect(Collectors.toUnmodifiableList()); } /** diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index 2c305fb03c475..4bdce11ff4f9a 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -12,6 +12,8 @@ import org.opensearch.common.io.InputStreamContainer; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; /** * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync @@ -19,18 +21,18 @@ @ExperimentalApi public class ReadContext { private final long blobSize; - private final List partStreams; + private final List asyncPartStreams; private final String blobChecksum; - public ReadContext(long blobSize, List partStreams, String blobChecksum) { + public ReadContext(long blobSize, List asyncPartStreams, String blobChecksum) { this.blobSize = blobSize; - this.partStreams = partStreams; + this.asyncPartStreams = asyncPartStreams; this.blobChecksum = blobChecksum; } public ReadContext(ReadContext readContext) { this.blobSize = readContext.blobSize; - this.partStreams = readContext.partStreams; + this.asyncPartStreams = readContext.asyncPartStreams; this.blobChecksum = readContext.blobChecksum; } @@ -39,14 +41,30 @@ public String getBlobChecksum() { } public int getNumberOfParts() { - return partStreams.size(); + return asyncPartStreams.size(); } public long getBlobSize() { return blobSize; } - public List getPartStreams() { - return partStreams; + public List getPartStreams() { + return asyncPartStreams; + } + + /** + * Functional interface defining an instance that can create an async action + * to create a part of an object represented as an InputStreamContainer. + */ + @FunctionalInterface + public interface StreamPartCreator extends Supplier> { + /** + * Kicks off a async process to start streaming. + * + * @return When the returned future is completed, streaming has + * just begun. Clients must fully consume the resulting stream. + */ + @Override + CompletableFuture get(); } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java deleted file mode 100644 index aadd6e2ab304e..0000000000000 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.common.blobstore.stream.read.listener; - -import org.opensearch.common.annotation.InternalApi; -import org.opensearch.core.action.ActionListener; - -import java.util.concurrent.atomic.AtomicInteger; - -/** - * FileCompletionListener listens for completion of fetch on all the streams for a file, where - * individual streams are handled using {@link FilePartWriter}. The {@link FilePartWriter}(s) - * hold a reference to the file completion listener to be notified. - */ -@InternalApi -class FileCompletionListener implements ActionListener { - - private final int numberOfParts; - private final String fileName; - private final AtomicInteger completedPartsCount; - private final ActionListener completionListener; - - public FileCompletionListener(int numberOfParts, String fileName, ActionListener completionListener) { - this.completedPartsCount = new AtomicInteger(); - this.numberOfParts = numberOfParts; - this.fileName = fileName; - this.completionListener = completionListener; - } - - @Override - public void onResponse(Integer unused) { - if (completedPartsCount.incrementAndGet() == numberOfParts) { - completionListener.onResponse(fileName); - } - } - - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } -} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java index 84fd7ed9ffebf..1a403200249cd 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -8,83 +8,37 @@ package org.opensearch.common.blobstore.stream.read.listener; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.io.Channels; import org.opensearch.common.io.InputStreamContainer; -import org.opensearch.core.action.ActionListener; import java.io.IOException; import java.io.InputStream; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.UnaryOperator; /** * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} - * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. + * instance. */ @InternalApi -class FilePartWriter implements Runnable { - - private final int partNumber; - private final InputStreamContainer blobPartStreamContainer; - private final Path fileLocation; - private final AtomicBoolean anyPartStreamFailed; - private final ActionListener fileCompletionListener; - private static final Logger logger = LogManager.getLogger(FilePartWriter.class); - +class FilePartWriter { // 8 MB buffer for transfer private static final int BUFFER_SIZE = 8 * 1024 * 2024; - public FilePartWriter( - int partNumber, - InputStreamContainer blobPartStreamContainer, - Path fileLocation, - AtomicBoolean anyPartStreamFailed, - ActionListener fileCompletionListener - ) { - this.partNumber = partNumber; - this.blobPartStreamContainer = blobPartStreamContainer; - this.fileLocation = fileLocation; - this.anyPartStreamFailed = anyPartStreamFailed; - this.fileCompletionListener = fileCompletionListener; - } - - @Override - public void run() { - // Ensures no writes to the file if any stream fails. - if (anyPartStreamFailed.get() == false) { - try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { - try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { - long streamOffset = blobPartStreamContainer.getOffset(); - final byte[] buffer = new byte[BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); - streamOffset += bytesRead; - } + public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator rateLimiter) throws IOException { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) { + long streamOffset = stream.getOffset(); + final byte[] buffer = new byte[BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); + streamOffset += bytesRead; } - } catch (IOException e) { - processFailure(e); - return; } - fileCompletionListener.onResponse(partNumber); - } - } - - void processFailure(Exception e) { - try { - Files.deleteIfExists(fileLocation); - } catch (IOException ex) { - // Die silently - logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); - } - if (anyPartStreamFailed.getAndSet(true) == false) { - fileCompletionListener.onFailure(e); } } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 4338bddb3fbe7..2914fd0c440fa 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,51 +10,73 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.UnaryOperator; /** * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} - * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are - * spread across a {@link ThreadPool} executor. + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams. */ @InternalApi public class ReadContextListener implements ActionListener { + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - private final String fileName; + private final String blobName; private final Path fileLocation; - private final ThreadPool threadPool; private final ActionListener completionListener; - private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + private final ThreadPool threadPool; + private final UnaryOperator rateLimiter; + private final int maxConcurrentStreams; - public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - this.fileName = fileName; + public ReadContextListener( + String blobName, + Path fileLocation, + ActionListener completionListener, + ThreadPool threadPool, + UnaryOperator rateLimiter, + int maxConcurrentStreams + ) { + this.blobName = blobName; this.fileLocation = fileLocation; - this.threadPool = threadPool; this.completionListener = completionListener; + this.threadPool = threadPool; + this.rateLimiter = rateLimiter; + this.maxConcurrentStreams = maxConcurrentStreams; } @Override public void onResponse(ReadContext readContext) { - logger.trace("Streams received for blob {}", fileName); + logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName); final int numParts = readContext.getNumberOfParts(); - final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); - - for (int partNumber = 0; partNumber < numParts; partNumber++) { - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - readContext.getPartStreams().get(partNumber), - fileLocation, - anyPartStreamFailed, - fileCompletionListener - ); - threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false); + final GroupedActionListener groupedListener = new GroupedActionListener<>( + ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure), + numParts + ); + final Queue queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams()); + final StreamPartProcessor processor = new StreamPartProcessor( + queue, + anyPartStreamFailed, + fileLocation, + groupedListener, + threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY), + rateLimiter + ); + for (int i = 0; i < Math.min(maxConcurrentStreams, queue.size()); i++) { + processor.process(queue.poll()); } } @@ -62,4 +84,79 @@ public void onResponse(ReadContext readContext) { public void onFailure(Exception e) { completionListener.onFailure(e); } + + private static class StreamPartProcessor { + private static final RuntimeException CANCELED_PART_EXCEPTION = new RuntimeException( + "Canceled part download due to previous failure" + ); + private final Queue queue; + private final AtomicBoolean anyPartStreamFailed; + private final Path fileLocation; + private final GroupedActionListener completionListener; + private final Executor executor; + private final UnaryOperator rateLimiter; + + private StreamPartProcessor( + Queue queue, + AtomicBoolean anyPartStreamFailed, + Path fileLocation, + GroupedActionListener completionListener, + Executor executor, + UnaryOperator rateLimiter + ) { + this.queue = queue; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileLocation = fileLocation; + this.completionListener = completionListener; + this.executor = executor; + this.rateLimiter = rateLimiter; + } + + private void process(ReadContext.StreamPartCreator supplier) { + if (supplier == null) { + return; + } + supplier.get().whenCompleteAsync((blobPartStreamContainer, throwable) -> { + if (throwable != null) { + processFailure(throwable instanceof Exception ? (Exception) throwable : new RuntimeException(throwable)); + } else if (anyPartStreamFailed.get()) { + processFailure(CANCELED_PART_EXCEPTION); + } else { + try { + FilePartWriter.write(fileLocation, blobPartStreamContainer, rateLimiter); + completionListener.onResponse(fileLocation.toString()); + + // Upon successfully completing a file part, pull another + // file part off the queue to trigger asynchronous processing + process(queue.poll()); + } catch (Exception e) { + processFailure(e); + } + } + }, executor); + } + + private void processFailure(Exception e) { + if (anyPartStreamFailed.getAndSet(true) == false) { + completionListener.onFailure(e); + + // Drain the queue of pending part downloads. These can be discarded + // since they haven't started any work yet, but the listener must be + // notified for each part. + Object item = queue.poll(); + while (item != null) { + completionListener.onFailure(CANCELED_PART_EXCEPTION); + item = queue.poll(); + } + } else { + completionListener.onFailure(e); + } + try { + Files.deleteIfExists(fileLocation); + } catch (IOException ex) { + // Die silently + logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); + } + } + } } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index fff34ad811b41..1f12971fe4771 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -289,6 +289,7 @@ public void apply(Settings value, Settings current, Settings previous) { RecoverySettings.INDICES_RECOVERY_INTERNAL_LONG_ACTION_TIMEOUT_SETTING, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING, + RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_INITIAL_PRIMARIES_RECOVERIES_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_INITIAL_REPLICAS_RECOVERIES_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_CONCURRENT_INCOMING_RECOVERIES_SETTING, diff --git a/server/src/main/java/org/opensearch/index/IndexService.java b/server/src/main/java/org/opensearch/index/IndexService.java index fdda8d4ce2497..df8e8070b8e03 100644 --- a/server/src/main/java/org/opensearch/index/IndexService.java +++ b/server/src/main/java/org/opensearch/index/IndexService.java @@ -89,6 +89,7 @@ import org.opensearch.index.shard.ShardNotInPrimaryModeException; import org.opensearch.index.shard.ShardPath; import org.opensearch.index.similarity.SimilarityService; +import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; import org.opensearch.index.store.Store; import org.opensearch.index.translog.Translog; import org.opensearch.index.translog.TranslogFactory; @@ -520,7 +521,8 @@ public synchronized IndexShard createShard( remoteStore, remoteStoreStatsTrackerFactory, clusterRemoteTranslogBufferIntervalSupplier, - nodeEnv.nodeId() + nodeEnv.nodeId(), + (RemoteSegmentStoreDirectoryFactory) remoteDirectoryFactory ); eventListener.indexShardStateChanged(indexShard, null, indexShard.state(), "shard created"); eventListener.afterIndexShardCreated(indexShard); diff --git a/server/src/main/java/org/opensearch/index/shard/IndexShard.java b/server/src/main/java/org/opensearch/index/shard/IndexShard.java index 70e84cb15cde2..abc4f3dcd5c3d 100644 --- a/server/src/main/java/org/opensearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/opensearch/index/shard/IndexShard.java @@ -63,7 +63,6 @@ import org.opensearch.action.admin.indices.flush.FlushRequest; import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest; import org.opensearch.action.admin.indices.upgrade.post.UpgradeRequest; -import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.replication.PendingReplicationActions; import org.opensearch.action.support.replication.ReplicationResponse; @@ -164,6 +163,7 @@ import org.opensearch.index.shard.PrimaryReplicaSyncer.ResyncTask; import org.opensearch.index.similarity.SimilarityService; import org.opensearch.index.store.RemoteSegmentStoreDirectory; +import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; import org.opensearch.index.store.Store; import org.opensearch.index.store.Store.MetadataSnapshot; import org.opensearch.index.store.StoreFileMetadata; @@ -343,6 +343,7 @@ Runnable getGlobalCheckpointSyncer() { private final RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory; private final List internalRefreshListener = new ArrayList<>(); + private final RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory; public IndexShard( final ShardRouting shardRouting, @@ -370,7 +371,11 @@ public IndexShard( @Nullable final Store remoteStore, final RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory, final Supplier clusterRemoteTranslogBufferIntervalSupplier, - final String nodeId + final String nodeId, + // Wiring a directory factory here breaks some intended abstractions, but this remote directory + // factory is used not as a Lucene directory but instead to copy files from a remote store when + // restoring a shallow snapshot. + @Nullable final RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory ) throws IOException { super(shardRouting.shardId(), indexSettings); assert shardRouting.initializing(); @@ -466,6 +471,7 @@ public boolean shouldCache(Query query) { ? false : mapperService.documentMapper().mappers().containsTimeStampField(); this.remoteStoreStatsTrackerFactory = remoteStoreStatsTrackerFactory; + this.remoteSegmentStoreDirectoryFactory = remoteSegmentStoreDirectoryFactory; } public ThreadPool getThreadPool() { @@ -2701,7 +2707,7 @@ public void restoreFromRemoteStore(ActionListener listener) { public void restoreFromSnapshotAndRemoteStore( Repository repository, - RepositoriesService repositoriesService, + RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory, ActionListener listener ) { try { @@ -2709,7 +2715,7 @@ public void restoreFromSnapshotAndRemoteStore( assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.SNAPSHOT : "invalid recovery type: " + recoveryState.getRecoverySource(); StoreRecovery storeRecovery = new StoreRecovery(shardId, logger); - storeRecovery.recoverFromSnapshotAndRemoteStore(this, repository, repositoriesService, listener, threadPool); + storeRecovery.recoverFromSnapshotAndRemoteStore(this, repository, remoteSegmentStoreDirectoryFactory, listener); } catch (Exception e) { listener.onFailure(e); } @@ -3555,7 +3561,7 @@ public void startRecovery( "from snapshot and remote store", recoveryState, recoveryListener, - l -> restoreFromSnapshotAndRemoteStore(repositoriesService.repository(repo), repositoriesService, l) + l -> restoreFromSnapshotAndRemoteStore(repositoriesService.repository(repo), remoteSegmentStoreDirectoryFactory, l) ); // indicesService.indexService(shardRouting.shardId().getIndex()).addMetadataListener(); } else { @@ -4930,24 +4936,17 @@ private void downloadSegments( RemoteSegmentStoreDirectory targetRemoteDirectory, Set toDownloadSegments, final Runnable onFileSync - ) { - final PlainActionFuture completionListener = PlainActionFuture.newFuture(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> null), - toDownloadSegments.size() - ); - - final ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, fileName -> { + ) throws IOException { + final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); + for (String segment : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + sourceRemoteDirectory.copyTo(segment, storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); onFileSync.run(); if (targetRemoteDirectory != null) { - targetRemoteDirectory.copyFrom(storeDirectory, fileName, fileName, IOContext.DEFAULT); + targetRemoteDirectory.copyFrom(storeDirectory, segment, segment, IOContext.DEFAULT); } - return null; - }); - - final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); - toDownloadSegments.forEach(file -> { sourceRemoteDirectory.copyTo(file, storeDirectory, indexPath, segmentsDownloadListener); }); - completionListener.actionGet(); + } } private boolean localDirectoryContains(Directory localDirectory, String file, long checksum) { diff --git a/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java b/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java index c0211e1257c8e..762aab51469d0 100644 --- a/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java +++ b/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java @@ -70,9 +70,7 @@ import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.indices.replication.common.ReplicationLuceneIndex; import org.opensearch.repositories.IndexId; -import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.channels.FileChannel; @@ -362,9 +360,8 @@ void recoverFromRepository(final IndexShard indexShard, Repository repository, A void recoverFromSnapshotAndRemoteStore( final IndexShard indexShard, Repository repository, - RepositoriesService repositoriesService, - ActionListener listener, - ThreadPool threadPool + RemoteSegmentStoreDirectoryFactory directoryFactory, + ActionListener listener ) { try { if (canRecover(indexShard)) { @@ -392,10 +389,6 @@ void recoverFromSnapshotAndRemoteStore( remoteStoreRepository = shallowCopyShardMetadata.getRemoteStoreRepository(); } - RemoteSegmentStoreDirectoryFactory directoryFactory = new RemoteSegmentStoreDirectoryFactory( - () -> repositoriesService, - threadPool - ); RemoteSegmentStoreDirectory sourceRemoteDirectory = (RemoteSegmentStoreDirectory) directoryFactory.newDirectory( remoteStoreRepository, indexUUID, diff --git a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java index 594b7f99cd85a..eb75c39532d71 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java @@ -62,9 +62,9 @@ public class RemoteDirectory extends Directory { protected final BlobContainer blobContainer; private static final Logger logger = LogManager.getLogger(RemoteDirectory.class); - protected final UnaryOperator uploadRateLimiter; + private final UnaryOperator uploadRateLimiter; - protected final UnaryOperator downloadRateLimiter; + private final UnaryOperator downloadRateLimiter; /** * Number of bytes in the segment file to store checksum @@ -333,6 +333,10 @@ public boolean copyFrom( return false; } + protected UnaryOperator getDownloadRateLimiter() { + return downloadRateLimiter; + } + private void uploadBlob( Directory from, String src, diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java index 21a84f2b8c903..a97b22360716c 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java @@ -25,6 +25,7 @@ import org.apache.lucene.util.Version; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.logging.Loggers; @@ -37,6 +38,7 @@ import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadata; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadataHandler; +import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.threadpool.ThreadPool; @@ -90,6 +92,8 @@ public final class RemoteSegmentStoreDirectory extends FilterDirectory implement private final ThreadPool threadPool; + private final RecoverySettings recoverySettings; + /** * Keeps track of local segment filename to uploaded filename along with other attributes like checksum. * This map acts as a cache layer for uploaded segment filenames which helps avoid calling listAll() each time. @@ -122,13 +126,15 @@ public RemoteSegmentStoreDirectory( RemoteDirectory remoteMetadataDirectory, RemoteStoreLockManager mdLockManager, ThreadPool threadPool, - ShardId shardId + ShardId shardId, + RecoverySettings recoverySettings ) throws IOException { super(remoteDataDirectory); this.remoteDataDirectory = remoteDataDirectory; this.remoteMetadataDirectory = remoteMetadataDirectory; this.mdLockManager = mdLockManager; this.threadPool = threadPool; + this.recoverySettings = recoverySettings; this.logger = Loggers.getLogger(getClass(), shardId); init(); } @@ -488,7 +494,15 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) { final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer(); final Path destinationFilePath = destinationPath.resolve(source); - blobContainer.asyncBlobDownload(blobName, destinationFilePath, threadPool, fileCompletionListener); + final ReadContextListener readContextListener = new ReadContextListener( + blobName, + destinationFilePath, + fileCompletionListener, + threadPool, + remoteDataDirectory.getDownloadRateLimiter(), + recoverySettings.getMaxConcurrentRemoteStoreStreams() + ); + blobContainer.readBlobAsync(blobName, readContextListener); } else { // Fallback to older mechanism of downloading the file try { diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java index a64b8536aa946..cc55380894ecd 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java @@ -13,8 +13,9 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexSettings; import org.opensearch.index.shard.ShardPath; +import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; import org.opensearch.index.store.lockmanager.RemoteStoreLockManagerFactory; -import org.opensearch.index.store.lockmanager.RemoteStoreMetadataLockManager; +import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.plugins.IndexStorePlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; @@ -34,12 +35,18 @@ public class RemoteSegmentStoreDirectoryFactory implements IndexStorePlugin.Dire private static final String SEGMENTS = "segments"; private final Supplier repositoriesService; + private final RecoverySettings recoverySettings; private final ThreadPool threadPool; - public RemoteSegmentStoreDirectoryFactory(Supplier repositoriesService, ThreadPool threadPool) { + public RemoteSegmentStoreDirectoryFactory( + Supplier repositoriesService, + ThreadPool threadPool, + RecoverySettings recoverySettings + ) { this.repositoriesService = repositoriesService; this.threadPool = threadPool; + this.recoverySettings = recoverySettings; } @Override @@ -64,20 +71,16 @@ public Directory newDirectory(String repositoryName, String indexUUID, ShardId s RemoteDirectory metadataDirectory = new RemoteDirectory( blobStoreRepository.blobStore().blobContainer(commonBlobPath.add("metadata")) ); - RemoteStoreMetadataLockManager mdLockManager = RemoteStoreLockManagerFactory.newLockManager( + RemoteStoreLockManager mdLockManager = RemoteStoreLockManagerFactory.newLockManager( repositoriesService.get(), repositoryName, indexUUID, String.valueOf(shardId.id()) ); - return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, mdLockManager, threadPool, shardId); + return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, mdLockManager, threadPool, shardId, recoverySettings); } catch (RepositoryMissingException e) { throw new IllegalArgumentException("Repository should be created before creating index with remote_store enabled setting", e); } } - - private RemoteDirectory createRemoteDirectory(BlobStoreRepository repository, BlobPath commonBlobPath, String extension) { - return new RemoteDirectory(repository.blobStore().blobContainer(commonBlobPath.add(extension))); - } } diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java index e2346ae078339..ed9755bf824ea 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java @@ -84,6 +84,17 @@ public class RecoverySettings { Property.NodeScope ); + /** + * Controls the maximum number of streams that can be started concurrently when downloading from the remote store. + */ + public static final Setting INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING = Setting.intSetting( + "indices.recovery.max_concurrent_remote_store_streams", + 20, + 1, + Property.Dynamic, + Property.NodeScope + ); + /** * how long to wait before retrying after issues cause by cluster state syncing between nodes * i.e., local node is not yet known on remote node, remote shard not yet started etc. @@ -149,6 +160,7 @@ public class RecoverySettings { private volatile ByteSizeValue maxBytesPerSec; private volatile int maxConcurrentFileChunks; private volatile int maxConcurrentOperations; + private volatile int maxConcurrentRemoteStoreStreams; private volatile SimpleRateLimiter rateLimiter; private volatile TimeValue retryDelayStateSync; private volatile TimeValue retryDelayNetwork; @@ -163,6 +175,7 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { this.retryDelayStateSync = INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING.get(settings); this.maxConcurrentFileChunks = INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING.get(settings); this.maxConcurrentOperations = INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING.get(settings); + this.maxConcurrentRemoteStoreStreams = INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING.get(settings); // doesn't have to be fast as nodes are reconnected every 10s by default (see InternalClusterService.ReconnectToNodes) // and we want to give the cluster-manager time to remove a faulty node this.retryDelayNetwork = INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING.get(settings); @@ -184,6 +197,10 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, this::setMaxBytesPerSec); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING, this::setMaxConcurrentFileChunks); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING, this::setMaxConcurrentOperations); + clusterSettings.addSettingsUpdateConsumer( + INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING, + this::setMaxConcurrentRemoteStoreStreams + ); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING, this::setRetryDelayStateSync); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING, this::setRetryDelayNetwork); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_INTERNAL_ACTION_TIMEOUT_SETTING, this::setInternalActionTimeout); @@ -279,4 +296,12 @@ public int getMaxConcurrentOperations() { private void setMaxConcurrentOperations(int maxConcurrentOperations) { this.maxConcurrentOperations = maxConcurrentOperations; } + + public int getMaxConcurrentRemoteStoreStreams() { + return this.maxConcurrentRemoteStoreStreams; + } + + private void setMaxConcurrentRemoteStoreStreams(int maxConcurrentRemoteStoreStreams) { + this.maxConcurrentRemoteStoreStreams = maxConcurrentRemoteStoreStreams; + } } diff --git a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java index aeb690465905f..e17c5293c38ac 100644 --- a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java @@ -14,7 +14,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.util.Version; -import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.core.action.ActionListener; import org.opensearch.index.shard.IndexShard; @@ -141,14 +141,12 @@ private void downloadSegments( ActionListener completionListener ) { final Path indexPath = shardPath == null ? null : shardPath.resolveIndex(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> new GetSegmentFilesResponse(toDownloadSegments)), - toDownloadSegments.size() - ); - ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, result -> null); - toDownloadSegments.forEach( - fileMetadata -> remoteStoreDirectory.copyTo(fileMetadata.name(), storeDirectory, indexPath, segmentsDownloadListener) - ); + for (StoreFileMetadata storeFileMetadata : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + remoteStoreDirectory.copyTo(storeFileMetadata.name(), storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); + } + completionListener.onResponse(new GetSegmentFilesResponse(toDownloadSegments)); } @Override diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 711b32c5ade91..b88dd3a0ef72a 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -759,9 +759,12 @@ protected Node( rerouteServiceReference.set(rerouteService); clusterService.setRerouteService(rerouteService); + final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); + final IndexStorePlugin.DirectoryFactory remoteDirectoryFactory = new RemoteSegmentStoreDirectoryFactory( repositoriesServiceReference::get, - threadPool + threadPool, + recoverySettings ); final SearchRequestStats searchRequestStats = new SearchRequestStats(); @@ -952,7 +955,6 @@ protected Node( transportService.getTaskManager() ); - final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); RepositoriesModule repositoriesModule = new RepositoriesModule( this.environment, pluginsService.filterPlugins(RepositoryPlugin.class), diff --git a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java index 0e9ae0e06cdeb..74b1a44399514 100644 --- a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java @@ -1189,7 +1189,8 @@ private void executeStaleShardDelete( // see https://github.com/opensearch-project/OpenSearch/issues/8469 new RemoteSegmentStoreDirectoryFactory( remoteStoreLockManagerFactory.getRepositoriesService(), - threadPool + threadPool, + recoverySettings ).newDirectory( remoteStoreRepoForIndex, indexUUID, @@ -1659,7 +1660,8 @@ private void executeOneStaleIndexDelete( // see https://github.com/opensearch-project/OpenSearch/issues/8469 new RemoteSegmentStoreDirectoryFactory( remoteStoreLockManagerFactory.getRepositoriesService(), - threadPool + threadPool, + recoverySettings ).newDirectory( remoteStoreRepoForIndex, indexUUID, diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 8a860b598b6d8..a09c9efbd77cf 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -114,6 +114,7 @@ public static class Names { public static final String TRANSLOG_SYNC = "translog_sync"; public static final String REMOTE_PURGE = "remote_purge"; public static final String REMOTE_REFRESH_RETRY = "remote_refresh_retry"; + public static final String REMOTE_RECOVERY = "remote_recovery"; public static final String INDEX_SEARCHER = "index_searcher"; } @@ -183,6 +184,7 @@ public static ThreadPoolType fromType(String type) { map.put(Names.TRANSLOG_SYNC, ThreadPoolType.FIXED); map.put(Names.REMOTE_PURGE, ThreadPoolType.SCALING); map.put(Names.REMOTE_REFRESH_RETRY, ThreadPoolType.SCALING); + map.put(Names.REMOTE_RECOVERY, ThreadPoolType.SCALING); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { map.put(Names.INDEX_SEARCHER, ThreadPoolType.FIXED_AUTO_QUEUE_SIZE); } @@ -280,6 +282,10 @@ public ThreadPool( Names.REMOTE_REFRESH_RETRY, new ScalingExecutorBuilder(Names.REMOTE_REFRESH_RETRY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) ); + builders.put( + Names.REMOTE_RECOVERY, + new ScalingExecutorBuilder(Names.REMOTE_RECOVERY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) + ); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { builders.put( Names.INDEX_SEARCHER, diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java index 947a4f9b1c9ab..1780819390052 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -20,6 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -51,10 +52,12 @@ public void testReadBlobAsync() throws Exception { // Objects needed for API call final byte[] data = new byte[size]; Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); @@ -76,7 +79,7 @@ public void testReadBlobAsync() throws Exception { assertEquals(1, response.getNumberOfParts()); assertEquals(size, response.getBlobSize()); - InputStreamContainer responseContainer = response.getPartStreams().get(0); + InputStreamContainer responseContainer = response.getPartStreams().get(0).get().join(); assertEquals(0, responseContainer.getOffset()); assertEquals(size, responseContainer.getContentLength()); assertEquals(100, responseContainer.getInputStream().available()); @@ -99,7 +102,8 @@ public void testReadBlobAsyncException() throws Exception { final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java deleted file mode 100644 index fa13d90f42fa6..0000000000000 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.common.blobstore.stream.read.listener; - -import org.opensearch.test.OpenSearchTestCase; - -import java.io.IOException; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; - -public class FileCompletionListenerTests extends OpenSearchTestCase { - - public void testFileCompletionListener() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - for (int stream = 0; stream < numStreams; stream++) { - // Ensure completion listener called only when all streams are completed - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getResponseCount()); - assertEquals(fileName, completionListener.getResponse()); - } - - public void testFileCompletionListenerFailure() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - // Fail the listener initially - IOException exception = new IOException(); - fileCompletionListener.onFailure(exception); - - for (int stream = 0; stream < numStreams - 1; stream++) { - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - assertEquals(0, completionListener.getResponseCount()); - - fileCompletionListener.onFailure(exception); - assertEquals(2, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - } -} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java index 811566eb5767b..f2a758b9bbe10 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -13,14 +13,11 @@ import org.junit.Before; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.UUID; -import java.util.concurrent.atomic.AtomicBoolean; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; +import java.util.function.UnaryOperator; public class FilePartWriterTests extends OpenSearchTestCase { @@ -34,130 +31,37 @@ public void init() throws Exception { public void testFilePartWriter() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterWithOffset() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; int offset = 10; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength + offset, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterLargeInput() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 20 * 1024 * 1024; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); - } - - public void testFilePartWriterException() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; - int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - IOException ioException = new IOException(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - assertFalse(anyStreamFailed.get()); - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - // Fail stream again to simulate another stream failure for same file - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - assertEquals(0, fileCompletionListener.getResponseCount()); - assertEquals(1, fileCompletionListener.getFailureCount()); - assertEquals(ioException, fileCompletionListener.getException()); - } - - public void testFilePartWriterStreamFailed() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; - int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(true); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); - - assertFalse(Files.exists(segmentFilePath)); - assertEquals(0, fileCompletionListener.getResponseCount()); } } diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 21b7b47390a9b..7e4c96cbadcda 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -29,7 +29,9 @@ import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.function.UnaryOperator; import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; @@ -46,6 +48,7 @@ public class ReadContextListenerTests extends OpenSearchTestCase { private static final int NUMBER_OF_PARTS = 5; private static final int PART_SIZE = 10; private static final String TEST_SEGMENT_FILE = "test_segment_file"; + private static final int MAX_CONCURRENT_STREAMS = 10; @BeforeClass public static void setup() { @@ -64,10 +67,17 @@ public void init() throws Exception { public void testReadContextListener() throws InterruptedException, IOException { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -79,10 +89,17 @@ public void testReadContextListener() throws InterruptedException, IOException { public void testReadContextListenerFailure() throws Exception { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); InputStream badInputStream = new InputStream() { @Override @@ -101,7 +118,13 @@ public int available() { } }; - blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + blobPartStreams.add( + NUMBER_OF_PARTS, + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), + threadPool.generic() + ) + ); ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -112,18 +135,31 @@ public int available() { public void testReadContextListenerException() { Path fileLocation = path.resolve(UUID.randomUUID().toString()); CountingCompletionListener listener = new CountingCompletionListener(); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + listener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); IOException exception = new IOException(); readContextListener.onFailure(exception); assertEquals(1, listener.getFailureCount()); assertEquals(exception, listener.getException()); } - private List initializeBlobPartStreams() { - List blobPartStreams = new ArrayList<>(); + private List initializeBlobPartStreams() { + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); - blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + int finalPartNumber = partNumber; + blobPartStreams.add( + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(testStream, PART_SIZE, (long) finalPartNumber * PART_SIZE), + threadPool.generic() + ) + ); } return blobPartStreams; } diff --git a/server/src/test/java/org/opensearch/index/IndexModuleTests.java b/server/src/test/java/org/opensearch/index/IndexModuleTests.java index caa6aa23b5d89..e77a1f02b86e0 100644 --- a/server/src/test/java/org/opensearch/index/IndexModuleTests.java +++ b/server/src/test/java/org/opensearch/index/IndexModuleTests.java @@ -104,6 +104,7 @@ import org.opensearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason; import org.opensearch.indices.fielddata.cache.IndicesFieldDataCache; import org.opensearch.indices.mapper.MapperRegistry; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.plugins.IndexStorePlugin; import org.opensearch.repositories.RepositoriesService; @@ -256,7 +257,7 @@ private IndexService newIndexService(IndexModule module) throws IOException { writableRegistry(), () -> false, null, - new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool), + new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool, DefaultRecoverySettings.INSTANCE), translogFactorySupplier, () -> IndexSettings.DEFAULT_REFRESH_INTERVAL, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL diff --git a/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java b/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java index 5a13f57db2c87..941f2f48e71af 100644 --- a/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java +++ b/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java @@ -33,6 +33,7 @@ import org.opensearch.index.store.RemoteSegmentStoreDirectory.MetadataFilenameUtils; import org.opensearch.index.store.Store; import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.threadpool.ThreadPool; @@ -155,7 +156,8 @@ public void testRemoteDirectoryInitThrowsException() throws IOException { remoteMetadataDirectory, mock(RemoteStoreLockManager.class), mock(ThreadPool.class), - shardId + shardId, + DefaultRecoverySettings.INSTANCE ); FilterDirectory remoteStoreFilterDirectory = new RemoteStoreRefreshListenerTests.TestFilterDirectory( new RemoteStoreRefreshListenerTests.TestFilterDirectory(remoteSegmentStoreDirectory) diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java index cad5e47531cc6..78c7fe64cebd9 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java @@ -20,6 +20,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexSettings; import org.opensearch.index.shard.ShardPath; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.RepositoryMissingException; import org.opensearch.repositories.blobstore.BlobStoreRepository; @@ -57,7 +58,11 @@ public void setup() { repositoriesService = mock(RepositoriesService.class); threadPool = mock(ThreadPool.class); when(repositoriesServiceSupplier.get()).thenReturn(repositoriesService); - remoteSegmentStoreDirectoryFactory = new RemoteSegmentStoreDirectoryFactory(repositoriesServiceSupplier, threadPool); + remoteSegmentStoreDirectoryFactory = new RemoteSegmentStoreDirectoryFactory( + repositoriesServiceSupplier, + threadPool, + DefaultRecoverySettings.INSTANCE + ); } public void testNewDirectory() throws IOException { diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java index f154dddb0e7cc..b574ccaac55e1 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java @@ -25,12 +25,13 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; @@ -41,6 +42,7 @@ import org.opensearch.index.store.lockmanager.RemoteStoreMetadataLockManager; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadata; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadataHandler; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.threadpool.ThreadPool; import org.junit.After; @@ -56,9 +58,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -145,13 +149,16 @@ public void setup() throws IOException { remoteMetadataDirectory, mdLockManager, threadPool, - indexShard.shardId() + indexShard.shardId(), + DefaultRecoverySettings.INSTANCE ); try (Store store = indexShard.store()) { segmentInfos = store.readLastCommittedSegmentsInfo(); } + when(remoteDataDirectory.getDownloadRateLimiter()).thenReturn(UnaryOperator.identity()); when(threadPool.executor(ThreadPool.Names.REMOTE_PURGE)).thenReturn(executorService); + when(threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY)).thenReturn(executorService); } @After @@ -562,9 +569,6 @@ public void onFailure(Exception e) {} } public void testCopyFilesToMultipart() throws Exception { - Settings settings = Settings.builder().build(); - FeatureFlags.initializeFeatureFlags(settings); - String filename = "_0.cfe"; populateMetadata(); remoteSegmentStoreDirectory.init(); @@ -574,13 +578,15 @@ public void testCopyFilesToMultipart() throws Exception { when(remoteDataDirectory.getBlobContainer()).thenReturn(blobContainer); Mockito.doAnswer(invocation -> { - ActionListener completionListener = invocation.getArgument(3); - completionListener.onResponse(invocation.getArgument(0)); + ActionListener completionListener = invocation.getArgument(1); + final CompletableFuture future = new CompletableFuture<>(); + future.complete(new InputStreamContainer(new ByteArrayInputStream(new byte[] { 42 }), 0, 1)); + completionListener.onResponse(new ReadContext(1, List.of(() -> future), "")); return null; - }).when(blobContainer).asyncBlobDownload(any(), any(), any(), any()); + }).when(blobContainer).readBlobAsync(any(), any()); CountDownLatch downloadLatch = new CountDownLatch(1); - ActionListener completionListener = new ActionListener() { + ActionListener completionListener = new ActionListener<>() { @Override public void onResponse(String unused) { downloadLatch.countDown(); @@ -592,7 +598,7 @@ public void onFailure(Exception e) {} Path path = createTempDir(); remoteSegmentStoreDirectory.copyTo(filename, storeDirectory, path, completionListener); assertTrue(downloadLatch.await(5000, TimeUnit.SECONDS)); - verify(blobContainer, times(1)).asyncBlobDownload(contains(filename), eq(path.resolve(filename)), any(), any()); + verify(blobContainer, times(1)).readBlobAsync(contains(filename), any()); verify(storeDirectory, times(0)).copyFrom(any(), any(), any(), any()); } @@ -678,7 +684,8 @@ public void testCopyFilesFromMultipartIOException() throws Exception { remoteMetadataDirectory, mdLockManager, threadPool, - indexShard.shardId() + indexShard.shardId(), + DefaultRecoverySettings.INSTANCE ); populateMetadata(); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 09f5c1bea1a5e..80731b378f369 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -195,6 +195,7 @@ import org.opensearch.indices.analysis.AnalysisModule; import org.opensearch.indices.cluster.IndicesClusterStateService; import org.opensearch.indices.mapper.MapperRegistry; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.PeerRecoverySourceService; import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoverySettings; @@ -2066,7 +2067,7 @@ public void onFailure(final Exception e) { emptyMap(), null, emptyMap(), - new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool), + new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool, DefaultRecoverySettings.INSTANCE), repositoriesServiceReference::get, fileCacheCleaner, null, diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index 466c00d0648dc..186c1c7e78f6b 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -119,6 +119,7 @@ import org.opensearch.indices.IndicesService; import org.opensearch.indices.breaker.HierarchyCircuitBreakerService; import org.opensearch.indices.recovery.AsyncRecoveryTarget; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoveryFailedException; import org.opensearch.indices.recovery.RecoveryResponse; @@ -640,7 +641,7 @@ protected IndexShard newShard( Collections.emptyList(), clusterSettings ); - Store remoteStore = null; + Store remoteStore; RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory = null; RepositoriesService mockRepoSvc = mock(RepositoriesService.class); @@ -659,6 +660,8 @@ protected IndexShard newShard( remoteStoreStatsTrackerFactory = new RemoteStoreStatsTrackerFactory(clusterService, indexSettings.getSettings()); BlobStoreRepository repo = createRepository(remotePath); when(mockRepoSvc.repository(any())).thenAnswer(invocationOnMock -> repo); + } else { + remoteStore = null; } final BiFunction translogFactorySupplier = (settings, shardRouting) -> { @@ -698,7 +701,8 @@ protected IndexShard newShard( remoteStore, remoteStoreStatsTrackerFactory, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL, - "dummy-node" + "dummy-node", + null ); indexShard.addShardFailureCallback(DEFAULT_SHARD_FAILURE_HANDLER); if (remoteStoreStatsTrackerFactory != null) { @@ -785,7 +789,14 @@ protected RemoteSegmentStoreDirectory createRemoteSegmentStoreDirectory(ShardId RemoteStoreLockManager remoteStoreLockManager = new RemoteStoreMetadataLockManager( new RemoteBufferedOutputDirectory(getBlobContainer(remoteShardPath.resolveIndex())) ); - return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, remoteStoreLockManager, threadPool, shardId); + return new RemoteSegmentStoreDirectory( + dataDirectory, + metadataDirectory, + remoteStoreLockManager, + threadPool, + shardId, + DefaultRecoverySettings.INSTANCE + ); } private RemoteDirectory newRemoteDirectory(Path f) throws IOException { diff --git a/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java b/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java new file mode 100644 index 0000000000000..359668f5dad71 --- /dev/null +++ b/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.recovery; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; + +/** + * Utility to provide a {@link RecoverySettings} instance containing all defaults + */ +public final class DefaultRecoverySettings { + private DefaultRecoverySettings() {} + + public static final RecoverySettings INSTANCE = new RecoverySettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); +}