diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java new file mode 100644 index 0000000000000..258b00bde75f0 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3Reference.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.concurrent.RefCountedReleasable; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Handles the shutdown of the wrapped {@link software.amazon.awssdk.services.s3.S3AsyncClient} using reference + * counting. + */ +public class AmazonAsyncS3Reference extends RefCountedReleasable { + + private static final Logger logger = LogManager.getLogger(AmazonAsyncS3Reference.class); + + AmazonAsyncS3Reference(AmazonAsyncS3WithCredentials client) { + super("AWS_S3_CLIENT", client, () -> { + client.client().close(); + client.priorityClient().close(); + AwsCredentialsProvider credentials = client.credentials(); + if (credentials instanceof Closeable) { + try { + ((Closeable) credentials).close(); + } catch (IOException e) { + logger.error("Exception while closing AwsCredentialsProvider", e); + } + } + }); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java new file mode 100644 index 0000000000000..15f104f51a067 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonAsyncS3WithCredentials.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.opensearch.common.Nullable; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.services.s3.S3AsyncClient; + +/** + * The holder of the AmazonS3 and AWSCredentialsProvider + */ +final class AmazonAsyncS3WithCredentials { + private final S3AsyncClient client; + private final S3AsyncClient priorityClient; + private final AwsCredentialsProvider credentials; + + private AmazonAsyncS3WithCredentials( + final S3AsyncClient client, + final S3AsyncClient priorityClient, + @Nullable final AwsCredentialsProvider credentials + ) { + this.client = client; + this.credentials = credentials; + this.priorityClient = priorityClient; + } + + S3AsyncClient client() { + return client; + } + + S3AsyncClient priorityClient() { + return priorityClient; + } + + AwsCredentialsProvider credentials() { + return credentials; + } + + static AmazonAsyncS3WithCredentials create( + final S3AsyncClient client, + final S3AsyncClient priorityClient, + @Nullable final AwsCredentialsProvider credentials + ) { + return new AmazonAsyncS3WithCredentials(client, priorityClient, credentials); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java new file mode 100644 index 0000000000000..653034ee9afde --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3AsyncService.java @@ -0,0 +1,430 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.Nullable; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.collect.MapBuilder; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.repositories.s3.S3ClientSettings.IrsaCredentials; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SdkSystemSetting; +import software.amazon.awssdk.core.client.config.ClientAsyncConfiguration; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.client.config.SdkAdvancedAsyncClientOption; +import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.core.retry.backoff.BackoffStrategy; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.http.nio.netty.ProxyConfiguration; +import software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup; +import software.amazon.awssdk.profiles.ProfileFileSystemSetting; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.services.sts.auth.StsWebIdentityTokenFileCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Map; + +import static java.util.Collections.emptyMap; + +class S3AsyncService implements Closeable { + private static final Logger logger = LogManager.getLogger(S3AsyncService.class); + + private static final String STS_ENDPOINT_OVERRIDE_SYSTEM_PROPERTY = "aws.stsEndpointOverride"; + + private static final String DEFAULT_S3_ENDPOINT = "s3.amazonaws.com"; + + private volatile Map clientsCache = emptyMap(); + + /** + * Client settings calculated from static configuration and settings in the keystore. + */ + private volatile Map staticClientSettings; + + /** + * Client settings derived from those in {@link #staticClientSettings} by combining them with settings + * in the {@link RepositoryMetadata}. + */ + private volatile Map derivedClientSettings = emptyMap(); + + S3AsyncService(final Path configPath) { + staticClientSettings = MapBuilder.newMapBuilder() + .put("default", S3ClientSettings.getClientSettings(Settings.EMPTY, "default", configPath)) + .immutableMap(); + } + + /** + * Refreshes the settings for the AmazonS3 clients and clears the cache of + * existing clients. New clients will be build using these new settings. Old + * clients are usable until released. On release they will be destroyed instead + * of being returned to the cache. + */ + public synchronized void refreshAndClearCache(Map clientsSettings) { + // shutdown all unused clients + // others will shutdown on their respective release + releaseCachedClients(); + this.staticClientSettings = MapBuilder.newMapBuilder(clientsSettings).immutableMap(); + derivedClientSettings = emptyMap(); + assert this.staticClientSettings.containsKey("default") : "always at least have 'default'"; + // clients are built lazily by {@link client} + } + + /** + * Attempts to retrieve a client by its repository metadata and settings from the cache. + * If the client does not exist it will be created. + */ + public AmazonAsyncS3Reference client( + RepositoryMetadata repositoryMetadata, + AsyncExecutorContainer priorityExecutorBuilder, + AsyncExecutorContainer normalExecutorBuilder + ) { + final S3ClientSettings clientSettings = settings(repositoryMetadata); + { + final AmazonAsyncS3Reference clientReference = clientsCache.get(clientSettings); + if (clientReference != null && clientReference.tryIncRef()) { + return clientReference; + } + } + synchronized (this) { + final AmazonAsyncS3Reference existing = clientsCache.get(clientSettings); + if (existing != null && existing.tryIncRef()) { + return existing; + } + final AmazonAsyncS3Reference clientReference = new AmazonAsyncS3Reference( + buildClient(clientSettings, priorityExecutorBuilder, normalExecutorBuilder) + ); + clientReference.incRef(); + clientsCache = MapBuilder.newMapBuilder(clientsCache).put(clientSettings, clientReference).immutableMap(); + return clientReference; + } + } + + /** + * Either fetches {@link S3ClientSettings} for a given {@link RepositoryMetadata} from cached settings or creates them + * by overriding static client settings from {@link #staticClientSettings} with settings found in the repository metadata. + * @param repositoryMetadata Repository Metadata + * @return S3ClientSettings + */ + S3ClientSettings settings(RepositoryMetadata repositoryMetadata) { + final Settings settings = repositoryMetadata.settings(); + { + final S3ClientSettings existing = derivedClientSettings.get(settings); + if (existing != null) { + return existing; + } + } + final String clientName = S3Repository.CLIENT_NAME.get(settings); + final S3ClientSettings staticSettings = staticClientSettings.get(clientName); + if (staticSettings != null) { + synchronized (this) { + final S3ClientSettings existing = derivedClientSettings.get(settings); + if (existing != null) { + return existing; + } + final S3ClientSettings newSettings = staticSettings.refine(settings); + derivedClientSettings = MapBuilder.newMapBuilder(derivedClientSettings).put(settings, newSettings).immutableMap(); + return newSettings; + } + } + throw new IllegalArgumentException( + "Unknown s3 client name [" + + clientName + + "]. Existing client configs: " + + Strings.collectionToDelimitedString(staticClientSettings.keySet(), ",") + ); + } + + // proxy for testing + synchronized AmazonAsyncS3WithCredentials buildClient( + final S3ClientSettings clientSettings, + AsyncExecutorContainer priorityExecutorBuilder, + AsyncExecutorContainer normalExecutorBuilder + ) { + setDefaultAwsProfilePath(); + final S3AsyncClientBuilder builder = S3AsyncClient.builder(); + builder.overrideConfiguration(buildOverrideConfiguration(clientSettings)); + final AwsCredentialsProvider credentials = buildCredentials(logger, clientSettings); + builder.credentialsProvider(credentials); + + String endpoint = Strings.hasLength(clientSettings.endpoint) ? clientSettings.endpoint : DEFAULT_S3_ENDPOINT; + if ((endpoint.startsWith("http://") || endpoint.startsWith("https://")) == false) { + // Manually add the schema to the endpoint to work around https://github.com/aws/aws-sdk-java/issues/2274 + endpoint = clientSettings.protocol.toString() + "://" + endpoint; + } + logger.debug("using endpoint [{}] and region [{}]", endpoint, clientSettings.region); + + // If the endpoint configuration isn't set on the builder then the default behaviour is to try + // and work out what region we are in and use an appropriate endpoint - see AwsClientBuilder#setRegion. + // In contrast, directly-constructed clients use s3.amazonaws.com unless otherwise instructed. We currently + // use a directly-constructed client, and need to keep the existing behaviour to avoid a breaking change, + // so to move to using the builder we must set it explicitly to keep the existing behaviour. + // + // We do this because directly constructing the client is deprecated (was already deprecated in 1.1.223 too) + // so this change removes that usage of a deprecated API. + builder.endpointOverride(URI.create(endpoint)); + builder.region(Region.of(clientSettings.region)); + if (clientSettings.pathStyleAccess) { + builder.forcePathStyle(true); + } + + builder.httpClient(buildHttpClient(clientSettings, priorityExecutorBuilder.getAsyncTransferEventLoopGroup())); + builder.asyncConfiguration( + ClientAsyncConfiguration.builder() + .advancedOption( + SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, + priorityExecutorBuilder.getFutureCompletionExecutor() + ) + .build() + ); + final S3AsyncClient priorityClient = SocketAccess.doPrivileged(builder::build); + + builder.httpClient(buildHttpClient(clientSettings, normalExecutorBuilder.getAsyncTransferEventLoopGroup())); + builder.asyncConfiguration( + ClientAsyncConfiguration.builder() + .advancedOption( + SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR, + normalExecutorBuilder.getFutureCompletionExecutor() + ) + .build() + ); + final S3AsyncClient client = SocketAccess.doPrivileged(builder::build); + + return AmazonAsyncS3WithCredentials.create(client, priorityClient, credentials); + } + + static ClientOverrideConfiguration buildOverrideConfiguration(final S3ClientSettings clientSettings) { + return ClientOverrideConfiguration.builder() + .retryPolicy( + RetryPolicy.builder() + .numRetries(clientSettings.maxRetries) + .throttlingBackoffStrategy( + clientSettings.throttleRetries ? BackoffStrategy.defaultThrottlingStrategy() : BackoffStrategy.none() + ) + .build() + ) + .apiCallAttemptTimeout(Duration.ofMillis(clientSettings.requestTimeoutMillis)) + .build(); + } + + // pkg private for tests + static SdkAsyncHttpClient buildHttpClient(S3ClientSettings clientSettings, AsyncTransferEventLoopGroup asyncTransferEventLoopGroup) { + // the response metadata cache is only there for diagnostics purposes, + // but can force objects from every response to the old generation. + NettyNioAsyncHttpClient.Builder clientBuilder = NettyNioAsyncHttpClient.builder(); + + if (clientSettings.proxySettings.getType() != ProxySettings.ProxyType.DIRECT) { + ProxyConfiguration.Builder proxyConfiguration = ProxyConfiguration.builder(); + proxyConfiguration.scheme(clientSettings.proxySettings.getType().toProtocol().toString()); + proxyConfiguration.host(clientSettings.proxySettings.getHostName()); + proxyConfiguration.port(clientSettings.proxySettings.getPort()); + proxyConfiguration.username(clientSettings.proxySettings.getUsername()); + proxyConfiguration.password(clientSettings.proxySettings.getPassword()); + clientBuilder.proxyConfiguration(proxyConfiguration.build()); + } + + // TODO: add max retry and UseThrottleRetry. Replace values with settings and put these in default settings + clientBuilder.connectionTimeout(Duration.ofMillis(clientSettings.connectionTimeoutMillis)); + clientBuilder.connectionAcquisitionTimeout(Duration.ofMillis(clientSettings.connectionAcquisitionTimeoutMillis)); + clientBuilder.maxPendingConnectionAcquires(10_000); + clientBuilder.maxConcurrency(clientSettings.maxConnections); + clientBuilder.eventLoopGroup(SdkEventLoopGroup.create(asyncTransferEventLoopGroup.getEventLoopGroup())); + clientBuilder.tcpKeepAlive(true); + + return clientBuilder.build(); + } + + // pkg private for tests + static AwsCredentialsProvider buildCredentials(Logger logger, S3ClientSettings clientSettings) { + final AwsCredentials basicCredentials = clientSettings.credentials; + final IrsaCredentials irsaCredentials = buildFromEnvironment(clientSettings.irsaCredentials); + + // If IAM Roles for Service Accounts (IRSA) credentials are configured, start with them first + if (irsaCredentials != null) { + logger.debug("Using IRSA credentials"); + + final Region region = Region.of(clientSettings.region); + StsClient stsClient = SocketAccess.doPrivileged(() -> { + StsClientBuilder builder = StsClient.builder().region(region); + + final String stsEndpoint = System.getProperty(STS_ENDPOINT_OVERRIDE_SYSTEM_PROPERTY); + if (stsEndpoint != null) { + builder = builder.endpointOverride(URI.create(stsEndpoint)); + } + + if (basicCredentials != null) { + builder = builder.credentialsProvider(StaticCredentialsProvider.create(basicCredentials)); + } else { + builder = builder.credentialsProvider(DefaultCredentialsProvider.create()); + } + + return builder.build(); + }); + + if (irsaCredentials.getIdentityTokenFile() == null) { + final StsAssumeRoleCredentialsProvider.Builder stsCredentialsProviderBuilder = StsAssumeRoleCredentialsProvider.builder() + .stsClient(stsClient) + .refreshRequest( + AssumeRoleRequest.builder() + .roleArn(irsaCredentials.getRoleArn()) + .roleSessionName(irsaCredentials.getRoleSessionName()) + .build() + ); + + final StsAssumeRoleCredentialsProvider stsCredentialsProvider = SocketAccess.doPrivileged( + stsCredentialsProviderBuilder::build + ); + + return new PrivilegedSTSAssumeRoleSessionCredentialsProvider<>(stsClient, stsCredentialsProvider); + } else { + final StsWebIdentityTokenFileCredentialsProvider.Builder stsCredentialsProviderBuilder = + StsWebIdentityTokenFileCredentialsProvider.builder() + .stsClient(stsClient) + .roleArn(irsaCredentials.getRoleArn()) + .roleSessionName(irsaCredentials.getRoleSessionName()) + .webIdentityTokenFile(Path.of(irsaCredentials.getIdentityTokenFile())); + + final StsWebIdentityTokenFileCredentialsProvider stsCredentialsProvider = SocketAccess.doPrivileged( + stsCredentialsProviderBuilder::build + ); + + return new PrivilegedSTSAssumeRoleSessionCredentialsProvider<>(stsClient, stsCredentialsProvider); + } + } else if (basicCredentials != null) { + logger.debug("Using basic key/secret credentials"); + return StaticCredentialsProvider.create(basicCredentials); + } else { + logger.debug("Using instance profile credentials"); + return new PrivilegedInstanceProfileCredentialsProvider(); + } + } + + // Aws v2 sdk tries to load a default profile from home path which is restricted. Hence, setting these to random + // valid paths. + @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") + private static void setDefaultAwsProfilePath() { + if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { + System.setProperty(ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), System.getProperty("opensearch.path.conf")); + } + if (ProfileFileSystemSetting.AWS_CONFIG_FILE.getStringValue().isEmpty()) { + System.setProperty(ProfileFileSystemSetting.AWS_CONFIG_FILE.property(), System.getProperty("opensearch.path.conf")); + } + } + + private static IrsaCredentials buildFromEnvironment(IrsaCredentials defaults) { + if (defaults == null) { + return null; + } + + String webIdentityTokenFile = defaults.getIdentityTokenFile(); + if (webIdentityTokenFile == null) { + webIdentityTokenFile = System.getenv(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.environmentVariable()); + } + + String roleArn = defaults.getRoleArn(); + if (roleArn == null) { + roleArn = System.getenv(SdkSystemSetting.AWS_ROLE_ARN.environmentVariable()); + } + + String roleSessionName = defaults.getRoleSessionName(); + if (roleSessionName == null) { + roleSessionName = System.getenv(SdkSystemSetting.AWS_ROLE_SESSION_NAME.environmentVariable()); + } + + return new IrsaCredentials(webIdentityTokenFile, roleArn, roleSessionName); + } + + private synchronized void releaseCachedClients() { + // the clients will shutdown when they will not be used anymore + for (final AmazonAsyncS3Reference clientReference : clientsCache.values()) { + clientReference.decRef(); + } + + // clear previously cached clients, they will be build lazily + clientsCache = emptyMap(); + derivedClientSettings = emptyMap(); + } + + static class PrivilegedInstanceProfileCredentialsProvider implements AwsCredentialsProvider { + private final AwsCredentialsProvider credentials; + + private PrivilegedInstanceProfileCredentialsProvider() { + this.credentials = initializeProvider(); + } + + private AwsCredentialsProvider initializeProvider() { + if (SdkSystemSetting.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI.getStringValue().isPresent() + || SdkSystemSetting.AWS_CONTAINER_CREDENTIALS_FULL_URI.getStringValue().isPresent()) { + + return ContainerCredentialsProvider.builder().asyncCredentialUpdateEnabled(true).build(); + } + // InstanceProfileCredentialsProvider as last item of chain + return InstanceProfileCredentialsProvider.builder().asyncCredentialUpdateEnabled(true).build(); + } + + @Override + public AwsCredentials resolveCredentials() { + return SocketAccess.doPrivileged(credentials::resolveCredentials); + } + } + + static class PrivilegedSTSAssumeRoleSessionCredentialsProvider

+ implements + AwsCredentialsProvider, + Closeable { + private final P credentials; + private final StsClient stsClient; + + private PrivilegedSTSAssumeRoleSessionCredentialsProvider(@Nullable final StsClient stsClient, final P credentials) { + this.stsClient = stsClient; + this.credentials = credentials; + } + + @Override + public void close() throws IOException { + SocketAccess.doPrivilegedIOException(() -> { + credentials.close(); + if (stsClient != null) { + stsClient.close(); + } + return null; + }); + } + + @Override + public AwsCredentials resolveCredentials() { + return SocketAccess.doPrivileged(credentials::resolveCredentials); + } + } + + @Override + public void close() { + releaseCachedClients(); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java index 5f6be6ac01e76..25e9797018102 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3ClientSettings.java @@ -172,6 +172,48 @@ final class S3ClientSettings { key -> Setting.timeSetting(key, TimeValue.timeValueMillis(50_000), Property.NodeScope) ); + /** The request timeout for connecting to s3. */ + static final Setting.AffixSetting REQUEST_TIMEOUT_SETTING = Setting.affixKeySetting( + PREFIX, + "request_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueMinutes(2), Property.NodeScope) + ); + + /** The connection timeout for connecting to s3. */ + static final Setting.AffixSetting CONNECTION_TIMEOUT_SETTING = Setting.affixKeySetting( + PREFIX, + "connection_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueSeconds(10), Property.NodeScope) + ); + + /** The connection TTL for connecting to s3. */ + static final Setting.AffixSetting CONNECTION_TTL_SETTING = Setting.affixKeySetting( + PREFIX, + "connection_ttl", + key -> Setting.timeSetting(key, TimeValue.timeValueMillis(5000), Property.NodeScope) + ); + + /** The maximum connections to s3. */ + static final Setting.AffixSetting MAX_CONNECTIONS_SETTING = Setting.affixKeySetting( + PREFIX, + "max_connections", + key -> Setting.intSetting(key, 100, Property.NodeScope) + ); + + /** Connection acquisition timeout for new connections to S3. */ + static final Setting.AffixSetting CONNECTION_ACQUISITION_TIMEOUT = Setting.affixKeySetting( + PREFIX, + "connection_acquisition_timeout", + key -> Setting.timeSetting(key, TimeValue.timeValueMinutes(2), Property.NodeScope) + ); + + /** The maximum pending connections to S3. */ + static final Setting.AffixSetting MAX_PENDING_CONNECTION_ACQUIRES = Setting.affixKeySetting( + PREFIX, + "max_pending_connection_acquires", + key -> Setting.intSetting(key, 10_000, Property.NodeScope) + ); + /** The number of retries to use when an s3 request fails. */ static final Setting.AffixSetting MAX_RETRIES_SETTING = Setting.affixKeySetting( PREFIX, @@ -232,6 +274,21 @@ final class S3ClientSettings { /** The read timeout for the s3 client. */ final int readTimeoutMillis; + /** The request timeout for the s3 client */ + final int requestTimeoutMillis; + + /** The connection timeout for the s3 client */ + final int connectionTimeoutMillis; + + /** The connection TTL for the s3 client */ + final int connectionTTLMillis; + + /** The max number of connections for the s3 client */ + final int maxConnections; + + /** The connnection acquisition timeout for the s3 async client */ + final int connectionAcquisitionTimeoutMillis; + /** The number of retries to use for the s3 client. */ final int maxRetries; @@ -256,6 +313,11 @@ private S3ClientSettings( String endpoint, Protocol protocol, int readTimeoutMillis, + int requestTimeoutMillis, + int connectionTimeoutMillis, + int connectionTTLMillis, + int maxConnections, + int connectionAcquisitionTimeoutMillis, int maxRetries, boolean throttleRetries, boolean pathStyleAccess, @@ -269,6 +331,11 @@ private S3ClientSettings( this.endpoint = endpoint; this.protocol = protocol; this.readTimeoutMillis = readTimeoutMillis; + this.requestTimeoutMillis = requestTimeoutMillis; + this.connectionTimeoutMillis = connectionTimeoutMillis; + this.connectionTTLMillis = connectionTTLMillis; + this.maxConnections = maxConnections; + this.connectionAcquisitionTimeoutMillis = connectionAcquisitionTimeoutMillis; this.maxRetries = maxRetries; this.throttleRetries = throttleRetries; this.pathStyleAccess = pathStyleAccess; @@ -300,6 +367,24 @@ S3ClientSettings refine(Settings repositorySettings) { final int newReadTimeoutMillis = Math.toIntExact( getRepoSettingOrDefault(READ_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(readTimeoutMillis)).millis() ); + final int newRequestTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault(REQUEST_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(requestTimeoutMillis)).millis() + ); + final int newConnectionTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault(CONNECTION_TIMEOUT_SETTING, normalizedSettings, TimeValue.timeValueMillis(connectionTimeoutMillis)) + .millis() + ); + final int newConnectionTTLMillis = Math.toIntExact( + getRepoSettingOrDefault(CONNECTION_TTL_SETTING, normalizedSettings, TimeValue.timeValueMillis(connectionTTLMillis)).millis() + ); + final int newConnectionAcquisitionTimeoutMillis = Math.toIntExact( + getRepoSettingOrDefault( + CONNECTION_ACQUISITION_TIMEOUT, + normalizedSettings, + TimeValue.timeValueMillis(connectionAcquisitionTimeoutMillis) + ).millis() + ); + final int newMaxConnections = Math.toIntExact(getRepoSettingOrDefault(MAX_CONNECTIONS_SETTING, normalizedSettings, maxConnections)); final int newMaxRetries = getRepoSettingOrDefault(MAX_RETRIES_SETTING, normalizedSettings, maxRetries); final boolean newThrottleRetries = getRepoSettingOrDefault(USE_THROTTLE_RETRIES_SETTING, normalizedSettings, throttleRetries); final boolean newPathStyleAccess = getRepoSettingOrDefault(USE_PATH_STYLE_ACCESS, normalizedSettings, pathStyleAccess); @@ -321,6 +406,11 @@ S3ClientSettings refine(Settings repositorySettings) { && Objects.equals(proxySettings.getHostName(), newProxyHost) && proxySettings.getPort() == newProxyPort && newReadTimeoutMillis == readTimeoutMillis + && newRequestTimeoutMillis == requestTimeoutMillis + && newConnectionTimeoutMillis == connectionTimeoutMillis + && newConnectionTTLMillis == connectionTTLMillis + && newMaxConnections == maxConnections + && newConnectionAcquisitionTimeoutMillis == connectionAcquisitionTimeoutMillis && maxRetries == newMaxRetries && newThrottleRetries == throttleRetries && Objects.equals(credentials, newCredentials) @@ -338,6 +428,11 @@ S3ClientSettings refine(Settings repositorySettings) { newEndpoint, newProtocol, newReadTimeoutMillis, + newRequestTimeoutMillis, + newConnectionTimeoutMillis, + newConnectionTTLMillis, + newMaxConnections, + newConnectionAcquisitionTimeoutMillis, newMaxRetries, newThrottleRetries, newPathStyleAccess, @@ -463,6 +558,11 @@ static S3ClientSettings getClientSettings(final Settings settings, final String getConfigValue(settings, clientName, ENDPOINT_SETTING), awsProtocol, Math.toIntExact(getConfigValue(settings, clientName, READ_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, REQUEST_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_TIMEOUT_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_TTL_SETTING).millis()), + Math.toIntExact(getConfigValue(settings, clientName, MAX_CONNECTIONS_SETTING)), + Math.toIntExact(getConfigValue(settings, clientName, CONNECTION_ACQUISITION_TIMEOUT).millis()), getConfigValue(settings, clientName, MAX_RETRIES_SETTING), getConfigValue(settings, clientName, USE_THROTTLE_RETRIES_SETTING), getConfigValue(settings, clientName, USE_PATH_STYLE_ACCESS), @@ -532,6 +632,11 @@ public boolean equals(final Object o) { } final S3ClientSettings that = (S3ClientSettings) o; return readTimeoutMillis == that.readTimeoutMillis + && requestTimeoutMillis == that.requestTimeoutMillis + && connectionTimeoutMillis == that.connectionTimeoutMillis + && connectionTTLMillis == that.connectionTTLMillis + && maxConnections == that.maxConnections + && connectionAcquisitionTimeoutMillis == that.connectionAcquisitionTimeoutMillis && maxRetries == that.maxRetries && throttleRetries == that.throttleRetries && Objects.equals(credentials, that.credentials) @@ -552,6 +657,11 @@ public int hashCode() { protocol, proxySettings, readTimeoutMillis, + requestTimeoutMillis, + connectionTimeoutMillis, + connectionTTLMillis, + maxConnections, + connectionAcquisitionTimeoutMillis, maxRetries, throttleRetries, disableChunkedEncoding, diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java index 0a6408764aeeb..4888764dbc720 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/SocketAccess.java @@ -46,7 +46,7 @@ * {@link SocketPermission} 'connect' to establish connections. This class wraps the operations requiring access in * {@link AccessController#doPrivileged(PrivilegedAction)} blocks. */ -final class SocketAccess { +public final class SocketAccess { private SocketAccess() {} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorContainer.java new file mode 100644 index 0000000000000..1ae1a15ad4010 --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncExecutorContainer.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import java.util.concurrent.ExecutorService; + +/** + * An encapsulation for the {@link AsyncTransferEventLoopGroup}, and the stream reader and future completion executor services + */ +public class AsyncExecutorContainer { + + private final ExecutorService futureCompletionExecutor; + private final ExecutorService streamReader; + private final AsyncTransferEventLoopGroup asyncTransferEventLoopGroup; + + /** + * Construct a new AsyncExecutorBuilder object + * + * @param futureCompletionExecutor An {@link ExecutorService} to pass to {@link software.amazon.awssdk.services.s3.S3AsyncClient} for future completion + * @param streamReader An {@link ExecutorService} to read streams for upload + * @param asyncTransferEventLoopGroup A {@link AsyncTransferEventLoopGroup} which encapsulates the netty {@link io.netty.channel.EventLoopGroup} for async uploads + */ + public AsyncExecutorContainer( + ExecutorService futureCompletionExecutor, + ExecutorService streamReader, + AsyncTransferEventLoopGroup asyncTransferEventLoopGroup + ) { + this.asyncTransferEventLoopGroup = asyncTransferEventLoopGroup; + this.streamReader = streamReader; + this.futureCompletionExecutor = futureCompletionExecutor; + } + + public ExecutorService getFutureCompletionExecutor() { + return futureCompletionExecutor; + } + + public AsyncTransferEventLoopGroup getAsyncTransferEventLoopGroup() { + return asyncTransferEventLoopGroup; + } + + public ExecutorService getStreamReader() { + return streamReader; + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java new file mode 100644 index 0000000000000..b6af91a08ac2b --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncPartsHandler.java @@ -0,0 +1,183 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.repositories.s3.SocketAccess; +import org.opensearch.repositories.s3.io.CheckedContainer; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReferenceArray; + +/** + * Responsible for handling parts of the original multipart request + */ +public class AsyncPartsHandler { + + private static Logger log = LogManager.getLogger(AsyncPartsHandler.class); + + /** + * Uploads parts of the upload multipart request* + * @param s3AsyncClient S3 client to use for upload + * @param executorService Thread pool for regular upload + * @param priorityExecutorService Thread pool for priority uploads + * @param uploadRequest request for upload + * @param streamContext Stream context used in supplying individual file parts + * @param uploadId Upload Id against which multi-part is being performed + * @param completedParts Reference of completed parts + * @param inputStreamContainers Checksum containers + * @return list of completable futures + * @throws IOException thrown in case of an IO error + */ + public static List> uploadParts( + S3AsyncClient s3AsyncClient, + ExecutorService executorService, + ExecutorService priorityExecutorService, + UploadRequest uploadRequest, + StreamContext streamContext, + String uploadId, + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers + ) throws IOException { + List> futures = new ArrayList<>(); + for (int partIdx = 0; partIdx < streamContext.getNumberOfParts(); partIdx++) { + InputStreamContainer inputStreamContainer = streamContext.provideStream(partIdx); + inputStreamContainers.set(partIdx, new CheckedContainer(inputStreamContainer.getContentLength())); + UploadPartRequest.Builder uploadPartRequestBuilder = UploadPartRequest.builder() + .bucket(uploadRequest.getBucket()) + .partNumber(partIdx + 1) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .contentLength(inputStreamContainer.getContentLength()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + uploadPartRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + } + uploadPart( + s3AsyncClient, + executorService, + priorityExecutorService, + completedParts, + inputStreamContainers, + futures, + uploadPartRequestBuilder.build(), + inputStreamContainer, + uploadRequest + ); + } + + return futures; + } + + /** + * Cleans up parts of the original multipart request* + * @param s3AsyncClient s3 client to use + * @param uploadRequest upload request + * @param uploadId upload id against which multi-part was carried out. + */ + public static void cleanUpParts(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, String uploadId) { + + AbortMultipartUploadRequest abortMultipartUploadRequest = AbortMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .build(); + SocketAccess.doPrivileged(() -> s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest).exceptionally(throwable -> { + log.warn( + () -> new ParameterizedMessage( + "Failed to abort previous multipart upload " + + "(id: {})" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + uploadId + ), + throwable + ); + return null; + })); + } + + private static void uploadPart( + S3AsyncClient s3AsyncClient, + ExecutorService executorService, + ExecutorService priorityExecutorService, + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers, + List> futures, + UploadPartRequest uploadPartRequest, + InputStreamContainer inputStreamContainer, + UploadRequest uploadRequest + ) { + Integer partNumber = uploadPartRequest.partNumber(); + + ExecutorService streamReadExecutor = uploadRequest.getWritePriority() == WritePriority.HIGH + ? priorityExecutorService + : executorService; + CompletableFuture uploadPartResponseFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.uploadPart( + uploadPartRequest, + AsyncRequestBody.fromInputStream( + inputStreamContainer.getInputStream(), + inputStreamContainer.getContentLength(), + streamReadExecutor + ) + ) + ); + + CompletableFuture convertFuture = uploadPartResponseFuture.thenApply( + uploadPartResponse -> convertUploadPartResponse( + completedParts, + inputStreamContainers, + uploadPartResponse, + partNumber, + uploadRequest.doRemoteDataIntegrityCheck() + ) + ); + futures.add(convertFuture); + + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartResponseFuture); + } + + private static CompletedPart convertUploadPartResponse( + AtomicReferenceArray completedParts, + AtomicReferenceArray inputStreamContainers, + UploadPartResponse partResponse, + int partNumber, + boolean isRemoteDataIntegrityCheckEnabled + ) { + CompletedPart.Builder completedPartBuilder = CompletedPart.builder().eTag(partResponse.eTag()).partNumber(partNumber); + if (isRemoteDataIntegrityCheckEnabled) { + completedPartBuilder.checksumCRC32(partResponse.checksumCRC32()); + CheckedContainer inputStreamCRC32Container = inputStreamContainers.get(partNumber - 1); + inputStreamCRC32Container.setChecksum(partResponse.checksumCRC32()); + inputStreamContainers.set(partNumber - 1, inputStreamCRC32Container); + } + CompletedPart completedPart = completedPartBuilder.build(); + completedParts.set(partNumber - 1, completedPart); + return completedPart; + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferEventLoopGroup.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferEventLoopGroup.java new file mode 100644 index 0000000000000..381a9671d669a --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferEventLoopGroup.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.util.concurrent.Future; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.repositories.s3.SocketAccess; + +import java.io.Closeable; +import java.util.concurrent.TimeUnit; + +/** + * AsyncTransferEventLoopGroup is an encapsulation for netty {@link EventLoopGroup} + */ +public class AsyncTransferEventLoopGroup implements Closeable { + private static final String THREAD_PREFIX = "s3-async-transfer-worker"; + private final Logger logger = LogManager.getLogger(AsyncTransferEventLoopGroup.class); + + private final EventLoopGroup eventLoopGroup; + + /** + * Construct a new AsyncTransferEventLoopGroup + * + * @param eventLoopThreads The number of event loop threads for this event loop group + */ + public AsyncTransferEventLoopGroup(int eventLoopThreads) { + // Epoll event loop incurs less GC and provides better performance than Nio loop. Therefore, + // using epoll wherever available is preferred. + this.eventLoopGroup = SocketAccess.doPrivileged( + () -> Epoll.isAvailable() + ? new EpollEventLoopGroup(eventLoopThreads, OpenSearchExecutors.daemonThreadFactory(THREAD_PREFIX)) + : new NioEventLoopGroup(eventLoopThreads, OpenSearchExecutors.daemonThreadFactory(THREAD_PREFIX)) + ); + } + + public EventLoopGroup getEventLoopGroup() { + return eventLoopGroup; + } + + @Override + public void close() { + Future shutdownFuture = eventLoopGroup.shutdownGracefully(0, 5, TimeUnit.SECONDS); + shutdownFuture.awaitUninterruptibly(); + if (!shutdownFuture.isSuccess()) { + logger.warn(new ParameterizedMessage("Error closing {} netty event loop group", THREAD_PREFIX), shutdownFuture.cause()); + } + } + +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java new file mode 100644 index 0000000000000..5b43ae84c51dc --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java @@ -0,0 +1,354 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import com.jcraft.jzlib.JZlib; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.exception.CorruptFileException; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.util.ByteUtils; +import org.opensearch.repositories.s3.io.CheckedContainer; +import org.opensearch.repositories.s3.SocketAccess; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.HttpStatusCode; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +/** + * A helper class that automatically uses multipart upload based on the size of the source object + */ +public final class AsyncTransferManager { + private static final Logger log = LogManager.getLogger(AsyncTransferManager.class); + private final ExecutorService executorService; + private final ExecutorService priorityExecutorService; + private final long minimumPartSize; + + /** + * The max number of parts on S3 side is 10,000 + */ + private static final long MAX_UPLOAD_PARTS = 10_000; + + /** + * Construct a new object of AsyncTransferManager + * + * @param minimumPartSize The minimum part size for parallel multipart uploads + * @param executorService The stream reader {@link ExecutorService} for normal priority uploads + * @param priorityExecutorService The stream read {@link ExecutorService} for high priority uploads + */ + public AsyncTransferManager(long minimumPartSize, ExecutorService executorService, ExecutorService priorityExecutorService) { + this.executorService = executorService; + this.priorityExecutorService = priorityExecutorService; + this.minimumPartSize = minimumPartSize; + } + + /** + * Upload an object to S3 using the async client + * + * @param s3AsyncClient S3 client to use for upload + * @param uploadRequest The {@link UploadRequest} object encapsulating all relevant details for upload + * @param streamContext The {@link StreamContext} to supply streams during upload + * @return A {@link CompletableFuture} to listen for upload completion + */ + public CompletableFuture uploadObject(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext) { + + CompletableFuture returnFuture = new CompletableFuture<>(); + try { + if (streamContext.getNumberOfParts() == 1) { + log.debug(() -> "Starting the upload as a single upload part request"); + uploadInOneChunk(s3AsyncClient, uploadRequest, streamContext.provideStream(0), returnFuture); + } else { + log.debug(() -> "Starting the upload as multipart upload request"); + uploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture); + } + } catch (Throwable throwable) { + returnFuture.completeExceptionally(throwable); + } + + return returnFuture; + } + + private void uploadInParts( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + CompletableFuture returnFuture + ) { + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + createMultipartUploadRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + } + CompletableFuture createMultipartUploadFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.createMultipartUpload(createMultipartUploadRequestBuilder.build()) + ); + + // Ensure cancellations are forwarded to the createMultipartUploadFuture future + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + } else { + log.debug(() -> "Initiated new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + doUploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, createMultipartUploadResponse.uploadId()); + } + }); + } + + private void doUploadInParts( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + CompletableFuture returnFuture, + String uploadId + ) { + + // The list of completed parts must be sorted + AtomicReferenceArray completedParts = new AtomicReferenceArray<>(streamContext.getNumberOfParts()); + AtomicReferenceArray inputStreamContainers = new AtomicReferenceArray<>(streamContext.getNumberOfParts()); + + List> futures; + try { + futures = AsyncPartsHandler.uploadParts( + s3AsyncClient, + executorService, + priorityExecutorService, + uploadRequest, + streamContext, + uploadId, + completedParts, + inputStreamContainers + ); + } catch (Exception ex) { + try { + AsyncPartsHandler.cleanUpParts(s3AsyncClient, uploadRequest, uploadId); + } finally { + returnFuture.completeExceptionally(ex); + } + return; + } + + CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(CompletableFuture[]::new)).thenApply(resp -> { + try { + uploadRequest.getUploadFinalizer().accept(true); + } catch (IOException e) { + throw new RuntimeException(e); + } + return resp; + }).thenApply(ignore -> { + if (uploadRequest.doRemoteDataIntegrityCheck()) { + mergeAndVerifyChecksum(inputStreamContainers, uploadRequest.getKey(), uploadRequest.getExpectedChecksum()); + } + return null; + }) + .thenCompose(ignore -> completeMultipartUpload(s3AsyncClient, uploadRequest, uploadId, completedParts)) + .handle(handleExceptionOrResponse(s3AsyncClient, uploadRequest, returnFuture, uploadId)) + .exceptionally(throwable -> { + handleException(returnFuture, () -> "Unexpected exception occurred", throwable); + return null; + }); + } + + private void mergeAndVerifyChecksum( + AtomicReferenceArray inputStreamContainers, + String fileName, + long expectedChecksum + ) { + long resultantChecksum = fromBase64String(inputStreamContainers.get(0).getChecksum()); + for (int index = 1; index < inputStreamContainers.length(); index++) { + long curChecksum = fromBase64String(inputStreamContainers.get(index).getChecksum()); + resultantChecksum = JZlib.crc32_combine(resultantChecksum, curChecksum, inputStreamContainers.get(index).getContentLength()); + } + + if (resultantChecksum != expectedChecksum) { + throw new RuntimeException(new CorruptFileException("File level checksums didn't match combined part checksums", fileName)); + } + } + + private BiFunction handleExceptionOrResponse( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + CompletableFuture returnFuture, + String uploadId + ) { + + return (response, throwable) -> { + if (throwable != null) { + AsyncPartsHandler.cleanUpParts(s3AsyncClient, uploadRequest, uploadId); + handleException(returnFuture, () -> "Failed to send multipart upload requests.", throwable); + } else { + returnFuture.complete(null); + } + + return null; + }; + } + + private CompletableFuture completeMultipartUpload( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + String uploadId, + AtomicReferenceArray completedParts + ) { + + log.debug(() -> new ParameterizedMessage("Sending completeMultipartUploadRequest, uploadId: {}", uploadId)); + CompletedPart[] parts = IntStream.range(0, completedParts.length()).mapToObj(completedParts::get).toArray(CompletedPart[]::new); + CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .uploadId(uploadId) + .multipartUpload(CompletedMultipartUpload.builder().parts(parts).build()) + .build(); + + return SocketAccess.doPrivileged(() -> s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest)); + } + + private static String base64StringFromLong(Long val) { + return Base64.getEncoder().encodeToString(Arrays.copyOfRange(ByteUtils.toByteArrayBE(val), 4, 8)); + } + + private static long fromBase64String(String base64String) { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + if (decodedBytes.length != 4) { + throw new IllegalArgumentException("Invalid Base64 encoded CRC32 checksum"); + } + long result = 0; + for (int i = 0; i < 4; i++) { + result <<= 8; + result |= (decodedBytes[i] & 0xFF); + } + return result; + } + + private static void handleException(CompletableFuture returnFuture, Supplier message, Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create(message.get(), cause); + returnFuture.completeExceptionally(exception); + } + } + + /** + * Calculates the optimal part size of each part request if the upload operation is carried out as multipart upload. + */ + public long calculateOptimalPartSize(long contentLengthOfSource) { + if (contentLengthOfSource < ByteSizeUnit.MB.toBytes(5)) { + return contentLengthOfSource; + } + double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; + optimalPartSize = Math.ceil(optimalPartSize); + return (long) Math.max(optimalPartSize, minimumPartSize); + } + + private void uploadInOneChunk( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + InputStreamContainer inputStreamContainer, + CompletableFuture returnFuture + ) { + PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .contentLength(uploadRequest.getContentLength()); + if (uploadRequest.doRemoteDataIntegrityCheck()) { + putObjectRequestBuilder.checksumAlgorithm(ChecksumAlgorithm.CRC32); + putObjectRequestBuilder.checksumCRC32(base64StringFromLong(uploadRequest.getExpectedChecksum())); + } + ExecutorService streamReadExecutor = uploadRequest.getWritePriority() == WritePriority.HIGH + ? priorityExecutorService + : executorService; + CompletableFuture putObjectFuture = SocketAccess.doPrivileged( + () -> s3AsyncClient.putObject( + putObjectRequestBuilder.build(), + AsyncRequestBody.fromInputStream( + inputStreamContainer.getInputStream(), + inputStreamContainer.getContentLength(), + streamReadExecutor + ) + ).handle((resp, throwable) -> { + if (throwable != null) { + Throwable unwrappedThrowable = ExceptionsHelper.unwrap(throwable, S3Exception.class); + if (unwrappedThrowable != null) { + S3Exception s3Exception = (S3Exception) unwrappedThrowable; + if (s3Exception.statusCode() == HttpStatusCode.BAD_REQUEST + && "BadDigest".equals(s3Exception.awsErrorDetails().errorCode())) { + throw new RuntimeException(new CorruptFileException(s3Exception, uploadRequest.getKey())); + } + } + returnFuture.completeExceptionally(throwable); + } else { + try { + uploadRequest.getUploadFinalizer().accept(true); + } catch (IOException e) { + throw new RuntimeException(e); + } + returnFuture.complete(null); + } + + return null; + }).handle((resp, throwable) -> { + if (throwable != null) { + deleteUploadedObject(s3AsyncClient, uploadRequest); + returnFuture.completeExceptionally(throwable); + } + + return null; + }) + ); + + CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectFuture); + CompletableFutureUtils.forwardResultTo(putObjectFuture, returnFuture); + } + + private void deleteUploadedObject(S3AsyncClient s3AsyncClient, UploadRequest uploadRequest) { + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder() + .bucket(uploadRequest.getBucket()) + .key(uploadRequest.getKey()) + .build(); + + SocketAccess.doPrivileged(() -> s3AsyncClient.deleteObject(deleteObjectRequest)).exceptionally(throwable -> { + log.error(() -> new ParameterizedMessage("Failed to delete uploaded object of key {}", uploadRequest.getKey()), throwable); + return null; + }); + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java new file mode 100644 index 0000000000000..3804c8417eb9f --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.blobstore.stream.write.WritePriority; + +import java.io.IOException; + +/** + * A model encapsulating all details for an upload to S3 + */ +public class UploadRequest { + private final String bucket; + private final String key; + private final long contentLength; + private final WritePriority writePriority; + private final CheckedConsumer uploadFinalizer; + private final boolean doRemoteDataIntegrityCheck; + private final Long expectedChecksum; + + /** + * Construct a new UploadRequest object + * + * @param bucket The name of the S3 bucket + * @param key Key of the file needed to be uploaded + * @param contentLength Total content length of the file for upload + * @param writePriority The priority of this upload + * @param uploadFinalizer An upload finalizer to call once all parts are uploaded + * @param doRemoteDataIntegrityCheck A boolean to inform vendor plugins whether remote data integrity checks need to be done + * @param expectedChecksum Checksum of the file being uploaded for remote data integrity check + */ + public UploadRequest( + String bucket, + String key, + long contentLength, + WritePriority writePriority, + CheckedConsumer uploadFinalizer, + boolean doRemoteDataIntegrityCheck, + Long expectedChecksum + ) { + this.bucket = bucket; + this.key = key; + this.contentLength = contentLength; + this.writePriority = writePriority; + this.uploadFinalizer = uploadFinalizer; + this.doRemoteDataIntegrityCheck = doRemoteDataIntegrityCheck; + this.expectedChecksum = expectedChecksum; + } + + public String getBucket() { + return bucket; + } + + public String getKey() { + return key; + } + + public long getContentLength() { + return contentLength; + } + + public WritePriority getWritePriority() { + return writePriority; + } + + public CheckedConsumer getUploadFinalizer() { + return uploadFinalizer; + } + + public boolean doRemoteDataIntegrityCheck() { + return doRemoteDataIntegrityCheck; + } + + public Long getExpectedChecksum() { + return expectedChecksum; + } +} diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/io/CheckedContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/io/CheckedContainer.java new file mode 100644 index 0000000000000..0596424093dca --- /dev/null +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/io/CheckedContainer.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.io; + +public class CheckedContainer { + + private String checksum; + private long contentLength; + + public CheckedContainer(long contentLength) { + this.contentLength = contentLength; + } + + public void setChecksum(String checksum) { + this.checksum = checksum; + } + + public String getChecksum() { + return checksum; + } + + public long getContentLength() { + return contentLength; + } +} diff --git a/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy b/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy index d74d3048addbf..106103d45e7eb 100644 --- a/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy +++ b/plugins/repository-s3/src/main/plugin-metadata/plugin-security.policy @@ -35,6 +35,7 @@ grant { // TODO: get these fixed in aws sdk permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.RuntimePermission "getClassLoader"; + permission java.lang.RuntimePermission "setContextClassLoader"; // Needed because of problems in AmazonS3Client: // When no region is set on a AmazonS3Client instance, the // AWS SDK loads all known partitions from a JSON file and @@ -56,10 +57,10 @@ grant { // only for tests : org.opensearch.repositories.s3.S3RepositoryPlugin permission java.util.PropertyPermission "opensearch.allow_insecure_settings", "read,write"; - permission java.util.PropertyPermission "aws.sharedCredentialsFile", "read,write"; permission java.util.PropertyPermission "aws.configFile", "read,write"; permission java.util.PropertyPermission "opensearch.path.conf", "read,write"; - permission java.io.FilePermission "config", "read"; + + permission java.lang.RuntimePermission "accessDeclaredMembers"; }; diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java new file mode 100644 index 0000000000000..a401ba06728d7 --- /dev/null +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3AsyncServiceTests.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3; + +import org.junit.Before; +import org.opensearch.cli.SuppressForbidden; +import org.opensearch.cluster.metadata.RepositoryMetadata; +import org.opensearch.common.settings.MockSecureSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.repositories.s3.async.AsyncExecutorContainer; +import org.opensearch.repositories.s3.async.AsyncTransferEventLoopGroup; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; +import java.util.concurrent.Executors; + +public class S3AsyncServiceTests extends OpenSearchTestCase implements ConfigPathSupport { + + @Override + @Before + @SuppressForbidden(reason = "Need to set opensearch.path.conf for async client") + public void setUp() throws Exception { + SocketAccess.doPrivileged(() -> System.setProperty("opensearch.path.conf", configPath().toString())); + super.setUp(); + } + + public void testCachedClientsAreReleased() { + final S3AsyncService s3AsyncService = new S3AsyncService(configPath()); + final Settings settings = Settings.builder().put("endpoint", "http://first").put("region", "us-east-2").build(); + final RepositoryMetadata metadata1 = new RepositoryMetadata("first", "s3", settings); + final RepositoryMetadata metadata2 = new RepositoryMetadata("second", "s3", settings); + final AsyncExecutorContainer asyncExecutorContainer = new AsyncExecutorContainer( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + new AsyncTransferEventLoopGroup(1) + ); + final S3ClientSettings clientSettings = s3AsyncService.settings(metadata2); + final S3ClientSettings otherClientSettings = s3AsyncService.settings(metadata2); + assertSame(clientSettings, otherClientSettings); + final AmazonAsyncS3Reference reference = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorContainer, asyncExecutorContainer) + ); + reference.close(); + s3AsyncService.close(); + final AmazonAsyncS3Reference referenceReloaded = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorContainer, asyncExecutorContainer) + ); + assertNotSame(referenceReloaded, reference); + referenceReloaded.close(); + s3AsyncService.close(); + final S3ClientSettings clientSettingsReloaded = s3AsyncService.settings(metadata1); + assertNotSame(clientSettings, clientSettingsReloaded); + } + + public void testCachedClientsWithCredentialsAreReleased() { + final MockSecureSettings secureSettings = new MockSecureSettings(); + secureSettings.setString("s3.client.default.role_arn", "role"); + final Map defaults = S3ClientSettings.load( + Settings.builder().setSecureSettings(secureSettings).put("s3.client.default.identity_token_file", "file").build(), + configPath() + ); + final S3AsyncService s3AsyncService = new S3AsyncService(configPath()); + s3AsyncService.refreshAndClearCache(defaults); + final Settings settings = Settings.builder().put("endpoint", "http://first").put("region", "us-east-2").build(); + final RepositoryMetadata metadata1 = new RepositoryMetadata("first", "s3", settings); + final RepositoryMetadata metadata2 = new RepositoryMetadata("second", "s3", settings); + final AsyncExecutorContainer asyncExecutorContainer = new AsyncExecutorContainer( + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor(), + new AsyncTransferEventLoopGroup(1) + ); + final S3ClientSettings clientSettings = s3AsyncService.settings(metadata2); + final S3ClientSettings otherClientSettings = s3AsyncService.settings(metadata2); + assertSame(clientSettings, otherClientSettings); + final AmazonAsyncS3Reference reference = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorContainer, asyncExecutorContainer) + ); + reference.close(); + s3AsyncService.close(); + final AmazonAsyncS3Reference referenceReloaded = SocketAccess.doPrivileged( + () -> s3AsyncService.client(metadata1, asyncExecutorContainer, asyncExecutorContainer) + ); + assertNotSame(referenceReloaded, reference); + referenceReloaded.close(); + s3AsyncService.close(); + final S3ClientSettings clientSettingsReloaded = s3AsyncService.settings(metadata1); + assertNotSame(clientSettings, clientSettingsReloaded); + } +} diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java index 9f5ebc5afe017..57d0387c96095 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java @@ -56,6 +56,7 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.io.SdkDigestInputStream; import software.amazon.awssdk.utils.internal.Base16; +import org.opensearch.repositories.blobstore.ZeroInputStream; import java.io.ByteArrayInputStream; import java.io.FilterInputStream; diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java new file mode 100644 index 0000000000000..596291a1d94fb --- /dev/null +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java @@ -0,0 +1,238 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.s3.async; + +import org.junit.Before; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.exception.CorruptFileException; +import org.opensearch.common.blobstore.stream.write.WritePriority; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.repositories.blobstore.ZeroInputStream; +import org.opensearch.test.OpenSearchTestCase; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.http.HttpStatusCode; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AsyncTransferManagerTests extends OpenSearchTestCase { + + private AsyncTransferManager asyncTransferManager; + private S3AsyncClient s3AsyncClient; + + @Override + @Before + public void setUp() throws Exception { + s3AsyncClient = mock(S3AsyncClient.class); + asyncTransferManager = new AsyncTransferManager( + ByteSizeUnit.MB.toBytes(5), + Executors.newSingleThreadExecutor(), + Executors.newSingleThreadExecutor() + ); + super.setUp(); + } + + public void testOneChunkUpload() { + CompletableFuture putObjectResponseCompletableFuture = new CompletableFuture<>(); + putObjectResponseCompletableFuture.complete(PutObjectResponse.builder().build()); + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))).thenReturn( + putObjectResponseCompletableFuture + ); + + CompletableFuture resultFuture = asyncTransferManager.uploadObject( + s3AsyncClient, + new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(1), WritePriority.HIGH, uploadSuccess -> { + // do nothing + }, false, null), + new StreamContext( + (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), + ByteSizeUnit.MB.toBytes(1), + ByteSizeUnit.MB.toBytes(1), + 1 + ) + ); + + try { + resultFuture.get(); + } catch (ExecutionException | InterruptedException e) { + fail("did not expect resultFuture to fail"); + } + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + } + + public void testOneChunkUploadCorruption() { + CompletableFuture putObjectResponseCompletableFuture = new CompletableFuture<>(); + putObjectResponseCompletableFuture.completeExceptionally( + S3Exception.builder() + .statusCode(HttpStatusCode.BAD_REQUEST) + .awsErrorDetails(AwsErrorDetails.builder().errorCode("BadDigest").build()) + .build() + ); + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))).thenReturn( + putObjectResponseCompletableFuture + ); + + CompletableFuture deleteObjectResponseCompletableFuture = new CompletableFuture<>(); + deleteObjectResponseCompletableFuture.complete(DeleteObjectResponse.builder().build()); + when(s3AsyncClient.deleteObject(any(DeleteObjectRequest.class))).thenReturn(deleteObjectResponseCompletableFuture); + + CompletableFuture resultFuture = asyncTransferManager.uploadObject( + s3AsyncClient, + new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(1), WritePriority.HIGH, uploadSuccess -> { + // do nothing + }, false, null), + new StreamContext( + (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), + ByteSizeUnit.MB.toBytes(1), + ByteSizeUnit.MB.toBytes(1), + 1 + ) + ); + + try { + resultFuture.get(); + fail("did not expect resultFuture to pass"); + } catch (ExecutionException | InterruptedException e) { + Throwable throwable = ExceptionsHelper.unwrap(e, CorruptFileException.class); + assertNotNull(throwable); + assertTrue(throwable instanceof CorruptFileException); + } + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); + } + + public void testMultipartUpload() { + CompletableFuture createMultipartUploadRequestCompletableFuture = new CompletableFuture<>(); + createMultipartUploadRequestCompletableFuture.complete(CreateMultipartUploadResponse.builder().uploadId("uploadId").build()); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + createMultipartUploadRequestCompletableFuture + ); + + CompletableFuture uploadPartResponseCompletableFuture = new CompletableFuture<>(); + uploadPartResponseCompletableFuture.complete(UploadPartResponse.builder().checksumCRC32("pzjqHA==").build()); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn( + uploadPartResponseCompletableFuture + ); + + CompletableFuture completeMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + completeMultipartUploadResponseCompletableFuture.complete(CompleteMultipartUploadResponse.builder().build()); + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn( + completeMultipartUploadResponseCompletableFuture + ); + + CompletableFuture abortMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + abortMultipartUploadResponseCompletableFuture.complete(AbortMultipartUploadResponse.builder().build()); + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + abortMultipartUploadResponseCompletableFuture + ); + + CompletableFuture resultFuture = asyncTransferManager.uploadObject( + s3AsyncClient, + new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(5), WritePriority.HIGH, uploadSuccess -> { + // do nothing + }, true, 3376132981L), + new StreamContext( + (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), + ByteSizeUnit.MB.toBytes(1), + ByteSizeUnit.MB.toBytes(1), + 5 + ) + ); + + try { + resultFuture.get(); + } catch (ExecutionException | InterruptedException e) { + fail("did not expect resultFuture to fail"); + } + + verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, times(5)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(s3AsyncClient, times(0)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + public void testMultipartUploadCorruption() { + CompletableFuture createMultipartUploadRequestCompletableFuture = new CompletableFuture<>(); + createMultipartUploadRequestCompletableFuture.complete(CreateMultipartUploadResponse.builder().uploadId("uploadId").build()); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + createMultipartUploadRequestCompletableFuture + ); + + CompletableFuture uploadPartResponseCompletableFuture = new CompletableFuture<>(); + uploadPartResponseCompletableFuture.complete(UploadPartResponse.builder().checksumCRC32("pzjqHA==").build()); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn( + uploadPartResponseCompletableFuture + ); + + CompletableFuture completeMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + completeMultipartUploadResponseCompletableFuture.complete(CompleteMultipartUploadResponse.builder().build()); + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn( + completeMultipartUploadResponseCompletableFuture + ); + + CompletableFuture abortMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + abortMultipartUploadResponseCompletableFuture.complete(AbortMultipartUploadResponse.builder().build()); + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + abortMultipartUploadResponseCompletableFuture + ); + + CompletableFuture resultFuture = asyncTransferManager.uploadObject( + s3AsyncClient, + new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(5), WritePriority.HIGH, uploadSuccess -> { + // do nothing + }, true, 0L), + new StreamContext( + (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), + ByteSizeUnit.MB.toBytes(1), + ByteSizeUnit.MB.toBytes(1), + 5 + ) + ); + + try { + resultFuture.get(); + fail("did not expect resultFuture to pass"); + } catch (ExecutionException | InterruptedException e) { + Throwable throwable = ExceptionsHelper.unwrap(e, CorruptFileException.class); + assertNotNull(throwable); + assertTrue(throwable instanceof CorruptFileException); + } + + verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, times(5)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(0)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } +} diff --git a/test/framework/src/main/java/org/opensearch/repositories/blobstore/AbstractBlobContainerRetriesTestCase.java b/test/framework/src/main/java/org/opensearch/repositories/blobstore/AbstractBlobContainerRetriesTestCase.java index adaf95ae67a8e..0361652eb2457 100644 --- a/test/framework/src/main/java/org/opensearch/repositories/blobstore/AbstractBlobContainerRetriesTestCase.java +++ b/test/framework/src/main/java/org/opensearch/repositories/blobstore/AbstractBlobContainerRetriesTestCase.java @@ -58,8 +58,6 @@ import java.util.Arrays; import java.util.Locale; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -409,85 +407,4 @@ protected void sendIncompleteContent(HttpExchange exchange, byte[] bytes) throws exchange.getResponseBody().flush(); } } - - /** - * A resettable InputStream that only serves zeros. - **/ - public static class ZeroInputStream extends InputStream { - - private final AtomicBoolean closed = new AtomicBoolean(false); - private final long length; - private final AtomicLong reads; - private volatile long mark; - - public ZeroInputStream(final long length) { - this.length = length; - this.reads = new AtomicLong(0); - this.mark = -1; - } - - @Override - public int read() throws IOException { - ensureOpen(); - return (reads.incrementAndGet() <= length) ? 0 : -1; - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - ensureOpen(); - if (len == 0) { - return 0; - } - - final int available = available(); - if (available == 0) { - return -1; - } - - final int toCopy = Math.min(len, available); - Arrays.fill(b, off, off + toCopy, (byte) 0); - reads.addAndGet(toCopy); - return toCopy; - } - - @Override - public boolean markSupported() { - return true; - } - - @Override - public synchronized void mark(int readlimit) { - mark = reads.get(); - } - - @Override - public synchronized void reset() throws IOException { - ensureOpen(); - reads.set(mark); - } - - @Override - public int available() throws IOException { - ensureOpen(); - if (reads.get() >= length) { - return 0; - } - try { - return Math.toIntExact(length - reads.get()); - } catch (ArithmeticException e) { - return Integer.MAX_VALUE; - } - } - - @Override - public void close() { - closed.set(true); - } - - private void ensureOpen() throws IOException { - if (closed.get()) { - throw new IOException("Stream closed"); - } - } - } } diff --git a/test/framework/src/main/java/org/opensearch/repositories/blobstore/ZeroInputStream.java b/test/framework/src/main/java/org/opensearch/repositories/blobstore/ZeroInputStream.java new file mode 100644 index 0000000000000..f299c17f16a06 --- /dev/null +++ b/test/framework/src/main/java/org/opensearch/repositories/blobstore/ZeroInputStream.java @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.repositories.blobstore; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** + * A resettable InputStream that only serves zeros. + **/ +public class ZeroInputStream extends InputStream { + + private final AtomicBoolean closed = new AtomicBoolean(false); + private final long length; + private final AtomicLong reads; + private volatile long mark; + + public ZeroInputStream(final long length) { + this.length = length; + this.reads = new AtomicLong(0); + this.mark = -1; + } + + @Override + public int read() throws IOException { + ensureOpen(); + return (reads.incrementAndGet() <= length) ? 0 : -1; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + if (len == 0) { + return 0; + } + + final int available = available(); + if (available == 0) { + return -1; + } + + final int toCopy = Math.min(len, available); + Arrays.fill(b, off, off + toCopy, (byte) 0); + reads.addAndGet(toCopy); + return toCopy; + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public synchronized void mark(int readlimit) { + mark = reads.get(); + } + + @Override + public synchronized void reset() throws IOException { + ensureOpen(); + reads.set(mark); + } + + @Override + public int available() throws IOException { + ensureOpen(); + if (reads.get() >= length) { + return 0; + } + try { + return Math.toIntExact(length - reads.get()); + } catch (ArithmeticException e) { + return Integer.MAX_VALUE; + } + } + + @Override + public void close() { + closed.set(true); + } + + private void ensureOpen() throws IOException { + if (closed.get()) { + throw new IOException("Stream closed"); + } + } +}