Skip to content

Commit

Permalink
Fix AWS SigV4 on delete requests when using AWS SDK's Apache client
Browse files Browse the repository at this point in the history
The AWS SDK's Apache client implementation does not send the `Content-Length` header on DELETE requests, but the header is being set before calculating the signature.
This causes the Amazon OpenSearch Service to report an incorrect signature as it does not receive the header value needed to validate the signature.

`Content-Length` is somewhat unreliable to include in the signature calculation, but the AWS SDK doesn't allow configuring which headers to ignore in signature calculation.
As such we must move setting the header to after the signature is calculated.

Additionally moves to the supported `AwsV4HttpSigner` as `Aws4Signer` is now deprecated: https://github.com/aws/aws-sdk-java-v2/blob/88abec27e7d5d35b21545c7e05875a7cc3d0f46e/core/auth/src/main/java/software/amazon/awssdk/auth/signer/Aws4Signer.java

Signed-off-by: Thomas Farr <[email protected]>
  • Loading branch information
Xtansia committed Nov 1, 2024
1 parent b962e89 commit ac70834
Show file tree
Hide file tree
Showing 5 changed files with 411 additions and 79 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Fixed
- Fixed deserializing `GeoBoundsAggregate` when `bounds` is not returned by OpenSearch ([#1238](https://github.com/opensearch-project/opensearch-java/pull/1238))
- Fixed AWS SigV4 on delete requests when using AWS SDK's Apache client ([#]())

### Security

Expand Down
24 changes: 15 additions & 9 deletions java-client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ val integrationTest = task<Test>("integrationTest") {
val opensearchVersion = "3.0.0-SNAPSHOT"

dependencies {

val jacksonVersion = "2.17.0"
val jacksonDatabindVersion = "2.17.0"

Expand Down Expand Up @@ -210,21 +209,25 @@ dependencies {
implementation("jakarta.annotation", "jakarta.annotation-api", "1.3.5")

// Apache 2.0

implementation("com.fasterxml.jackson.core", "jackson-core", jacksonVersion)
implementation("com.fasterxml.jackson.core", "jackson-databind", jacksonDatabindVersion)
testImplementation("com.fasterxml.jackson.datatype", "jackson-datatype-jakarta-jsonp", jacksonVersion)

// For AwsSdk2Transport
"awsSdk2SupportCompileOnly"("software.amazon.awssdk","sdk-core","[2.15,3.0)")
"awsSdk2SupportCompileOnly"("software.amazon.awssdk","auth","[2.15,3.0)")
testImplementation("software.amazon.awssdk","sdk-core","[2.15,3.0)")
testImplementation("software.amazon.awssdk","auth","[2.15,3.0)")
testImplementation("software.amazon.awssdk","aws-crt-client","[2.15,3.0)")
testImplementation("software.amazon.awssdk","apache-client","[2.15,3.0)")
testImplementation("software.amazon.awssdk","sts","[2.15,3.0)")
"awsSdk2SupportCompileOnly"("software.amazon.awssdk", "sdk-core", "[2.21,3.0)")
"awsSdk2SupportCompileOnly"("software.amazon.awssdk", "auth", "[2.21,3.0)")
"awsSdk2SupportCompileOnly"("software.amazon.awssdk", "http-auth-aws", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "sdk-core", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "auth", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "http-auth-aws", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "aws-crt-client", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "apache-client", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "netty-nio-client", "[2.21,3.0)")
testImplementation("software.amazon.awssdk", "sts", "[2.21,3.0)")

testImplementation("org.apache.logging.log4j", "log4j-api","[2.17.1,3.0)")
testImplementation("org.apache.logging.log4j", "log4j-core","[2.17.1,3.0)")

// EPL-2.0 OR BSD-3-Clause
// https://eclipse-ee4j.github.io/yasson/
implementation("org.eclipse", "yasson", "2.0.2")
Expand All @@ -236,6 +239,9 @@ dependencies {
testImplementation("junit", "junit" , "4.13.2") {
exclude(group = "org.hamcrest")
}

// Apache 2.0
testImplementation("org.wiremock", "wiremock", "3.9.2")
}

licenseReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.time.Clock;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import javax.annotation.CheckForNull;
Expand All @@ -52,18 +53,19 @@
import org.opensearch.client.util.OpenSearchRequestBodyBuffer;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.Header;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.IoUtils;
import software.amazon.awssdk.utils.SdkAutoCloseable;
Expand All @@ -85,6 +87,7 @@ public class AwsSdk2Transport implements OpenSearchTransport {
private final String signingServiceName;
private final Region signingRegion;
private final JsonpMapper defaultMapper;
@Nonnull
private final AwsSdk2TransportOptions transportOptions;

/**
Expand Down Expand Up @@ -195,12 +198,11 @@ public <RequestT, ResponseT, ErrorT> ResponseT performRequest(
Endpoint<RequestT, ResponseT, ErrorT> endpoint,
@Nullable TransportOptions options
) throws IOException {

OpenSearchRequestBodyBuffer requestBody = prepareRequestBody(request, endpoint, options);
SdkHttpFullRequest clientReq = prepareRequest(request, endpoint, options, requestBody);
SdkHttpRequest clientReq = prepareRequest(request, endpoint, options, requestBody);

if (httpClient instanceof SdkHttpClient) {
return executeSync((SdkHttpClient) httpClient, clientReq, endpoint, options);
return executeSync((SdkHttpClient) httpClient, clientReq, requestBody, endpoint, options);
} else if (httpClient instanceof SdkAsyncHttpClient) {
try {
return executeAsync((SdkAsyncHttpClient) httpClient, clientReq, requestBody, endpoint, options).get();
Expand Down Expand Up @@ -229,11 +231,11 @@ public <RequestT, ResponseT, ErrorT> CompletableFuture<ResponseT> performRequest
) {
try {
OpenSearchRequestBodyBuffer requestBody = prepareRequestBody(request, endpoint, options);
SdkHttpFullRequest clientReq = prepareRequest(request, endpoint, options, requestBody);
SdkHttpRequest clientReq = prepareRequest(request, endpoint, options, requestBody);
if (httpClient instanceof SdkAsyncHttpClient) {
return executeAsync((SdkAsyncHttpClient) httpClient, clientReq, requestBody, endpoint, options);
} else if (httpClient instanceof SdkHttpClient) {
ResponseT result = executeSync((SdkHttpClient) httpClient, clientReq, endpoint, options);
ResponseT result = executeSync((SdkHttpClient) httpClient, clientReq, requestBody, endpoint, options);
return CompletableFuture.completedFuture(result);
} else {
throw new IOException("invalid httpClient: " + httpClient);
Expand Down Expand Up @@ -265,16 +267,12 @@ private <RequestT> OpenSearchRequestBodyBuffer prepareRequestBody(
TransportOptions options
) throws IOException {
if (endpoint.hasRequestBody()) {
final JsonpMapper mapper = Optional.ofNullable(options)
.map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null)
.map(AwsSdk2TransportOptions::mapper)
.orElse(defaultMapper);
final int maxUncompressedSize = or(
Optional.ofNullable(options)
.map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null)
.map(AwsSdk2TransportOptions::requestCompressionSize),
() -> Optional.ofNullable(transportOptions.requestCompressionSize())
).orElse(DEFAULT_REQUEST_COMPRESSION_SIZE);
final JsonpMapper mapper = Optional.ofNullable(
options instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) options) : null
).map(AwsSdk2TransportOptions::mapper).orElse(defaultMapper);
final int maxUncompressedSize = getOption(options, AwsSdk2TransportOptions::requestCompressionSize).orElse(
DEFAULT_REQUEST_COMPRESSION_SIZE
);

OpenSearchRequestBodyBuffer buffer = new OpenSearchRequestBodyBuffer(mapper, maxUncompressedSize);
buffer.addContent(request);
Expand All @@ -284,7 +282,7 @@ private <RequestT> OpenSearchRequestBodyBuffer prepareRequestBody(
return null;
}

private <RequestT> SdkHttpFullRequest prepareRequest(
private <RequestT> SdkHttpRequest prepareRequest(
RequestT request,
Endpoint<RequestT, ?, ?> endpoint,
@CheckForNull TransportOptions options,
Expand Down Expand Up @@ -315,46 +313,57 @@ private <RequestT> SdkHttpFullRequest prepareRequest(
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Invalid request URI: " + url.toString());
}

ContentStreamProvider bodyProvider = body != null ? ContentStreamProvider.fromInputStreamSupplier(body::getInputStream) : null;

applyHeadersPreSigning(req, options, body);

final AwsCredentialsProvider credentials = getOption(options, AwsSdk2TransportOptions::credentials).orElseGet(
DefaultCredentialsProvider::create
);

final Clock signingClock = getOption(options, AwsSdk2TransportOptions::signingClock).orElse(null);

SdkHttpRequest.Builder signedReq = AwsV4HttpSigner.create()
.sign(
b -> b.identity(credentials.resolveCredentials())
.request(req.build())
.payload(bodyProvider)
.putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, this.signingServiceName)
.putProperty(AwsV4HttpSigner.REGION_NAME, this.signingRegion.id())
.putProperty(AwsV4HttpSigner.SIGNING_CLOCK, signingClock)
)
.request()
.toBuilder();

applyHeadersPostSigning(signedReq, body);

return signedReq.build();
}

private void applyHeadersPreSigning(SdkHttpRequest.Builder req, TransportOptions options, OpenSearchRequestBodyBuffer body) {
applyOptionsHeaders(req, transportOptions);
applyOptionsHeaders(req, options);
if (endpoint.hasRequestBody() && body != null) {

if (body != null) {
req.putHeader("Content-Type", body.getContentType());
String encoding = body.getContentEncoding();
if (encoding != null) {
req.putHeader("Content-Encoding", encoding);
}
req.putHeader("Content-Length", String.valueOf(body.getContentLength()));
req.contentStreamProvider(body::getInputStream);
// To add the "X-Amz-Content-Sha256" header, it needs to set as required.
// It is a required header for Amazon OpenSearch Serverless.
req.putHeader("x-amz-content-sha256", "required");
}

boolean responseCompression = or(
Optional.ofNullable(options)
.map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null)
.map(AwsSdk2TransportOptions::responseCompression),
() -> Optional.ofNullable(transportOptions.responseCompression())
).orElse(Boolean.TRUE);
if (responseCompression) {
if (getOption(options, AwsSdk2TransportOptions::responseCompression).orElse(Boolean.TRUE)) {
req.putHeader("Accept-Encoding", "gzip");
} else {
req.removeHeader("Accept-Encoding");
}
}

final AwsCredentialsProvider credentials = or(
Optional.ofNullable(options)
.map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null)
.map(AwsSdk2TransportOptions::credentials),
() -> Optional.ofNullable(transportOptions.credentials())
).orElse(DefaultCredentialsProvider.create());

Aws4SignerParams signerParams = Aws4SignerParams.builder()
.awsCredentials(credentials.resolveCredentials())
.signingName(this.signingServiceName)
.signingRegion(signingRegion)
.build();
return Aws4Signer.create().sign(req.build(), signerParams);
private void applyHeadersPostSigning(SdkHttpRequest.Builder req, OpenSearchRequestBodyBuffer body) {
if (body != null) {
req.putHeader("Content-Length", String.valueOf(body.getContentLength()));
}
}

private void applyOptionsParams(StringBuilder url, TransportOptions options) throws UnsupportedEncodingException {
Expand All @@ -372,7 +381,7 @@ private void applyOptionsParams(StringBuilder url, TransportOptions options) thr
}
}

private void applyOptionsHeaders(SdkHttpFullRequest.Builder builder, TransportOptions options) {
private void applyOptionsHeaders(SdkHttpRequest.Builder builder, TransportOptions options) {
if (options == null) {
return;
}
Expand All @@ -386,14 +395,14 @@ private void applyOptionsHeaders(SdkHttpFullRequest.Builder builder, TransportOp

private <ResponseT> ResponseT executeSync(
SdkHttpClient syncHttpClient,
SdkHttpFullRequest httpRequest,
SdkHttpRequest httpRequest,
OpenSearchRequestBodyBuffer requestBody,
Endpoint<?, ResponseT, ?> endpoint,
TransportOptions options
) throws IOException {

HttpExecuteRequest.Builder executeRequest = HttpExecuteRequest.builder().request(httpRequest);
if (httpRequest.contentStreamProvider().isPresent()) {
executeRequest.contentStreamProvider(httpRequest.contentStreamProvider().get());
if (requestBody != null) {
executeRequest.contentStreamProvider(ContentStreamProvider.fromInputStreamSupplier(requestBody::getInputStream));
}
HttpExecuteResponse executeResponse = syncHttpClient.prepareRequest(executeRequest.build()).call();
AbortableInputStream bodyStream = null;
Expand All @@ -418,13 +427,12 @@ private <ResponseT> ResponseT executeSync(

private <ResponseT> CompletableFuture<ResponseT> executeAsync(
SdkAsyncHttpClient asyncHttpClient,
SdkHttpFullRequest httpRequest,
SdkHttpRequest httpRequest,
@CheckForNull OpenSearchRequestBodyBuffer requestBody,
Endpoint<?, ResponseT, ?> endpoint,
TransportOptions options
) {
byte[] requestBodyArray = requestBody == null ? NO_BYTES : requestBody.getByteArray();

final AsyncCapturingResponseHandler responseHandler = new AsyncCapturingResponseHandler();
AsyncExecuteRequest.Builder executeRequest = AsyncExecuteRequest.builder()
.request(httpRequest)
Expand Down Expand Up @@ -463,10 +471,9 @@ private <ResponseT, ErrorT> ResponseT parseResponse(
@Nonnull Endpoint<?, ResponseT, ErrorT> endpoint,
@CheckForNull TransportOptions options
) throws IOException {
final JsonpMapper mapper = Optional.ofNullable(options)
.map(o -> o instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) o) : null)
.map(AwsSdk2TransportOptions::mapper)
.orElse(defaultMapper);
final JsonpMapper mapper = Optional.ofNullable(
options instanceof AwsSdk2TransportOptions ? ((AwsSdk2TransportOptions) options) : null
).map(AwsSdk2TransportOptions::mapper).orElse(defaultMapper);

int statusCode = httpResponse.statusCode();
boolean isZipped = httpResponse.firstMatchingHeader("Content-Encoding").map(enc -> enc.contains("gzip")).orElse(Boolean.FALSE);
Expand Down Expand Up @@ -625,16 +632,15 @@ private <ResponseT, ErrorT> ResponseT decodeResponse(
}
}

private static <T> Optional<T> or(Optional<T> opt, Supplier<? extends Optional<? extends T>> supplier) {
Objects.requireNonNull(opt);
Objects.requireNonNull(supplier);
if (opt.isPresent()) {
return opt;
} else {
@SuppressWarnings("unchecked")
Optional<T> r = (Optional<T>) supplier.get();
return Objects.requireNonNull(r);
}
private <T> Optional<T> getOption(@Nullable TransportOptions options, @Nonnull Function<AwsSdk2TransportOptions, T> getter) {
Objects.requireNonNull(getter, "getter must not be null");

Function<AwsSdk2TransportOptions, Optional<T>> optGetter = o -> Optional.ofNullable(getter.apply(o));

Optional<T> opt = Optional.ofNullable(options instanceof AwsSdk2TransportOptions ? (AwsSdk2TransportOptions) options : null)
.flatMap(optGetter);

return opt.isPresent() ? opt : optGetter.apply(transportOptions);
}

private static ByteArrayInputStream toByteArrayInputStream(InputStream is) throws IOException {
Expand Down
Loading

0 comments on commit ac70834

Please sign in to comment.