diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index c6ae58371e15c..fcfccf50ad326 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -241,37 +241,23 @@ public void readBlobAsync(String blobName, ActionListener listener) return; } - final List> blobPartInputStreamFutures = new ArrayList<>(); + final List blobPartInputStreamFutures = new ArrayList<>(); final long blobSize = blobMetadata.objectSize(); final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount(); final String blobChecksum = blobMetadata.checksum().checksumCRC32(); if (numberOfParts == null) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); + blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null)); } else { // S3 multipart files use 1 to n indexing for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) { - blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber)); + final int innerPartNumber = partNumber; + blobPartInputStreamFutures.add( + () -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber) + ); } } - - CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)) - .whenComplete((unused, partThrowable) -> { - if (partThrowable == null) { - listener.onResponse( - new ReadContext( - blobSize, - blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()), - blobChecksum - ) - ); - } else { - Exception ex = partThrowable.getCause() instanceof Exception - ? (Exception) partThrowable.getCause() - : new Exception(partThrowable.getCause()); - listener.onFailure(ex); - } - }); + listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum)); }); } catch (Exception ex) { listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex)); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 9817d7cd520ef..2e54705e9cd78 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception { assertEquals(objectSize, readContext.getBlobSize()); for (int partNumber = 1; partNumber < objectPartCount; partNumber++) { - InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join(); final int offset = partNumber * partSize; assertEquals(partSize, inputStreamContainer.getContentLength()); assertEquals(offset, inputStreamContainer.getOffset()); @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception { assertEquals(checksum, readContext.getBlobChecksum()); assertEquals(objectSize, readContext.getBlobSize()); - InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get(); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join(); assertEquals(objectSize, inputStreamContainer.getContentLength()); assertEquals(0, inputStreamContainer.getOffset()); assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length); diff --git a/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java index 07f85496f13cf..c394a1f631690 100644 --- a/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/index/shard/IndexShardIT.java @@ -712,7 +712,8 @@ public static final IndexShard newIndexShard( null, null, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL, - nodeId + nodeId, + null ); } diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java index 079753de95680..36987ac2d4991 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java @@ -27,6 +27,7 @@ import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener listener) long contentLength = listBlobs().get(blobName).length(); long partSize = contentLength / 10; int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); - List blobPartStreams = new ArrayList<>(); + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { long offset = partNumber * partSize; InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); - blobPartStreams.add(blobPartStream); + blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream)); } ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); listener.onResponse(blobReadContext); diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java index e73a9f5cd0bc9..97f304d776f5c 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java @@ -10,13 +10,10 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; -import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; -import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -45,19 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer { @ExperimentalApi void readBlobAsync(String blobName, ActionListener listener); - /** - * Asynchronously downloads the blob to the specified location using an executor from the thread pool. - * @param blobName The name of the blob for which needs to be downloaded. - * @param fileLocation The path on local disk where the blob needs to be downloaded. - * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. - * @param completionListener Listener which will be notified when the download is complete. - */ - @ExperimentalApi - default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); - readBlobAsync(blobName, readContextListener); - } - /* * Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected * checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index c64dc6b9e3ae4..82bc7a0baed50 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -144,8 +144,10 @@ public long getBlobSize() { } @Override - public List getPartStreams() { - return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList()); + public List getPartStreams() { + return super.getPartStreams().stream() + .map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer)) + .collect(Collectors.toUnmodifiableList()); } /** diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index 2c305fb03c475..4bdce11ff4f9a 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -12,6 +12,8 @@ import org.opensearch.common.io.InputStreamContainer; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; /** * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync @@ -19,18 +21,18 @@ @ExperimentalApi public class ReadContext { private final long blobSize; - private final List partStreams; + private final List asyncPartStreams; private final String blobChecksum; - public ReadContext(long blobSize, List partStreams, String blobChecksum) { + public ReadContext(long blobSize, List asyncPartStreams, String blobChecksum) { this.blobSize = blobSize; - this.partStreams = partStreams; + this.asyncPartStreams = asyncPartStreams; this.blobChecksum = blobChecksum; } public ReadContext(ReadContext readContext) { this.blobSize = readContext.blobSize; - this.partStreams = readContext.partStreams; + this.asyncPartStreams = readContext.asyncPartStreams; this.blobChecksum = readContext.blobChecksum; } @@ -39,14 +41,30 @@ public String getBlobChecksum() { } public int getNumberOfParts() { - return partStreams.size(); + return asyncPartStreams.size(); } public long getBlobSize() { return blobSize; } - public List getPartStreams() { - return partStreams; + public List getPartStreams() { + return asyncPartStreams; + } + + /** + * Functional interface defining an instance that can create an async action + * to create a part of an object represented as an InputStreamContainer. + */ + @FunctionalInterface + public interface StreamPartCreator extends Supplier> { + /** + * Kicks off a async process to start streaming. + * + * @return When the returned future is completed, streaming has + * just begun. Clients must fully consume the resulting stream. + */ + @Override + CompletableFuture get(); } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java deleted file mode 100644 index aadd6e2ab304e..0000000000000 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.common.blobstore.stream.read.listener; - -import org.opensearch.common.annotation.InternalApi; -import org.opensearch.core.action.ActionListener; - -import java.util.concurrent.atomic.AtomicInteger; - -/** - * FileCompletionListener listens for completion of fetch on all the streams for a file, where - * individual streams are handled using {@link FilePartWriter}. The {@link FilePartWriter}(s) - * hold a reference to the file completion listener to be notified. - */ -@InternalApi -class FileCompletionListener implements ActionListener { - - private final int numberOfParts; - private final String fileName; - private final AtomicInteger completedPartsCount; - private final ActionListener completionListener; - - public FileCompletionListener(int numberOfParts, String fileName, ActionListener completionListener) { - this.completedPartsCount = new AtomicInteger(); - this.numberOfParts = numberOfParts; - this.fileName = fileName; - this.completionListener = completionListener; - } - - @Override - public void onResponse(Integer unused) { - if (completedPartsCount.incrementAndGet() == numberOfParts) { - completionListener.onResponse(fileName); - } - } - - @Override - public void onFailure(Exception e) { - completionListener.onFailure(e); - } -} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java index 84fd7ed9ffebf..1a403200249cd 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -8,83 +8,37 @@ package org.opensearch.common.blobstore.stream.read.listener; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.io.Channels; import org.opensearch.common.io.InputStreamContainer; -import org.opensearch.core.action.ActionListener; import java.io.IOException; import java.io.InputStream; import java.nio.channels.FileChannel; -import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.UnaryOperator; /** * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} - * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. + * instance. */ @InternalApi -class FilePartWriter implements Runnable { - - private final int partNumber; - private final InputStreamContainer blobPartStreamContainer; - private final Path fileLocation; - private final AtomicBoolean anyPartStreamFailed; - private final ActionListener fileCompletionListener; - private static final Logger logger = LogManager.getLogger(FilePartWriter.class); - +class FilePartWriter { // 8 MB buffer for transfer private static final int BUFFER_SIZE = 8 * 1024 * 2024; - public FilePartWriter( - int partNumber, - InputStreamContainer blobPartStreamContainer, - Path fileLocation, - AtomicBoolean anyPartStreamFailed, - ActionListener fileCompletionListener - ) { - this.partNumber = partNumber; - this.blobPartStreamContainer = blobPartStreamContainer; - this.fileLocation = fileLocation; - this.anyPartStreamFailed = anyPartStreamFailed; - this.fileCompletionListener = fileCompletionListener; - } - - @Override - public void run() { - // Ensures no writes to the file if any stream fails. - if (anyPartStreamFailed.get() == false) { - try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { - try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { - long streamOffset = blobPartStreamContainer.getOffset(); - final byte[] buffer = new byte[BUFFER_SIZE]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); - streamOffset += bytesRead; - } + public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator rateLimiter) throws IOException { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) { + long streamOffset = stream.getOffset(); + final byte[] buffer = new byte[BUFFER_SIZE]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset); + streamOffset += bytesRead; } - } catch (IOException e) { - processFailure(e); - return; } - fileCompletionListener.onResponse(partNumber); - } - } - - void processFailure(Exception e) { - try { - Files.deleteIfExists(fileLocation); - } catch (IOException ex) { - // Die silently - logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); - } - if (anyPartStreamFailed.getAndSet(true) == false) { - fileCompletionListener.onFailure(e); } } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 4338bddb3fbe7..2914fd0c440fa 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,51 +10,73 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; import java.nio.file.Path; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.UnaryOperator; /** * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} - * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are - * spread across a {@link ThreadPool} executor. + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams. */ @InternalApi public class ReadContextListener implements ActionListener { + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - private final String fileName; + private final String blobName; private final Path fileLocation; - private final ThreadPool threadPool; private final ActionListener completionListener; - private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + private final ThreadPool threadPool; + private final UnaryOperator rateLimiter; + private final int maxConcurrentStreams; - public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - this.fileName = fileName; + public ReadContextListener( + String blobName, + Path fileLocation, + ActionListener completionListener, + ThreadPool threadPool, + UnaryOperator rateLimiter, + int maxConcurrentStreams + ) { + this.blobName = blobName; this.fileLocation = fileLocation; - this.threadPool = threadPool; this.completionListener = completionListener; + this.threadPool = threadPool; + this.rateLimiter = rateLimiter; + this.maxConcurrentStreams = maxConcurrentStreams; } @Override public void onResponse(ReadContext readContext) { - logger.trace("Streams received for blob {}", fileName); + logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName); final int numParts = readContext.getNumberOfParts(); - final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); - - for (int partNumber = 0; partNumber < numParts; partNumber++) { - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - readContext.getPartStreams().get(partNumber), - fileLocation, - anyPartStreamFailed, - fileCompletionListener - ); - threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false); + final GroupedActionListener groupedListener = new GroupedActionListener<>( + ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure), + numParts + ); + final Queue queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams()); + final StreamPartProcessor processor = new StreamPartProcessor( + queue, + anyPartStreamFailed, + fileLocation, + groupedListener, + threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY), + rateLimiter + ); + for (int i = 0; i < Math.min(maxConcurrentStreams, queue.size()); i++) { + processor.process(queue.poll()); } } @@ -62,4 +84,79 @@ public void onResponse(ReadContext readContext) { public void onFailure(Exception e) { completionListener.onFailure(e); } + + private static class StreamPartProcessor { + private static final RuntimeException CANCELED_PART_EXCEPTION = new RuntimeException( + "Canceled part download due to previous failure" + ); + private final Queue queue; + private final AtomicBoolean anyPartStreamFailed; + private final Path fileLocation; + private final GroupedActionListener completionListener; + private final Executor executor; + private final UnaryOperator rateLimiter; + + private StreamPartProcessor( + Queue queue, + AtomicBoolean anyPartStreamFailed, + Path fileLocation, + GroupedActionListener completionListener, + Executor executor, + UnaryOperator rateLimiter + ) { + this.queue = queue; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileLocation = fileLocation; + this.completionListener = completionListener; + this.executor = executor; + this.rateLimiter = rateLimiter; + } + + private void process(ReadContext.StreamPartCreator supplier) { + if (supplier == null) { + return; + } + supplier.get().whenCompleteAsync((blobPartStreamContainer, throwable) -> { + if (throwable != null) { + processFailure(throwable instanceof Exception ? (Exception) throwable : new RuntimeException(throwable)); + } else if (anyPartStreamFailed.get()) { + processFailure(CANCELED_PART_EXCEPTION); + } else { + try { + FilePartWriter.write(fileLocation, blobPartStreamContainer, rateLimiter); + completionListener.onResponse(fileLocation.toString()); + + // Upon successfully completing a file part, pull another + // file part off the queue to trigger asynchronous processing + process(queue.poll()); + } catch (Exception e) { + processFailure(e); + } + } + }, executor); + } + + private void processFailure(Exception e) { + if (anyPartStreamFailed.getAndSet(true) == false) { + completionListener.onFailure(e); + + // Drain the queue of pending part downloads. These can be discarded + // since they haven't started any work yet, but the listener must be + // notified for each part. + Object item = queue.poll(); + while (item != null) { + completionListener.onFailure(CANCELED_PART_EXCEPTION); + item = queue.poll(); + } + } else { + completionListener.onFailure(e); + } + try { + Files.deleteIfExists(fileLocation); + } catch (IOException ex) { + // Die silently + logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex); + } + } + } } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index a8263285aaca5..8c3f50e87142d 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -288,6 +288,7 @@ public void apply(Settings value, Settings current, Settings previous) { RecoverySettings.INDICES_RECOVERY_INTERNAL_LONG_ACTION_TIMEOUT_SETTING, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING, RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING, + RecoverySettings.INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_INITIAL_PRIMARIES_RECOVERIES_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_INITIAL_REPLICAS_RECOVERIES_SETTING, ThrottlingAllocationDecider.CLUSTER_ROUTING_ALLOCATION_NODE_CONCURRENT_INCOMING_RECOVERIES_SETTING, diff --git a/server/src/main/java/org/opensearch/index/IndexService.java b/server/src/main/java/org/opensearch/index/IndexService.java index fdda8d4ce2497..df8e8070b8e03 100644 --- a/server/src/main/java/org/opensearch/index/IndexService.java +++ b/server/src/main/java/org/opensearch/index/IndexService.java @@ -89,6 +89,7 @@ import org.opensearch.index.shard.ShardNotInPrimaryModeException; import org.opensearch.index.shard.ShardPath; import org.opensearch.index.similarity.SimilarityService; +import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; import org.opensearch.index.store.Store; import org.opensearch.index.translog.Translog; import org.opensearch.index.translog.TranslogFactory; @@ -520,7 +521,8 @@ public synchronized IndexShard createShard( remoteStore, remoteStoreStatsTrackerFactory, clusterRemoteTranslogBufferIntervalSupplier, - nodeEnv.nodeId() + nodeEnv.nodeId(), + (RemoteSegmentStoreDirectoryFactory) remoteDirectoryFactory ); eventListener.indexShardStateChanged(indexShard, null, indexShard.state(), "shard created"); eventListener.afterIndexShardCreated(indexShard); diff --git a/server/src/main/java/org/opensearch/index/shard/IndexShard.java b/server/src/main/java/org/opensearch/index/shard/IndexShard.java index 5818b2d866854..4f08411c19b55 100644 --- a/server/src/main/java/org/opensearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/opensearch/index/shard/IndexShard.java @@ -62,7 +62,6 @@ import org.opensearch.action.admin.indices.flush.FlushRequest; import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest; import org.opensearch.action.admin.indices.upgrade.post.UpgradeRequest; -import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.replication.PendingReplicationActions; import org.opensearch.action.support.replication.ReplicationResponse; @@ -162,6 +161,7 @@ import org.opensearch.index.shard.PrimaryReplicaSyncer.ResyncTask; import org.opensearch.index.similarity.SimilarityService; import org.opensearch.index.store.RemoteSegmentStoreDirectory; +import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; import org.opensearch.index.store.Store; import org.opensearch.index.store.Store.MetadataSnapshot; import org.opensearch.index.store.StoreFileMetadata; @@ -341,6 +341,7 @@ Runnable getGlobalCheckpointSyncer() { private final RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory; private final List internalRefreshListener = new ArrayList<>(); + private final RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory; public IndexShard( final ShardRouting shardRouting, @@ -368,7 +369,11 @@ public IndexShard( @Nullable final Store remoteStore, final RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory, final Supplier clusterRemoteTranslogBufferIntervalSupplier, - final String nodeId + final String nodeId, + // Wiring a directory factory here breaks some intended abstractions, but this remote directory + // factory is used not as a Lucene directory but instead to copy files from a remote store when + // restoring a shallow snapshot. + @Nullable final RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory ) throws IOException { super(shardRouting.shardId(), indexSettings); assert shardRouting.initializing(); @@ -464,6 +469,7 @@ public boolean shouldCache(Query query) { ? false : mapperService.documentMapper().mappers().containsTimeStampField(); this.remoteStoreStatsTrackerFactory = remoteStoreStatsTrackerFactory; + this.remoteSegmentStoreDirectoryFactory = remoteSegmentStoreDirectoryFactory; } public ThreadPool getThreadPool() { @@ -2696,7 +2702,7 @@ public void restoreFromRemoteStore(ActionListener listener) { public void restoreFromSnapshotAndRemoteStore( Repository repository, - RepositoriesService repositoriesService, + RemoteSegmentStoreDirectoryFactory remoteSegmentStoreDirectoryFactory, ActionListener listener ) { try { @@ -2704,7 +2710,7 @@ public void restoreFromSnapshotAndRemoteStore( assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.SNAPSHOT : "invalid recovery type: " + recoveryState.getRecoverySource(); StoreRecovery storeRecovery = new StoreRecovery(shardId, logger); - storeRecovery.recoverFromSnapshotAndRemoteStore(this, repository, repositoriesService, listener, threadPool); + storeRecovery.recoverFromSnapshotAndRemoteStore(this, repository, remoteSegmentStoreDirectoryFactory, listener); } catch (Exception e) { listener.onFailure(e); } @@ -3544,7 +3550,7 @@ public void startRecovery( "from snapshot and remote store", recoveryState, recoveryListener, - l -> restoreFromSnapshotAndRemoteStore(repositoriesService.repository(repo), repositoriesService, l) + l -> restoreFromSnapshotAndRemoteStore(repositoriesService.repository(repo), remoteSegmentStoreDirectoryFactory, l) ); // indicesService.indexService(shardRouting.shardId().getIndex()).addMetadataListener(); } else { @@ -4921,24 +4927,17 @@ private void downloadSegments( RemoteSegmentStoreDirectory targetRemoteDirectory, Set toDownloadSegments, final Runnable onFileSync - ) { - final PlainActionFuture completionListener = PlainActionFuture.newFuture(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> null), - toDownloadSegments.size() - ); - - final ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, fileName -> { + ) throws IOException { + final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); + for (String segment : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + sourceRemoteDirectory.copyTo(segment, storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); onFileSync.run(); if (targetRemoteDirectory != null) { - targetRemoteDirectory.copyFrom(storeDirectory, fileName, fileName, IOContext.DEFAULT); + targetRemoteDirectory.copyFrom(storeDirectory, segment, segment, IOContext.DEFAULT); } - return null; - }); - - final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); - toDownloadSegments.forEach(file -> { sourceRemoteDirectory.copyTo(file, storeDirectory, indexPath, segmentsDownloadListener); }); - completionListener.actionGet(); + } } private boolean localDirectoryContains(Directory localDirectory, String file, long checksum) { diff --git a/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java b/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java index c0211e1257c8e..762aab51469d0 100644 --- a/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java +++ b/server/src/main/java/org/opensearch/index/shard/StoreRecovery.java @@ -70,9 +70,7 @@ import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.indices.replication.common.ReplicationLuceneIndex; import org.opensearch.repositories.IndexId; -import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.channels.FileChannel; @@ -362,9 +360,8 @@ void recoverFromRepository(final IndexShard indexShard, Repository repository, A void recoverFromSnapshotAndRemoteStore( final IndexShard indexShard, Repository repository, - RepositoriesService repositoriesService, - ActionListener listener, - ThreadPool threadPool + RemoteSegmentStoreDirectoryFactory directoryFactory, + ActionListener listener ) { try { if (canRecover(indexShard)) { @@ -392,10 +389,6 @@ void recoverFromSnapshotAndRemoteStore( remoteStoreRepository = shallowCopyShardMetadata.getRemoteStoreRepository(); } - RemoteSegmentStoreDirectoryFactory directoryFactory = new RemoteSegmentStoreDirectoryFactory( - () -> repositoriesService, - threadPool - ); RemoteSegmentStoreDirectory sourceRemoteDirectory = (RemoteSegmentStoreDirectory) directoryFactory.newDirectory( remoteStoreRepository, indexUUID, diff --git a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java index 594b7f99cd85a..eb75c39532d71 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteDirectory.java @@ -62,9 +62,9 @@ public class RemoteDirectory extends Directory { protected final BlobContainer blobContainer; private static final Logger logger = LogManager.getLogger(RemoteDirectory.class); - protected final UnaryOperator uploadRateLimiter; + private final UnaryOperator uploadRateLimiter; - protected final UnaryOperator downloadRateLimiter; + private final UnaryOperator downloadRateLimiter; /** * Number of bytes in the segment file to store checksum @@ -333,6 +333,10 @@ public boolean copyFrom( return false; } + protected UnaryOperator getDownloadRateLimiter() { + return downloadRateLimiter; + } + private void uploadBlob( Directory from, String src, diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java index 21a84f2b8c903..a97b22360716c 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java @@ -25,6 +25,7 @@ import org.apache.lucene.util.Version; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.logging.Loggers; @@ -37,6 +38,7 @@ import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadata; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadataHandler; +import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.threadpool.ThreadPool; @@ -90,6 +92,8 @@ public final class RemoteSegmentStoreDirectory extends FilterDirectory implement private final ThreadPool threadPool; + private final RecoverySettings recoverySettings; + /** * Keeps track of local segment filename to uploaded filename along with other attributes like checksum. * This map acts as a cache layer for uploaded segment filenames which helps avoid calling listAll() each time. @@ -122,13 +126,15 @@ public RemoteSegmentStoreDirectory( RemoteDirectory remoteMetadataDirectory, RemoteStoreLockManager mdLockManager, ThreadPool threadPool, - ShardId shardId + ShardId shardId, + RecoverySettings recoverySettings ) throws IOException { super(remoteDataDirectory); this.remoteDataDirectory = remoteDataDirectory; this.remoteMetadataDirectory = remoteMetadataDirectory; this.mdLockManager = mdLockManager; this.threadPool = threadPool; + this.recoverySettings = recoverySettings; this.logger = Loggers.getLogger(getClass(), shardId); init(); } @@ -488,7 +494,15 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) { final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer(); final Path destinationFilePath = destinationPath.resolve(source); - blobContainer.asyncBlobDownload(blobName, destinationFilePath, threadPool, fileCompletionListener); + final ReadContextListener readContextListener = new ReadContextListener( + blobName, + destinationFilePath, + fileCompletionListener, + threadPool, + remoteDataDirectory.getDownloadRateLimiter(), + recoverySettings.getMaxConcurrentRemoteStoreStreams() + ); + blobContainer.readBlobAsync(blobName, readContextListener); } else { // Fallback to older mechanism of downloading the file try { diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java index 490b07e441702..cc55380894ecd 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactory.java @@ -15,6 +15,7 @@ import org.opensearch.index.shard.ShardPath; import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; import org.opensearch.index.store.lockmanager.RemoteStoreLockManagerFactory; +import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.plugins.IndexStorePlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.Repository; @@ -34,12 +35,18 @@ public class RemoteSegmentStoreDirectoryFactory implements IndexStorePlugin.Dire private static final String SEGMENTS = "segments"; private final Supplier repositoriesService; + private final RecoverySettings recoverySettings; private final ThreadPool threadPool; - public RemoteSegmentStoreDirectoryFactory(Supplier repositoriesService, ThreadPool threadPool) { + public RemoteSegmentStoreDirectoryFactory( + Supplier repositoriesService, + ThreadPool threadPool, + RecoverySettings recoverySettings + ) { this.repositoriesService = repositoriesService; this.threadPool = threadPool; + this.recoverySettings = recoverySettings; } @Override @@ -71,13 +78,9 @@ public Directory newDirectory(String repositoryName, String indexUUID, ShardId s String.valueOf(shardId.id()) ); - return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, mdLockManager, threadPool, shardId); + return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, mdLockManager, threadPool, shardId, recoverySettings); } catch (RepositoryMissingException e) { throw new IllegalArgumentException("Repository should be created before creating index with remote_store enabled setting", e); } } - - private RemoteDirectory createRemoteDirectory(BlobStoreRepository repository, BlobPath commonBlobPath, String extension) { - return new RemoteDirectory(repository.blobStore().blobContainer(commonBlobPath.add(extension))); - } } diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java index e2346ae078339..ed9755bf824ea 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java @@ -84,6 +84,17 @@ public class RecoverySettings { Property.NodeScope ); + /** + * Controls the maximum number of streams that can be started concurrently when downloading from the remote store. + */ + public static final Setting INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING = Setting.intSetting( + "indices.recovery.max_concurrent_remote_store_streams", + 20, + 1, + Property.Dynamic, + Property.NodeScope + ); + /** * how long to wait before retrying after issues cause by cluster state syncing between nodes * i.e., local node is not yet known on remote node, remote shard not yet started etc. @@ -149,6 +160,7 @@ public class RecoverySettings { private volatile ByteSizeValue maxBytesPerSec; private volatile int maxConcurrentFileChunks; private volatile int maxConcurrentOperations; + private volatile int maxConcurrentRemoteStoreStreams; private volatile SimpleRateLimiter rateLimiter; private volatile TimeValue retryDelayStateSync; private volatile TimeValue retryDelayNetwork; @@ -163,6 +175,7 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { this.retryDelayStateSync = INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING.get(settings); this.maxConcurrentFileChunks = INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING.get(settings); this.maxConcurrentOperations = INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING.get(settings); + this.maxConcurrentRemoteStoreStreams = INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING.get(settings); // doesn't have to be fast as nodes are reconnected every 10s by default (see InternalClusterService.ReconnectToNodes) // and we want to give the cluster-manager time to remove a faulty node this.retryDelayNetwork = INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING.get(settings); @@ -184,6 +197,10 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, this::setMaxBytesPerSec); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING, this::setMaxConcurrentFileChunks); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING, this::setMaxConcurrentOperations); + clusterSettings.addSettingsUpdateConsumer( + INDICES_RECOVERY_MAX_CONCURRENT_REMOTE_STORE_STREAMS_SETTING, + this::setMaxConcurrentRemoteStoreStreams + ); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING, this::setRetryDelayStateSync); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING, this::setRetryDelayNetwork); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_INTERNAL_ACTION_TIMEOUT_SETTING, this::setInternalActionTimeout); @@ -279,4 +296,12 @@ public int getMaxConcurrentOperations() { private void setMaxConcurrentOperations(int maxConcurrentOperations) { this.maxConcurrentOperations = maxConcurrentOperations; } + + public int getMaxConcurrentRemoteStoreStreams() { + return this.maxConcurrentRemoteStoreStreams; + } + + private void setMaxConcurrentRemoteStoreStreams(int maxConcurrentRemoteStoreStreams) { + this.maxConcurrentRemoteStoreStreams = maxConcurrentRemoteStoreStreams; + } } diff --git a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java index aeb690465905f..e17c5293c38ac 100644 --- a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java @@ -14,7 +14,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.util.Version; -import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.core.action.ActionListener; import org.opensearch.index.shard.IndexShard; @@ -141,14 +141,12 @@ private void downloadSegments( ActionListener completionListener ) { final Path indexPath = shardPath == null ? null : shardPath.resolveIndex(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> new GetSegmentFilesResponse(toDownloadSegments)), - toDownloadSegments.size() - ); - ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, result -> null); - toDownloadSegments.forEach( - fileMetadata -> remoteStoreDirectory.copyTo(fileMetadata.name(), storeDirectory, indexPath, segmentsDownloadListener) - ); + for (StoreFileMetadata storeFileMetadata : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + remoteStoreDirectory.copyTo(storeFileMetadata.name(), storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); + } + completionListener.onResponse(new GetSegmentFilesResponse(toDownloadSegments)); } @Override diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 02668f28878c5..9e49ca37babdc 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -758,9 +758,12 @@ protected Node( rerouteServiceReference.set(rerouteService); clusterService.setRerouteService(rerouteService); + final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); + final IndexStorePlugin.DirectoryFactory remoteDirectoryFactory = new RemoteSegmentStoreDirectoryFactory( repositoriesServiceReference::get, - threadPool + threadPool, + recoverySettings ); final SearchRequestStats searchRequestStats = new SearchRequestStats(); @@ -952,7 +955,6 @@ protected Node( transportService.getTaskManager() ); - final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); RepositoriesModule repositoriesModule = new RepositoriesModule( this.environment, pluginsService.filterPlugins(RepositoryPlugin.class), diff --git a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java index 3481e43cf4c72..41ad357eaeed9 100644 --- a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java @@ -1126,7 +1126,8 @@ private void executeStaleShardDelete( // see https://github.com/opensearch-project/OpenSearch/issues/8469 new RemoteSegmentStoreDirectoryFactory( remoteStoreLockManagerFactory.getRepositoriesService(), - threadPool + threadPool, + recoverySettings ).newDirectory( remoteStoreRepoForIndex, indexUUID, @@ -1596,7 +1597,8 @@ private void executeOneStaleIndexDelete( // see https://github.com/opensearch-project/OpenSearch/issues/8469 new RemoteSegmentStoreDirectoryFactory( remoteStoreLockManagerFactory.getRepositoriesService(), - threadPool + threadPool, + recoverySettings ).newDirectory( remoteStoreRepoForIndex, indexUUID, diff --git a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java index 8375ac34972af..ecb5b2cef58ac 100644 --- a/server/src/main/java/org/opensearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/opensearch/threadpool/ThreadPool.java @@ -115,6 +115,7 @@ public static class Names { public static final String TRANSLOG_SYNC = "translog_sync"; public static final String REMOTE_PURGE = "remote_purge"; public static final String REMOTE_REFRESH_RETRY = "remote_refresh_retry"; + public static final String REMOTE_RECOVERY = "remote_recovery"; public static final String INDEX_SEARCHER = "index_searcher"; } @@ -184,6 +185,7 @@ public static ThreadPoolType fromType(String type) { map.put(Names.TRANSLOG_SYNC, ThreadPoolType.FIXED); map.put(Names.REMOTE_PURGE, ThreadPoolType.SCALING); map.put(Names.REMOTE_REFRESH_RETRY, ThreadPoolType.SCALING); + map.put(Names.REMOTE_RECOVERY, ThreadPoolType.SCALING); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { map.put(Names.INDEX_SEARCHER, ThreadPoolType.RESIZABLE); } @@ -269,6 +271,10 @@ public ThreadPool( Names.REMOTE_REFRESH_RETRY, new ScalingExecutorBuilder(Names.REMOTE_REFRESH_RETRY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) ); + builders.put( + Names.REMOTE_RECOVERY, + new ScalingExecutorBuilder(Names.REMOTE_RECOVERY, 1, halfProcMaxAt10, TimeValue.timeValueMinutes(5)) + ); if (FeatureFlags.isEnabled(FeatureFlags.CONCURRENT_SEGMENT_SEARCH)) { builders.put( Names.INDEX_SEARCHER, diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java index 947a4f9b1c9ab..1780819390052 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -20,6 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -51,10 +52,12 @@ public void testReadBlobAsync() throws Exception { // Objects needed for API call final byte[] data = new byte[size]; Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); @@ -76,7 +79,7 @@ public void testReadBlobAsync() throws Exception { assertEquals(1, response.getNumberOfParts()); assertEquals(size, response.getBlobSize()); - InputStreamContainer responseContainer = response.getPartStreams().get(0); + InputStreamContainer responseContainer = response.getPartStreams().get(0).get().join(); assertEquals(0, responseContainer.getOffset()); assertEquals(size, responseContainer.getContentLength()); assertEquals(100, responseContainer.getInputStream().available()); @@ -99,7 +102,8 @@ public void testReadBlobAsyncException() throws Exception { final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(() -> streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java deleted file mode 100644 index fa13d90f42fa6..0000000000000 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.common.blobstore.stream.read.listener; - -import org.opensearch.test.OpenSearchTestCase; - -import java.io.IOException; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; - -public class FileCompletionListenerTests extends OpenSearchTestCase { - - public void testFileCompletionListener() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - for (int stream = 0; stream < numStreams; stream++) { - // Ensure completion listener called only when all streams are completed - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getResponseCount()); - assertEquals(fileName, completionListener.getResponse()); - } - - public void testFileCompletionListenerFailure() { - int numStreams = 10; - String fileName = "test_segment_file"; - CountingCompletionListener completionListener = new CountingCompletionListener(); - FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); - - // Fail the listener initially - IOException exception = new IOException(); - fileCompletionListener.onFailure(exception); - - for (int stream = 0; stream < numStreams - 1; stream++) { - assertEquals(0, completionListener.getResponseCount()); - fileCompletionListener.onResponse(null); - } - - assertEquals(1, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - assertEquals(0, completionListener.getResponseCount()); - - fileCompletionListener.onFailure(exception); - assertEquals(2, completionListener.getFailureCount()); - assertEquals(exception, completionListener.getException()); - } -} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java index 811566eb5767b..f2a758b9bbe10 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -13,14 +13,11 @@ import org.junit.Before; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.util.UUID; -import java.util.concurrent.atomic.AtomicBoolean; - -import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; +import java.util.function.UnaryOperator; public class FilePartWriterTests extends OpenSearchTestCase { @@ -34,130 +31,37 @@ public void init() throws Exception { public void testFilePartWriter() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterWithOffset() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 100; int offset = 10; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength + offset, Files.size(segmentFilePath)); - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); } public void testFilePartWriterLargeInput() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); int contentLength = 20 * 1024 * 1024; - int partNumber = 1; InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter.write(segmentFilePath, inputStreamContainer, UnaryOperator.identity()); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); - - assertEquals(1, fileCompletionListener.getResponseCount()); - assertEquals(Integer.valueOf(partNumber), fileCompletionListener.getResponse()); - } - - public void testFilePartWriterException() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; - int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - IOException ioException = new IOException(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - assertFalse(anyStreamFailed.get()); - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - // Fail stream again to simulate another stream failure for same file - filePartWriter.processFailure(ioException); - - assertTrue(anyStreamFailed.get()); - assertFalse(Files.exists(segmentFilePath)); - - assertEquals(0, fileCompletionListener.getResponseCount()); - assertEquals(1, fileCompletionListener.getFailureCount()); - assertEquals(ioException, fileCompletionListener.getException()); - } - - public void testFilePartWriterStreamFailed() throws Exception { - Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; - int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); - AtomicBoolean anyStreamFailed = new AtomicBoolean(true); - CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); - - assertFalse(Files.exists(segmentFilePath)); - assertEquals(0, fileCompletionListener.getResponseCount()); } } diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 21b7b47390a9b..7e4c96cbadcda 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -29,7 +29,9 @@ import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.function.UnaryOperator; import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; @@ -46,6 +48,7 @@ public class ReadContextListenerTests extends OpenSearchTestCase { private static final int NUMBER_OF_PARTS = 5; private static final int PART_SIZE = 10; private static final String TEST_SEGMENT_FILE = "test_segment_file"; + private static final int MAX_CONCURRENT_STREAMS = 10; @BeforeClass public static void setup() { @@ -64,10 +67,17 @@ public void init() throws Exception { public void testReadContextListener() throws InterruptedException, IOException { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -79,10 +89,17 @@ public void testReadContextListener() throws InterruptedException, IOException { public void testReadContextListenerFailure() throws Exception { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); InputStream badInputStream = new InputStream() { @Override @@ -101,7 +118,13 @@ public int available() { } }; - blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + blobPartStreams.add( + NUMBER_OF_PARTS, + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), + threadPool.generic() + ) + ); ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -112,18 +135,31 @@ public int available() { public void testReadContextListenerException() { Path fileLocation = path.resolve(UUID.randomUUID().toString()); CountingCompletionListener listener = new CountingCompletionListener(); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + listener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); IOException exception = new IOException(); readContextListener.onFailure(exception); assertEquals(1, listener.getFailureCount()); assertEquals(exception, listener.getException()); } - private List initializeBlobPartStreams() { - List blobPartStreams = new ArrayList<>(); + private List initializeBlobPartStreams() { + List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); - blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + int finalPartNumber = partNumber; + blobPartStreams.add( + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(testStream, PART_SIZE, (long) finalPartNumber * PART_SIZE), + threadPool.generic() + ) + ); } return blobPartStreams; } diff --git a/server/src/test/java/org/opensearch/index/IndexModuleTests.java b/server/src/test/java/org/opensearch/index/IndexModuleTests.java index a1d6be84c9926..bbd73bcf97aab 100644 --- a/server/src/test/java/org/opensearch/index/IndexModuleTests.java +++ b/server/src/test/java/org/opensearch/index/IndexModuleTests.java @@ -105,6 +105,7 @@ import org.opensearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason; import org.opensearch.indices.fielddata.cache.IndicesFieldDataCache; import org.opensearch.indices.mapper.MapperRegistry; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.RecoveryState; import org.opensearch.plugins.IndexStorePlugin; import org.opensearch.repositories.RepositoriesService; @@ -257,7 +258,7 @@ private IndexService newIndexService(IndexModule module) throws IOException { writableRegistry(), () -> false, null, - new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool), + new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool, DefaultRecoverySettings.INSTANCE), translogFactorySupplier, () -> IndexSettings.DEFAULT_REFRESH_INTERVAL, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL diff --git a/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java b/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java index 5a13f57db2c87..941f2f48e71af 100644 --- a/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java +++ b/server/src/test/java/org/opensearch/index/shard/RemoteStoreRefreshListenerTests.java @@ -33,6 +33,7 @@ import org.opensearch.index.store.RemoteSegmentStoreDirectory.MetadataFilenameUtils; import org.opensearch.index.store.Store; import org.opensearch.index.store.lockmanager.RemoteStoreLockManager; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.threadpool.ThreadPool; @@ -155,7 +156,8 @@ public void testRemoteDirectoryInitThrowsException() throws IOException { remoteMetadataDirectory, mock(RemoteStoreLockManager.class), mock(ThreadPool.class), - shardId + shardId, + DefaultRecoverySettings.INSTANCE ); FilterDirectory remoteStoreFilterDirectory = new RemoteStoreRefreshListenerTests.TestFilterDirectory( new RemoteStoreRefreshListenerTests.TestFilterDirectory(remoteSegmentStoreDirectory) diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java index cad5e47531cc6..78c7fe64cebd9 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryFactoryTests.java @@ -20,6 +20,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexSettings; import org.opensearch.index.shard.ShardPath; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.RepositoryMissingException; import org.opensearch.repositories.blobstore.BlobStoreRepository; @@ -57,7 +58,11 @@ public void setup() { repositoriesService = mock(RepositoriesService.class); threadPool = mock(ThreadPool.class); when(repositoriesServiceSupplier.get()).thenReturn(repositoriesService); - remoteSegmentStoreDirectoryFactory = new RemoteSegmentStoreDirectoryFactory(repositoriesServiceSupplier, threadPool); + remoteSegmentStoreDirectoryFactory = new RemoteSegmentStoreDirectoryFactory( + repositoriesServiceSupplier, + threadPool, + DefaultRecoverySettings.INSTANCE + ); } public void testNewDirectory() throws IOException { diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java index f154dddb0e7cc..b574ccaac55e1 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java @@ -25,12 +25,13 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.UUIDs; import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; +import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.io.VersionedCodecStreamWrapper; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lucene.store.ByteArrayIndexInput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.FeatureFlags; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; @@ -41,6 +42,7 @@ import org.opensearch.index.store.lockmanager.RemoteStoreMetadataLockManager; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadata; import org.opensearch.index.store.remote.metadata.RemoteSegmentMetadataHandler; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.threadpool.ThreadPool; import org.junit.After; @@ -56,9 +58,11 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -145,13 +149,16 @@ public void setup() throws IOException { remoteMetadataDirectory, mdLockManager, threadPool, - indexShard.shardId() + indexShard.shardId(), + DefaultRecoverySettings.INSTANCE ); try (Store store = indexShard.store()) { segmentInfos = store.readLastCommittedSegmentsInfo(); } + when(remoteDataDirectory.getDownloadRateLimiter()).thenReturn(UnaryOperator.identity()); when(threadPool.executor(ThreadPool.Names.REMOTE_PURGE)).thenReturn(executorService); + when(threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY)).thenReturn(executorService); } @After @@ -562,9 +569,6 @@ public void onFailure(Exception e) {} } public void testCopyFilesToMultipart() throws Exception { - Settings settings = Settings.builder().build(); - FeatureFlags.initializeFeatureFlags(settings); - String filename = "_0.cfe"; populateMetadata(); remoteSegmentStoreDirectory.init(); @@ -574,13 +578,15 @@ public void testCopyFilesToMultipart() throws Exception { when(remoteDataDirectory.getBlobContainer()).thenReturn(blobContainer); Mockito.doAnswer(invocation -> { - ActionListener completionListener = invocation.getArgument(3); - completionListener.onResponse(invocation.getArgument(0)); + ActionListener completionListener = invocation.getArgument(1); + final CompletableFuture future = new CompletableFuture<>(); + future.complete(new InputStreamContainer(new ByteArrayInputStream(new byte[] { 42 }), 0, 1)); + completionListener.onResponse(new ReadContext(1, List.of(() -> future), "")); return null; - }).when(blobContainer).asyncBlobDownload(any(), any(), any(), any()); + }).when(blobContainer).readBlobAsync(any(), any()); CountDownLatch downloadLatch = new CountDownLatch(1); - ActionListener completionListener = new ActionListener() { + ActionListener completionListener = new ActionListener<>() { @Override public void onResponse(String unused) { downloadLatch.countDown(); @@ -592,7 +598,7 @@ public void onFailure(Exception e) {} Path path = createTempDir(); remoteSegmentStoreDirectory.copyTo(filename, storeDirectory, path, completionListener); assertTrue(downloadLatch.await(5000, TimeUnit.SECONDS)); - verify(blobContainer, times(1)).asyncBlobDownload(contains(filename), eq(path.resolve(filename)), any(), any()); + verify(blobContainer, times(1)).readBlobAsync(contains(filename), any()); verify(storeDirectory, times(0)).copyFrom(any(), any(), any(), any()); } @@ -678,7 +684,8 @@ public void testCopyFilesFromMultipartIOException() throws Exception { remoteMetadataDirectory, mdLockManager, threadPool, - indexShard.shardId() + indexShard.shardId(), + DefaultRecoverySettings.INSTANCE ); populateMetadata(); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 09f5c1bea1a5e..80731b378f369 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -195,6 +195,7 @@ import org.opensearch.indices.analysis.AnalysisModule; import org.opensearch.indices.cluster.IndicesClusterStateService; import org.opensearch.indices.mapper.MapperRegistry; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.PeerRecoverySourceService; import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoverySettings; @@ -2066,7 +2067,7 @@ public void onFailure(final Exception e) { emptyMap(), null, emptyMap(), - new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool), + new RemoteSegmentStoreDirectoryFactory(() -> repositoriesService, threadPool, DefaultRecoverySettings.INSTANCE), repositoriesServiceReference::get, fileCacheCleaner, null, diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index 466c00d0648dc..186c1c7e78f6b 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -119,6 +119,7 @@ import org.opensearch.indices.IndicesService; import org.opensearch.indices.breaker.HierarchyCircuitBreakerService; import org.opensearch.indices.recovery.AsyncRecoveryTarget; +import org.opensearch.indices.recovery.DefaultRecoverySettings; import org.opensearch.indices.recovery.PeerRecoveryTargetService; import org.opensearch.indices.recovery.RecoveryFailedException; import org.opensearch.indices.recovery.RecoveryResponse; @@ -640,7 +641,7 @@ protected IndexShard newShard( Collections.emptyList(), clusterSettings ); - Store remoteStore = null; + Store remoteStore; RemoteStoreStatsTrackerFactory remoteStoreStatsTrackerFactory = null; RepositoriesService mockRepoSvc = mock(RepositoriesService.class); @@ -659,6 +660,8 @@ protected IndexShard newShard( remoteStoreStatsTrackerFactory = new RemoteStoreStatsTrackerFactory(clusterService, indexSettings.getSettings()); BlobStoreRepository repo = createRepository(remotePath); when(mockRepoSvc.repository(any())).thenAnswer(invocationOnMock -> repo); + } else { + remoteStore = null; } final BiFunction translogFactorySupplier = (settings, shardRouting) -> { @@ -698,7 +701,8 @@ protected IndexShard newShard( remoteStore, remoteStoreStatsTrackerFactory, () -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL, - "dummy-node" + "dummy-node", + null ); indexShard.addShardFailureCallback(DEFAULT_SHARD_FAILURE_HANDLER); if (remoteStoreStatsTrackerFactory != null) { @@ -785,7 +789,14 @@ protected RemoteSegmentStoreDirectory createRemoteSegmentStoreDirectory(ShardId RemoteStoreLockManager remoteStoreLockManager = new RemoteStoreMetadataLockManager( new RemoteBufferedOutputDirectory(getBlobContainer(remoteShardPath.resolveIndex())) ); - return new RemoteSegmentStoreDirectory(dataDirectory, metadataDirectory, remoteStoreLockManager, threadPool, shardId); + return new RemoteSegmentStoreDirectory( + dataDirectory, + metadataDirectory, + remoteStoreLockManager, + threadPool, + shardId, + DefaultRecoverySettings.INSTANCE + ); } private RemoteDirectory newRemoteDirectory(Path f) throws IOException { diff --git a/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java b/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java new file mode 100644 index 0000000000000..359668f5dad71 --- /dev/null +++ b/test/framework/src/main/java/org/opensearch/indices/recovery/DefaultRecoverySettings.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.indices.recovery; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; + +/** + * Utility to provide a {@link RecoverySettings} instance containing all defaults + */ +public final class DefaultRecoverySettings { + private DefaultRecoverySettings() {} + + public static final RecoverySettings INSTANCE = new RecoverySettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); +}