From ac70834c13bef12c34b8c5cc667051f1e736df1d Mon Sep 17 00:00:00 2001 From: Thomas Farr Date: Fri, 1 Nov 2024 15:51:15 +1300 Subject: [PATCH] Fix AWS SigV4 on delete requests when using AWS SDK's Apache client The AWS SDK's Apache client implementation does not send the `Content-Length` header on DELETE requests, but the header is being set before calculating the signature. This causes the Amazon OpenSearch Service to report an incorrect signature as it does not receive the header value needed to validate the signature. `Content-Length` is somewhat unreliable to include in the signature calculation, but the AWS SDK doesn't allow configuring which headers to ignore in signature calculation. As such we must move setting the header to after the signature is calculated. Additionally moves to the supported `AwsV4HttpSigner` as `Aws4Signer` is now deprecated: https://github.com/aws/aws-sdk-java-v2/blob/88abec27e7d5d35b21545c7e05875a7cc3d0f46e/core/auth/src/main/java/software/amazon/awssdk/auth/signer/Aws4Signer.java Signed-off-by: Thomas Farr --- CHANGELOG.md | 1 + java-client/build.gradle.kts | 24 +- .../transport/aws/AwsSdk2Transport.java | 138 +++++---- .../aws/AwsSdk2TransportOptions.java | 38 ++- .../transport/aws/AwsSdk2TransportTests.java | 289 ++++++++++++++++++ 5 files changed, 411 insertions(+), 79 deletions(-) create mode 100644 java-client/src/test/java/org/opensearch/client/transport/aws/AwsSdk2TransportTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 3630053b08..5be3c1af37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ This section is for maintaining a changelog for all breaking changes for the cli ### Fixed - Fixed deserializing `GeoBoundsAggregate` when `bounds` is not returned by OpenSearch ([#1238](https://github.com/opensearch-project/opensearch-java/pull/1238)) +- Fixed AWS SigV4 on delete requests when using AWS SDK's Apache client ([#]()) ### Security diff --git a/java-client/build.gradle.kts b/java-client/build.gradle.kts index 56cbee28ec..d82d57bf72 100644 --- a/java-client/build.gradle.kts +++ b/java-client/build.gradle.kts @@ -173,7 +173,6 @@ val integrationTest = task("integrationTest") { val opensearchVersion = "3.0.0-SNAPSHOT" dependencies { - val jacksonVersion = "2.17.0" val jacksonDatabindVersion = "2.17.0" @@ -210,21 +209,25 @@ dependencies { implementation("jakarta.annotation", "jakarta.annotation-api", "1.3.5") // Apache 2.0 - implementation("com.fasterxml.jackson.core", "jackson-core", jacksonVersion) implementation("com.fasterxml.jackson.core", "jackson-databind", jacksonDatabindVersion) testImplementation("com.fasterxml.jackson.datatype", "jackson-datatype-jakarta-jsonp", jacksonVersion) // For AwsSdk2Transport - "awsSdk2SupportCompileOnly"("software.amazon.awssdk","sdk-core","[2.15,3.0)") - "awsSdk2SupportCompileOnly"("software.amazon.awssdk","auth","[2.15,3.0)") - testImplementation("software.amazon.awssdk","sdk-core","[2.15,3.0)") - testImplementation("software.amazon.awssdk","auth","[2.15,3.0)") - testImplementation("software.amazon.awssdk","aws-crt-client","[2.15,3.0)") - testImplementation("software.amazon.awssdk","apache-client","[2.15,3.0)") - testImplementation("software.amazon.awssdk","sts","[2.15,3.0)") + "awsSdk2SupportCompileOnly"("software.amazon.awssdk", "sdk-core", "[2.21,3.0)") + "awsSdk2SupportCompileOnly"("software.amazon.awssdk", "auth", "[2.21,3.0)") + "awsSdk2SupportCompileOnly"("software.amazon.awssdk", "http-auth-aws", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "sdk-core", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "auth", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "http-auth-aws", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "aws-crt-client", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "apache-client", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "netty-nio-client", "[2.21,3.0)") + testImplementation("software.amazon.awssdk", "sts", "[2.21,3.0)") + testImplementation("org.apache.logging.log4j", "log4j-api","[2.17.1,3.0)") testImplementation("org.apache.logging.log4j", "log4j-core","[2.17.1,3.0)") + // EPL-2.0 OR BSD-3-Clause // https://eclipse-ee4j.github.io/yasson/ implementation("org.eclipse", "yasson", "2.0.2") @@ -236,6 +239,9 @@ dependencies { testImplementation("junit", "junit" , "4.13.2") { exclude(group = "org.hamcrest") } + + // Apache 2.0 + testImplementation("org.wiremock", "wiremock", "3.9.2") } licenseReport { diff --git a/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2Transport.java b/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2Transport.java index 310d936e4a..83e27297ce 100644 --- a/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2Transport.java +++ b/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2Transport.java @@ -19,6 +19,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.URLEncoder; +import java.time.Clock; import java.util.AbstractMap; import java.util.Collection; import java.util.Map; @@ -26,7 +27,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.function.Supplier; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import javax.annotation.CheckForNull; @@ -52,18 +53,19 @@ import org.opensearch.client.util.OpenSearchRequestBodyBuffer; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.auth.signer.params.Aws4SignerParams; import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.Header; import software.amazon.awssdk.http.HttpExecuteRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.SdkHttpRequest; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.http.async.AsyncExecuteRequest; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.utils.IoUtils; import software.amazon.awssdk.utils.SdkAutoCloseable; @@ -85,6 +87,7 @@ public class AwsSdk2Transport implements OpenSearchTransport { private final String signingServiceName; private final Region signingRegion; private final JsonpMapper defaultMapper; + @Nonnull private final AwsSdk2TransportOptions transportOptions; /** @@ -195,12 +198,11 @@ public ResponseT performRequest( Endpoint endpoint, @Nullable TransportOptions options ) throws IOException { - OpenSearchRequestBodyBuffer requestBody = prepareRequestBody(request, endpoint, options); - SdkHttpFullRequest clientReq = prepareRequest(request, endpoint, options, requestBody); + SdkHttpRequest clientReq = prepareRequest(request, endpoint, options, requestBody); if (httpClient instanceof SdkHttpClient) { - return executeSync((SdkHttpClient) httpClient, clientReq, endpoint, options); + return executeSync((SdkHttpClient) httpClient, clientReq, requestBody, endpoint, options); } else if (httpClient instanceof SdkAsyncHttpClient) { try { return executeAsync((SdkAsyncHttpClient) httpClient, clientReq, requestBody, endpoint, options).get(); @@ -229,11 +231,11 @@ public CompletableFuture performRequest ) { try { OpenSearchRequestBodyBuffer requestBody = prepareRequestBody(request, endpoint, options); - SdkHttpFullRequest clientReq = prepareRequest(request, endpoint, options, requestBody); + SdkHttpRequest clientReq = prepareRequest(request, endpoint, options, requestBody); if (httpClient instanceof SdkAsyncHttpClient) { return executeAsync((SdkAsyncHttpClient) httpClient, clientReq, requestBody, endpoint, options); } else if (httpClient instanceof SdkHttpClient) { - ResponseT result = executeSync((SdkHttpClient) httpClient, clientReq, endpoint, options); + ResponseT result = executeSync((SdkHttpClient) httpClient, clientReq, requestBody, endpoint, options); return CompletableFuture.completedFuture(result); } else { throw new IOException("invalid httpClient: " + httpClient); @@ -265,16 +267,12 @@ private OpenSearchRequestBodyBuffer prepareRequestBody( TransportOptions options ) throws IOException { if (endpoint.hasRequestBody()) { - final JsonpMapper mapper = Optional.ofNullable(options) - .map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null) - .map(AwsSdk2TransportOptions::mapper) - .orElse(defaultMapper); - final int maxUncompressedSize = or( - Optional.ofNullable(options) - .map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null) - .map(AwsSdk2TransportOptions::requestCompressionSize), - () -> Optional.ofNullable(transportOptions.requestCompressionSize()) - ).orElse(DEFAULT_REQUEST_COMPRESSION_SIZE); + final JsonpMapper mapper = Optional.ofNullable( + options instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) options) : null + ).map(AwsSdk2TransportOptions::mapper).orElse(defaultMapper); + final int maxUncompressedSize = getOption(options, AwsSdk2TransportOptions::requestCompressionSize).orElse( + DEFAULT_REQUEST_COMPRESSION_SIZE + ); OpenSearchRequestBodyBuffer buffer = new OpenSearchRequestBodyBuffer(mapper, maxUncompressedSize); buffer.addContent(request); @@ -284,7 +282,7 @@ private OpenSearchRequestBodyBuffer prepareRequestBody( return null; } - private SdkHttpFullRequest prepareRequest( + private SdkHttpRequest prepareRequest( RequestT request, Endpoint endpoint, @CheckForNull TransportOptions options, @@ -315,46 +313,57 @@ private SdkHttpFullRequest prepareRequest( } catch (URISyntaxException e) { throw new IllegalArgumentException("Invalid request URI: " + url.toString()); } + + ContentStreamProvider bodyProvider = body != null ? ContentStreamProvider.fromInputStreamSupplier(body::getInputStream) : null; + + applyHeadersPreSigning(req, options, body); + + final AwsCredentialsProvider credentials = getOption(options, AwsSdk2TransportOptions::credentials).orElseGet( + DefaultCredentialsProvider::create + ); + + final Clock signingClock = getOption(options, AwsSdk2TransportOptions::signingClock).orElse(null); + + SdkHttpRequest.Builder signedReq = AwsV4HttpSigner.create() + .sign( + b -> b.identity(credentials.resolveCredentials()) + .request(req.build()) + .payload(bodyProvider) + .putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, this.signingServiceName) + .putProperty(AwsV4HttpSigner.REGION_NAME, this.signingRegion.id()) + .putProperty(AwsV4HttpSigner.SIGNING_CLOCK, signingClock) + ) + .request() + .toBuilder(); + + applyHeadersPostSigning(signedReq, body); + + return signedReq.build(); + } + + private void applyHeadersPreSigning(SdkHttpRequest.Builder req, TransportOptions options, OpenSearchRequestBodyBuffer body) { applyOptionsHeaders(req, transportOptions); applyOptionsHeaders(req, options); - if (endpoint.hasRequestBody() && body != null) { + + if (body != null) { req.putHeader("Content-Type", body.getContentType()); String encoding = body.getContentEncoding(); if (encoding != null) { req.putHeader("Content-Encoding", encoding); } - req.putHeader("Content-Length", String.valueOf(body.getContentLength())); - req.contentStreamProvider(body::getInputStream); - // To add the "X-Amz-Content-Sha256" header, it needs to set as required. - // It is a required header for Amazon OpenSearch Serverless. - req.putHeader("x-amz-content-sha256", "required"); } - boolean responseCompression = or( - Optional.ofNullable(options) - .map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null) - .map(AwsSdk2TransportOptions::responseCompression), - () -> Optional.ofNullable(transportOptions.responseCompression()) - ).orElse(Boolean.TRUE); - if (responseCompression) { + if (getOption(options, AwsSdk2TransportOptions::responseCompression).orElse(Boolean.TRUE)) { req.putHeader("Accept-Encoding", "gzip"); } else { req.removeHeader("Accept-Encoding"); } + } - final AwsCredentialsProvider credentials = or( - Optional.ofNullable(options) - .map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null) - .map(AwsSdk2TransportOptions::credentials), - () -> Optional.ofNullable(transportOptions.credentials()) - ).orElse(DefaultCredentialsProvider.create()); - - Aws4SignerParams signerParams = Aws4SignerParams.builder() - .awsCredentials(credentials.resolveCredentials()) - .signingName(this.signingServiceName) - .signingRegion(signingRegion) - .build(); - return Aws4Signer.create().sign(req.build(), signerParams); + private void applyHeadersPostSigning(SdkHttpRequest.Builder req, OpenSearchRequestBodyBuffer body) { + if (body != null) { + req.putHeader("Content-Length", String.valueOf(body.getContentLength())); + } } private void applyOptionsParams(StringBuilder url, TransportOptions options) throws UnsupportedEncodingException { @@ -372,7 +381,7 @@ private void applyOptionsParams(StringBuilder url, TransportOptions options) thr } } - private void applyOptionsHeaders(SdkHttpFullRequest.Builder builder, TransportOptions options) { + private void applyOptionsHeaders(SdkHttpRequest.Builder builder, TransportOptions options) { if (options == null) { return; } @@ -386,14 +395,14 @@ private void applyOptionsHeaders(SdkHttpFullRequest.Builder builder, TransportOp private ResponseT executeSync( SdkHttpClient syncHttpClient, - SdkHttpFullRequest httpRequest, + SdkHttpRequest httpRequest, + OpenSearchRequestBodyBuffer requestBody, Endpoint endpoint, TransportOptions options ) throws IOException { - HttpExecuteRequest.Builder executeRequest = HttpExecuteRequest.builder().request(httpRequest); - if (httpRequest.contentStreamProvider().isPresent()) { - executeRequest.contentStreamProvider(httpRequest.contentStreamProvider().get()); + if (requestBody != null) { + executeRequest.contentStreamProvider(ContentStreamProvider.fromInputStreamSupplier(requestBody::getInputStream)); } HttpExecuteResponse executeResponse = syncHttpClient.prepareRequest(executeRequest.build()).call(); AbortableInputStream bodyStream = null; @@ -418,13 +427,12 @@ private ResponseT executeSync( private CompletableFuture executeAsync( SdkAsyncHttpClient asyncHttpClient, - SdkHttpFullRequest httpRequest, + SdkHttpRequest httpRequest, @CheckForNull OpenSearchRequestBodyBuffer requestBody, Endpoint endpoint, TransportOptions options ) { byte[] requestBodyArray = requestBody == null ? NO_BYTES : requestBody.getByteArray(); - final AsyncCapturingResponseHandler responseHandler = new AsyncCapturingResponseHandler(); AsyncExecuteRequest.Builder executeRequest = AsyncExecuteRequest.builder() .request(httpRequest) @@ -463,10 +471,9 @@ private ResponseT parseResponse( @Nonnull Endpoint endpoint, @CheckForNull TransportOptions options ) throws IOException { - final JsonpMapper mapper = Optional.ofNullable(options) - .map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null) - .map(AwsSdk2TransportOptions::mapper) - .orElse(defaultMapper); + final JsonpMapper mapper = Optional.ofNullable( + options instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) options) : null + ).map(AwsSdk2TransportOptions::mapper).orElse(defaultMapper); int statusCode = httpResponse.statusCode(); boolean isZipped = httpResponse.firstMatchingHeader("Content-Encoding").map(enc -> enc.contains("gzip")).orElse(Boolean.FALSE); @@ -625,16 +632,15 @@ private ResponseT decodeResponse( } } - private static Optional or(Optional opt, Supplier> supplier) { - Objects.requireNonNull(opt); - Objects.requireNonNull(supplier); - if (opt.isPresent()) { - return opt; - } else { - @SuppressWarnings("unchecked") - Optional r = (Optional) supplier.get(); - return Objects.requireNonNull(r); - } + private Optional getOption(@Nullable TransportOptions options, @Nonnull Function getter) { + Objects.requireNonNull(getter, "getter must not be null"); + + Function> optGetter = o -> Optional.ofNullable(getter.apply(o)); + + Optional opt = Optional.ofNullable(options instanceof AwsSdk2TransportOptions ? (AwsSdk2TransportOptions) options : null) + .flatMap(optGetter); + + return opt.isPresent() ? opt : optGetter.apply(transportOptions); } private static ByteArrayInputStream toByteArrayInputStream(InputStream is) throws IOException { diff --git a/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2TransportOptions.java b/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2TransportOptions.java index 1d10f9c424..29f3b687b1 100644 --- a/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2TransportOptions.java +++ b/java-client/src/main/java/org/opensearch/client/transport/aws/AwsSdk2TransportOptions.java @@ -8,6 +8,7 @@ package org.opensearch.client.transport.aws; +import java.time.Clock; import java.util.List; import java.util.function.Function; import org.opensearch.client.json.JsonpMapper; @@ -71,6 +72,18 @@ public interface AwsSdk2TransportOptions extends TransportOptions { */ JsonpMapper mapper(); + /** + * Get the clock used for signing requests. + *

+ * If this is null, then a default will be used -- either a value specified + * in a more general {@link AwsSdk2TransportOptions} that applies to the request, or + * {@link Clock#systemUTC()} if there is none. + *

+ * + * @return A clock or null + */ + Clock signingClock(); + AwsSdk2TransportOptions.Builder toBuilder(); static AwsSdk2TransportOptions.Builder builder() { @@ -92,6 +105,8 @@ interface Builder extends TransportOptions.Builder { Builder setMapper(JsonpMapper mapper); + Builder setSigningClock(Clock clock); + AwsSdk2TransportOptions build(); } @@ -101,6 +116,7 @@ class BuilderImpl extends TransportOptions.BuilderImpl implements Builder { protected Integer requestCompressionSize; protected Boolean responseCompression; protected JsonpMapper mapper; + protected Clock signingClock; public BuilderImpl() {} @@ -110,6 +126,7 @@ public BuilderImpl(AwsSdk2TransportOptions src) { requestCompressionSize = src.requestCompressionSize(); responseCompression = src.responseCompression(); mapper = src.mapper(); + signingClock = src.signingClock(); } @Override @@ -154,6 +171,12 @@ public Builder setResponseCompression(Boolean enabled) { return this; } + @Override + public Builder setSigningClock(Clock clock) { + this.signingClock = clock; + return this; + } + @Override public AwsSdk2TransportOptions build() { return new DefaultImpl(this); @@ -162,10 +185,11 @@ public AwsSdk2TransportOptions build() { class DefaultImpl extends TransportOptions.DefaultImpl implements AwsSdk2TransportOptions { - private AwsCredentialsProvider credentials; - private Integer requestCompressionSize; - private Boolean responseCompression; - private JsonpMapper mapper; + private final AwsCredentialsProvider credentials; + private final Integer requestCompressionSize; + private final Boolean responseCompression; + private final JsonpMapper mapper; + private final Clock signingClock; DefaultImpl(AwsSdk2TransportOptions.BuilderImpl builder) { super(builder); @@ -173,6 +197,7 @@ class DefaultImpl extends TransportOptions.DefaultImpl implements AwsSdk2Transpo requestCompressionSize = builder.requestCompressionSize; responseCompression = builder.responseCompression; mapper = builder.mapper; + signingClock = builder.signingClock; } @Override @@ -195,6 +220,11 @@ public JsonpMapper mapper() { return mapper; } + @Override + public Clock signingClock() { + return signingClock; + } + @Override public AwsSdk2TransportOptions.Builder toBuilder() { return new AwsSdk2TransportOptions.BuilderImpl(this); diff --git a/java-client/src/test/java/org/opensearch/client/transport/aws/AwsSdk2TransportTests.java b/java-client/src/test/java/org/opensearch/client/transport/aws/AwsSdk2TransportTests.java new file mode 100644 index 0000000000..cc19d77b7d --- /dev/null +++ b/java-client/src/test/java/org/opensearch/client/transport/aws/AwsSdk2TransportTests.java @@ -0,0 +1,289 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.client.transport.aws; + +import static com.github.tomakehurst.wiremock.client.WireMock.any; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.delete; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.put; +import static com.github.tomakehurst.wiremock.client.WireMock.serviceUnavailable; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathTemplate; +import static com.github.tomakehurst.wiremock.common.ContentTypes.APPLICATION_JSON; +import static com.github.tomakehurst.wiremock.common.ContentTypes.CONTENT_LENGTH; +import static com.github.tomakehurst.wiremock.common.ContentTypes.CONTENT_TYPE; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import com.github.tomakehurst.wiremock.http.RequestMethod; +import com.github.tomakehurst.wiremock.junit.WireMockRule; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; +import java.net.URI; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.opensearch.client.opensearch.OpenSearchClient; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpConfigurationOption; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; +import software.amazon.awssdk.http.crt.AwsCrtHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.utils.AttributeMap; + +@RunWith(Parameterized.class) +public class AwsSdk2TransportTests { + private static final Region TEST_REGION = Region.AP_SOUTHEAST_2; + + @Rule + public final WireMockRule wireMockRule = new WireMockRule(wireMockConfig().dynamicPort().enableBrowserProxying(true)); + private final SdkHttpClientType sdkHttpClientType; + private final String serviceName; + private final String serviceHostName; + + public AwsSdk2TransportTests(SdkHttpClientType sdkHttpClientType, String serviceName) { + this.sdkHttpClientType = sdkHttpClientType; + this.serviceName = serviceName; + this.serviceHostName = "aaabbbcccddd111222333." + TEST_REGION.toString() + "." + serviceName + ".amazonaws.com"; + } + + @Parameterized.Parameters(name = "sdkHttpClientType: {0}, serviceName: {1}") + public static Collection getParameters() { + var serviceNames = List.of("aoss", "es", "arbitrary"); + return Arrays.stream(SdkHttpClientType.values()) + .flatMap(sdkHttpClientType -> serviceNames.stream().map(serviceName -> new Object[] { sdkHttpClientType, serviceName })) + .collect(Collectors.toList()); + } + + public enum SdkHttpClientType { + AWS_CRT, + AWS_CRT_ASYNC, + APACHE, + NETTY_NIO_ASYNC + } + + @Before + public void setup() { + stubFor(any(anyUrl()).atPriority(10).willReturn(serviceUnavailable())); + + stubFor( + put(urlPathTemplate("/{index}")).atPriority(1) + .willReturn( + okJson("{\"acknowledged\": true,\"shards_acknowledged\": true,\"index\": \"{{request.path.index}}\"}").withTransformers( + "response-template" + ) + ) + ); + + stubFor(delete(urlPathEqualTo("/_search/scroll")).atPriority(1).willReturn(okJson("{\"succeeded\": true,\"num_freed\": 1}"))); + + stubFor( + delete(urlPathEqualTo("/_search/point_in_time")).atPriority(1) + .willReturn(okJson("{\"pits\": [{\"pit_id\": \"pit1\", \"successful\": true}]}")) + ); + } + + private OpenSearchClient getTestClient() throws Exception { + AwsSdk2TransportOptions options = AwsSdk2TransportOptions.builder() + .setCredentials(() -> AwsBasicCredentials.builder().accessKeyId("test-access-key").secretAccessKey("test-secret-key").build()) + .setSigningClock(Clock.fixed(Instant.ofEpochSecond(1673626117), ZoneId.of("UTC"))) // 2023-01-13 16:08:37 +0000 + .setResponseCompression(false) + .build(); + + AttributeMap sdkConfig = AttributeMap.builder().put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, true).build(); + + SdkHttpClient sdkHttpClient = null; + SdkAsyncHttpClient sdkAsyncHttpClient = null; + switch (sdkHttpClientType) { + case AWS_CRT: + sdkHttpClient = AwsCrtHttpClient.builder() + .proxyConfiguration(p -> p.scheme("http").host("localhost").port(wireMockRule.port())) + .buildWithDefaults(sdkConfig); + break; + case AWS_CRT_ASYNC: + sdkAsyncHttpClient = AwsCrtAsyncHttpClient.builder() + .proxyConfiguration(p -> p.scheme("http").host("localhost").port(wireMockRule.port())) + .buildWithDefaults(sdkConfig); + break; + case APACHE: + var proxyConfig = software.amazon.awssdk.http.apache.ProxyConfiguration.builder() + .endpoint(new URI("http://localhost:" + wireMockRule.port())) + .build(); + sdkHttpClient = ApacheHttpClient.builder().proxyConfiguration(proxyConfig).buildWithDefaults(sdkConfig); + break; + case NETTY_NIO_ASYNC: + var nettyProxyConfig = software.amazon.awssdk.http.nio.netty.ProxyConfiguration.builder() + .scheme("http") + .host("localhost") + .port(wireMockRule.port()) + .build(); + sdkAsyncHttpClient = NettyNioAsyncHttpClient.builder().proxyConfiguration(nettyProxyConfig).buildWithDefaults(sdkConfig); + break; + default: + throw new IllegalArgumentException("Unknown SdkHttpClientType: " + sdkHttpClientType); + } + + AwsSdk2Transport transport; + if (sdkAsyncHttpClient != null) { + transport = new AwsSdk2Transport(sdkAsyncHttpClient, serviceHostName, serviceName, TEST_REGION, options); + } else { + transport = new AwsSdk2Transport(sdkHttpClient, serviceHostName, serviceName, TEST_REGION, options); + } + return new OpenSearchClient(transport); + } + + private LoggedRequest getReceivedRequest() { + var serveEvents = wireMockRule.getAllServeEvents(); + assertEquals(1, serveEvents.size()); + return serveEvents.get(0).getRequest(); + } + + @Test + public void testSigV4PutIndex() throws Exception { + String expectedSignature = null; + switch (serviceName) { + case "aoss": + expectedSignature = "29123ccbcbd9af71fce384a1ed6d64b8c70f660e55a16de05405cac5fbebf18b"; + break; + case "es": + expectedSignature = "ff12e7b3e5e0f96fa25f13b3e95606dd18e3f1314dea6b7d6a9159f0aa51c21c"; + break; + case "arbitrary": + expectedSignature = "dbddbed28a34c0c380cd31567491a240294ef58755f9370e237d66f10d20d2df"; + break; + } + + OpenSearchClient client = getTestClient(); + + var resp = client.indices() + .create( + b -> b.index("sample-index1") + .aliases("sample-alias1", a -> a) + .mappings(m -> m.properties("age", p -> p.integer(i -> i))) + .settings(s -> s.index(i -> i.numberOfReplicas("1").numberOfShards("2"))) + ); + assertEquals("sample-index1", resp.index()); + assertEquals(Boolean.TRUE, resp.acknowledged()); + + var req = getReceivedRequest(); + + assertEquals(RequestMethod.PUT, req.getMethod()); + assertEquals(APPLICATION_JSON, req.getHeader(CONTENT_TYPE)); + assertEquals("156", req.getHeader(CONTENT_LENGTH)); + assertEquals(serviceHostName, req.getHeader("Host")); + assertEquals("20230113T160837Z", req.getHeader("x-amz-date")); + assertEquals("381bb92a04d397cab611362eb3ac3e075db11ac08272d64763de2279e2b5604d", req.getHeader("x-amz-content-sha256")); + assertEquals( + "AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/" + + serviceName + + "/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, Signature=" + + expectedSignature, + req.getHeader("Authorization") + ); + } + + @Test + public void testSigV4ClearScroll() throws Exception { + String expectedSignature = null; + switch (serviceName) { + case "aoss": + expectedSignature = "8c5d3d990f038e1d980a7d1b1611fa55f9b9b29a018a89ec84a6b9286e0e782d"; + break; + case "es": + expectedSignature = "f423dc8dce53a90d9f8e0701a8a721e54119b97201366438796d74ca0265f08d"; + break; + case "arbitrary": + expectedSignature = "63dd431cb3d4e2ba9e0aaf183975b1d19528de23bd68ee0c4269000008545922"; + break; + } + + OpenSearchClient client = getTestClient(); + + client.clearScroll(); + + var req = getReceivedRequest(); + + assertEquals(RequestMethod.DELETE, req.getMethod()); + assertEquals(APPLICATION_JSON, req.getHeader(CONTENT_TYPE)); + var contentLength = req.getHeader(CONTENT_LENGTH); + if (sdkHttpClientType != SdkHttpClientType.APACHE) { + assertEquals("2", contentLength); + } else { + // Apache client does not set content-length for DELETE requests + assertNull(contentLength); + } + assertEquals(serviceHostName, req.getHeader("Host")); + assertEquals("20230113T160837Z", req.getHeader("x-amz-date")); + assertEquals("44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", req.getHeader("x-amz-content-sha256")); + assertEquals( + "AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/" + + serviceName + + "/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, Signature=" + + expectedSignature, + req.getHeader("Authorization") + ); + } + + @Test + public void testSigV4DeletePit() throws Exception { + String expectedSignature = null; + switch (serviceName) { + case "aoss": + expectedSignature = "82cb4f441ca313047542597cd54bdb3139ce111e269fe3bade5d59a1b2cd00a0"; + break; + case "es": + expectedSignature = "6abef10fb828cfc62683f38fbaa93894885308b0516bbe7b5485ae99e16b51bb"; + break; + case "arbitrary": + expectedSignature = "59697fbb5f10b197a1abea0264e7380d34db3c99b428bfa3781c0b665242f420"; + break; + } + + OpenSearchClient client = getTestClient(); + + client.deletePit(d -> d.pitId(List.of("pit1"))); + + var req = getReceivedRequest(); + + assertEquals(RequestMethod.DELETE, req.getMethod()); + assertEquals(APPLICATION_JSON, req.getHeader(CONTENT_TYPE)); + var contentLength = req.getHeader(CONTENT_LENGTH); + if (sdkHttpClientType != SdkHttpClientType.APACHE) { + assertEquals("19", contentLength); + } else { + // Apache client does not set content-length for DELETE requests + assertNull(contentLength); + } + assertEquals(serviceHostName, req.getHeader("Host")); + assertEquals("20230113T160837Z", req.getHeader("x-amz-date")); + assertEquals("daaa6af55a9cfe622f46de69ebc3b4df84703f320b839346b7fb4cf94bdbd766", req.getHeader("x-amz-content-sha256")); + assertEquals( + "AWS4-HMAC-SHA256 Credential=test-access-key/20230113/ap-southeast-2/" + + serviceName + + "/aws4_request, SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date, Signature=" + + expectedSignature, + req.getHeader("Authorization") + ); + } +}