Skip to content

Commit

Permalink
Fix for AwsSdk2Transport error handling (#1068)
Browse files Browse the repository at this point in the history
Signed-off-by: Wesley Workman <[email protected]>
  • Loading branch information
workmanw committed Jul 17, 2024
1 parent 0cf11db commit 6f2ec50
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 68 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ This section is for maintaining a changelog for all breaking changes for the cli

### Fixed
- Deserialize aggregation containing a parent aggregation ([#706](https://github.com/opensearch-project/opensearch-java/pull/706))
- Deserialize error response when using AwsSdk2Transport ([#1068](https://github.com/opensearch-project/opensearch-java/pull/1068))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@

import jakarta.json.JsonObject;
import jakarta.json.stream.JsonParser;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.io.*;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.net.URI;
Expand Down Expand Up @@ -48,6 +45,7 @@
import org.opensearch.client.transport.TransportOptions;
import org.opensearch.client.transport.endpoints.BooleanEndpoint;
import org.opensearch.client.transport.endpoints.BooleanResponse;
import org.opensearch.client.util.MissingRequiredPropertyException;
import org.opensearch.client.util.OpenSearchRequestBodyBuffer;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
Expand Down Expand Up @@ -534,10 +532,17 @@ private <ResponseT, ErrorT> ResponseT parseResponse(
if (errorDeserializer == null || bodyStream == null) {
throw new TransportException("Request failed with status code '" + statusCode + "'");
}

// We may have to reset if there is a parse deserialization exception
bodyStream = toByteArrayInputStream(bodyStream);

try {
try (JsonParser parser = mapper.jsonProvider().createParser(bodyStream)) {
ErrorT error = errorDeserializer.deserialize(parser, mapper);
throw new OpenSearchException((ErrorResponse) error);
} catch (MissingRequiredPropertyException errorEx) {
bodyStream.reset();
return decodeResponse(uri, method, protocol, httpResponse, bodyStream, endpoint, mapper);
}
} catch (OpenSearchException e) {
throw e;
Expand All @@ -551,57 +556,68 @@ private <ResponseT, ErrorT> ResponseT parseResponse(
}
}
} else {
if (endpoint instanceof BooleanEndpoint) {
BooleanEndpoint<?> bep = (BooleanEndpoint<?>) endpoint;
@SuppressWarnings("unchecked")
ResponseT response = (ResponseT) new BooleanResponse(bep.getResult(statusCode));
return response;
} else if (endpoint instanceof JsonEndpoint) {
JsonEndpoint<?, ResponseT, ?> jsonEndpoint = (JsonEndpoint<?, ResponseT, ?>) endpoint;
// Successful response
ResponseT response = null;
JsonpDeserializer<ResponseT> responseParser = jsonEndpoint.responseDeserializer();
if (responseParser != null) {
// Expecting a body
if (bodyStream == null) {
throw new TransportException("Expecting a response body, but none was sent");
}
try (JsonParser parser = mapper.jsonProvider().createParser(bodyStream)) {
try {
response = responseParser.deserialize(parser, mapper);
} catch (NullPointerException e) {
response = responseParser.deserialize(parser, mapper);
}
}
;
}
return response;
} else if (endpoint instanceof GenericEndpoint) {
@SuppressWarnings("unchecked")
final GenericEndpoint<?, ResponseT> rawEndpoint = (GenericEndpoint<?, ResponseT>) endpoint;
return decodeResponse(uri, method, protocol, httpResponse, bodyStream, endpoint, mapper);
}
}

String contentType = null;
if (bodyStream != null) {
contentType = httpResponse.firstMatchingHeader(Header.CONTENT_TYPE).orElse(null);
private <ResponseT, ErrorT> ResponseT decodeResponse(
URI uri,
@Nonnull SdkHttpMethod method,
String protocol,
@Nonnull SdkHttpResponse httpResponse,
@CheckForNull InputStream bodyStream,
@Nonnull Endpoint<?, ResponseT, ErrorT> endpoint,
JsonpMapper mapper
) throws IOException {
if (endpoint instanceof BooleanEndpoint) {
BooleanEndpoint<?> bep = (BooleanEndpoint<?>) endpoint;
@SuppressWarnings("unchecked")
ResponseT response = (ResponseT) new BooleanResponse(bep.getResult(httpResponse.statusCode()));
return response;
} else if (endpoint instanceof JsonEndpoint) {
JsonEndpoint<?, ResponseT, ?> jsonEndpoint = (JsonEndpoint<?, ResponseT, ?>) endpoint;
// Successful response
ResponseT response = null;
JsonpDeserializer<ResponseT> responseParser = jsonEndpoint.responseDeserializer();
if (responseParser != null) {
// Expecting a body
if (bodyStream == null) {
throw new TransportException("Expecting a response body, but none was sent");
}
try (JsonParser parser = mapper.jsonProvider().createParser(bodyStream)) {
try {
response = responseParser.deserialize(parser, mapper);
} catch (NullPointerException e) {
response = responseParser.deserialize(parser, mapper);
}
}
}
return response;
} else if (endpoint instanceof GenericEndpoint) {
@SuppressWarnings("unchecked")
final GenericEndpoint<?, ResponseT> rawEndpoint = (GenericEndpoint<?, ResponseT>) endpoint;

return rawEndpoint.responseDeserializer(
uri.toString(),
method.name(),
protocol,
httpResponse.statusCode(),
httpResponse.statusText().orElse(null),
httpResponse.headers()
.entrySet()
.stream()
.map(h -> new AbstractMap.SimpleEntry<String, String>(h.getKey(), Objects.toString(h.getValue())))
.collect(Collectors.toList()),
contentType,
bodyStream
);
} else {
throw new TransportException("Unhandled endpoint type: '" + endpoint.getClass().getName() + "'");
String contentType = null;
if (bodyStream != null) {
contentType = httpResponse.firstMatchingHeader(Header.CONTENT_TYPE).orElse(null);
}

return rawEndpoint.responseDeserializer(
uri.toString(),
method.name(),
protocol,
httpResponse.statusCode(),
httpResponse.statusText().orElse(null),
httpResponse.headers()
.entrySet()
.stream()
.map(h -> new AbstractMap.SimpleEntry<String, String>(h.getKey(), Objects.toString(h.getValue())))
.collect(Collectors.toList()),
contentType,
bodyStream
);
} else {
throw new TransportException("Unhandled endpoint type: '" + endpoint.getClass().getName() + "'");
}
}

Expand All @@ -617,6 +633,22 @@ private static <T> Optional<T> or(Optional<T> opt, Supplier<? extends Optional<?
}
}

private static ByteArrayInputStream toByteArrayInputStream(InputStream is) throws IOException {
// Optimization to avoid copying when applicable. `executeAsync` will produce `ByteArrayInputStream`.
if (is instanceof ByteArrayInputStream) {
return (ByteArrayInputStream) is;
}

ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[1024];
int len;
while ((len = is.read(buffer)) > -1) {
baos.write(buffer, 0, len);
}
baos.flush();
return new ByteArrayInputStream(baos.toByteArray());
}

/**
* Wrap the exception so the caller's signature shows up in the stack trace, taking care to copy the original type and message
* where possible so async and sync code don't have to check different exceptions.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch.integTest.aws;

import java.util.concurrent.CompletableFuture;
import org.junit.Assert;
import org.junit.Test;
import org.opensearch.client.opensearch.OpenSearchAsyncClient;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch.core.*;

public class AwsSdk2GetRequestIT extends AwsSdk2TransportTestCase {
@Test
public void testSyncGetRequest() throws Exception {
resetTestIndex(false);
final OpenSearchClient client = getClient(false, null, null);

SimplePojo doc1 = new SimplePojo("Document 1", "The text of document 1");
addDoc(client, "id1", doc1);

Thread.sleep(1000);

GetRequest doc1Request = new GetRequest.Builder().index(TEST_INDEX).id("id1").build();
GetResponse<SimplePojo> doc1Response = client.get(doc1Request, SimplePojo.class);
Assert.assertTrue(doc1Response.found());

GetRequest doc2Request = new GetRequest.Builder().index(TEST_INDEX).id("does-not-exist").build();
GetResponse<SimplePojo> doc2Response = client.get(doc2Request, SimplePojo.class);
Assert.assertFalse(doc2Response.found());
}

@Test
public void testAsyncGetRequest() throws Exception {
resetTestIndex(false);
final OpenSearchAsyncClient client = getAsyncClient(false, null, null);

SimplePojo doc1 = new SimplePojo("Document 1", "The text of document 1");
addDoc(client, "id1", doc1).join();

Thread.sleep(1000);

GetRequest doc1Request = new GetRequest.Builder().index(TEST_INDEX).id("id1").build();
CompletableFuture<GetResponse<SimplePojo>> doc1ResponseFuture = client.get(doc1Request, SimplePojo.class);

GetRequest doc2Request = new GetRequest.Builder().index(TEST_INDEX).id("does-not-exist").build();
CompletableFuture<GetResponse<SimplePojo>> doc2ResponseFuture = client.get(doc2Request, SimplePojo.class);

Assert.assertTrue(doc1ResponseFuture.join().found());
Assert.assertFalse(doc2ResponseFuture.join().found());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.client.opensearch.OpenSearchAsyncClient;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch._types.OpenSearchException;
import org.opensearch.client.opensearch.core.IndexRequest;
import org.opensearch.client.opensearch.core.IndexResponse;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.indices.CreateIndexRequest;
Expand Down Expand Up @@ -105,22 +104,6 @@ void testClientAsync(boolean async) throws Exception {
Assert.assertEquals(doc1, response.hits().hits().get(0).source());
}

private void addDoc(OpenSearchClient client, String id, SimplePojo doc) throws Exception {
IndexRequest.Builder<SimplePojo> req = new IndexRequest.Builder<SimplePojo>().index(TEST_INDEX).document(doc).id(id);
client.index(req.build());
}

private CompletableFuture<IndexResponse> addDoc(OpenSearchAsyncClient client, String id, SimplePojo doc) {
IndexRequest.Builder<SimplePojo> req = new IndexRequest.Builder<SimplePojo>().index(TEST_INDEX).document(doc).id(id);
try {
return client.index(req.build());
} catch (Exception e) {
final CompletableFuture<IndexResponse> failed = new CompletableFuture<>();
failed.completeExceptionally(e);
return failed;
}
}

@Test
public void testDoubleWrappedException() throws Exception {
// ensure the test index exists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.opensearch.client.opensearch._types.SortOptions;
import org.opensearch.client.opensearch._types.SortOrder;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch.core.IndexRequest;
import org.opensearch.client.opensearch.core.IndexResponse;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.indices.CreateIndexRequest;
Expand Down Expand Up @@ -215,6 +217,22 @@ public void resetTestIndex(boolean async) throws Exception {
client.create(req.build());
}

protected void addDoc(OpenSearchClient client, String id, SimplePojo doc) throws Exception {
IndexRequest.Builder<SimplePojo> req = new IndexRequest.Builder<SimplePojo>().index(TEST_INDEX).document(doc).id(id);
client.index(req.build());
}

protected CompletableFuture<IndexResponse> addDoc(OpenSearchAsyncClient client, String id, SimplePojo doc) {
IndexRequest.Builder<SimplePojo> req = new IndexRequest.Builder<SimplePojo>().index(TEST_INDEX).document(doc).id(id);
try {
return client.index(req.build());
} catch (Exception e) {
final CompletableFuture<IndexResponse> failed = new CompletableFuture<>();
failed.completeExceptionally(e);
return failed;
}
}

protected SearchResponse<SimplePojo> query(OpenSearchClient client, String title, String text) throws Exception {
final Query query = Query.of(qb -> {
if (title != null) {
Expand Down

0 comments on commit 6f2ec50

Please sign in to comment.