diff --git a/src/main/java/com/firebolt/jdbc/client/query/StatementClientImpl.java b/src/main/java/com/firebolt/jdbc/client/query/StatementClientImpl.java index 7933768d6..c58bc9765 100644 --- a/src/main/java/com/firebolt/jdbc/client/query/StatementClientImpl.java +++ b/src/main/java/com/firebolt/jdbc/client/query/StatementClientImpl.java @@ -75,15 +75,13 @@ public InputStream executeSqlStatement(@NonNull StatementInfoWrapper statementIn return executeSqlStatementWithRetryOnUnauthorized(statementInfoWrapper, connectionProperties, formattedStatement, uri); } catch (FireboltException e) { throw e; + } catch (StreamResetException e) { + String errorMessage = format("Error executing statement with id %s: %s", statementInfoWrapper.getId(), formattedStatement); + throw new FireboltException(errorMessage, e, ExceptionType.CANCELED); } catch (Exception e) { - String errorMessage = format("Error executing statement with id %s: %s", - statementInfoWrapper.getId(), formattedStatement); - if (e instanceof StreamResetException) { - throw new FireboltException(errorMessage, e, ExceptionType.CANCELED); - } + String errorMessage = format("Error executing statement with id %s: %s", statementInfoWrapper.getId(), formattedStatement); throw new FireboltException(errorMessage, e); } - } private InputStream executeSqlStatementWithRetryOnUnauthorized(@NonNull StatementInfoWrapper statementInfoWrapper, diff --git a/src/test/java/com/firebolt/jdbc/client/query/StatementClientImplTest.java b/src/test/java/com/firebolt/jdbc/client/query/StatementClientImplTest.java index f951cd9ba..903538225 100644 --- a/src/test/java/com/firebolt/jdbc/client/query/StatementClientImplTest.java +++ b/src/test/java/com/firebolt/jdbc/client/query/StatementClientImplTest.java @@ -8,10 +8,16 @@ import com.firebolt.jdbc.statement.StatementInfoWrapper; import com.firebolt.jdbc.statement.StatementUtil; import lombok.NonNull; -import okhttp3.*; +import okhttp3.Call; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; import okio.Buffer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -29,7 +35,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class StatementClientImplTest { @@ -134,6 +144,25 @@ void shouldNotRetryNoMoreThanOnceOnUnauthorized() throws IOException, FireboltEx verify(connection, times(2)).removeExpiredTokens(); } + @ParameterizedTest + @CsvSource({ + "java.io.IOException, ERROR", + "okhttp3.internal.http2.StreamResetException, CANCELED", + "java.lang.IllegalArgumentException, ERROR", + }) + void shouldThrowIOException(Class exceptionClass, ExceptionType exceptionType) throws IOException { + FireboltProperties fireboltProperties = FireboltProperties.builder().database("db1").compress(true) + .host("firebolt1").port(555).build(); + Call call = mock(Call.class); + when(call.execute()).thenThrow(exceptionClass); + when(okHttpClient.newCall(any())).thenReturn(call); + StatementClient statementClient = new StatementClientImpl(okHttpClient, mock(ObjectMapper.class), connection, "", ""); + StatementInfoWrapper statementInfoWrapper = StatementUtil.parseToStatementInfoWrappers("select 1").get(0); + FireboltException ex = assertThrows(FireboltException.class, () -> statementClient.executeSqlStatement(statementInfoWrapper, fireboltProperties, false, 5, true)); + assertEquals(exceptionType, ex.getType()); + assertEquals(exceptionClass, ex.getCause().getClass()); + } + private Call getMockedCallWithResponse(int statusCode) throws IOException { Call call = mock(Call.class); Response response = mock(Response.class);