diff --git a/cdap-credential-ext-gcp-wi/src/main/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProvider.java b/cdap-credential-ext-gcp-wi/src/main/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProvider.java index ab380073b261..72df8739cf16 100644 --- a/cdap-credential-ext-gcp-wi/src/main/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProvider.java +++ b/cdap-credential-ext-gcp-wi/src/main/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProvider.java @@ -42,6 +42,7 @@ import java.io.BufferedReader; import java.io.DataOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; import java.net.ConnectException; import java.net.HttpURLConnection; @@ -298,17 +299,23 @@ private String exchangeTokenViaSts(String token, String scopes, String audience) Map headers = new HashMap<>(); headers.put(HttpHeaders.CONTENT_TYPE, "application/json"); - return executeHttpPostRequest(url, securityTokenServiceRequestJson, headers); + return executeHttpPostRequest(() -> (HttpURLConnection) url.openConnection(), + securityTokenServiceRequestJson, headers); + } + + @VisibleForTesting + interface ConnectionProvider { + HttpURLConnection getConnection() throws IOException; } /** * Executes a http post request with the specified parameters. */ @VisibleForTesting - String executeHttpPostRequest(URL url, String body, Map headers) - throws IOException { + String executeHttpPostRequest(ConnectionProvider connectionProvider, String body, + Map headers) throws IOException { - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + HttpURLConnection connection = connectionProvider.getConnection(); connection.setRequestMethod(HttpMethod.POST); connection.setUseCaches(false); for (Map.Entry entry : headers.entrySet()) { @@ -321,15 +328,26 @@ String executeHttpPostRequest(URL url, String body, Map headers) outputStream.writeBytes(body); outputStream.flush(); } + InputStream inputStream; + boolean errorResponse = false; + if (connection.getResponseCode() < 200 || connection.getResponseCode() >= 300) { + inputStream = connection.getErrorStream(); + errorResponse = true; + } else { + inputStream = connection.getInputStream(); + } StringBuilder response = new StringBuilder(); - try (BufferedReader in = new BufferedReader( - new InputStreamReader(connection.getInputStream()))) { + try (BufferedReader in = new BufferedReader(new InputStreamReader(inputStream))) { String inputLine; while ((inputLine = in.readLine()) != null) { response.append(inputLine); } } + if (errorResponse) { + throw new IOException(String.format("Failed to call URL %s with code; response code %d:\n%s", + connection.getURL(), connection.getResponseCode(), response.toString())); + } return response.toString(); } @@ -346,6 +364,7 @@ private String fetchIamServiceAccountToken(String token, String scopes, headers.put(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", token)); headers.put(HttpHeaders.CONTENT_TYPE, "application/json"); String generateAccessTokenRequestJson = GSON.toJson(credentialsGenerateAccessTokenRequest); - return executeHttpPostRequest(url, generateAccessTokenRequestJson, headers); + return executeHttpPostRequest(() -> (HttpURLConnection) url.openConnection(), + generateAccessTokenRequestJson, headers); } } diff --git a/cdap-credential-ext-gcp-wi/src/test/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProviderTest.java b/cdap-credential-ext-gcp-wi/src/test/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProviderTest.java index bbd79858e141..8ea355bbe95e 100644 --- a/cdap-credential-ext-gcp-wi/src/test/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProviderTest.java +++ b/cdap-credential-ext-gcp-wi/src/test/java/io/cdap/cdap/security/spi/credential/GcpWorkloadIdentityCredentialProviderTest.java @@ -16,6 +16,16 @@ package io.cdap.cdap.security.spi.credential; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import com.google.gson.Gson; import com.google.gson.GsonBuilder; import io.cdap.cdap.proto.BasicThrowable; @@ -31,8 +41,15 @@ import io.kubernetes.client.openapi.ApiResponse; import io.kubernetes.client.openapi.models.AuthenticationV1TokenRequest; import io.kubernetes.client.openapi.models.V1TokenRequestStatus; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.net.ConnectException; +import java.net.HttpURLConnection; import java.net.SocketTimeoutException; +import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.format.DateTimeFormatter; @@ -41,12 +58,12 @@ import java.util.Map; import org.junit.Assert; import org.junit.Test; -import org.mockito.Mockito; /** * Unit Tests for {@link GcpWorkloadIdentityCredentialProvider}. */ public class GcpWorkloadIdentityCredentialProviderTest { + private static final Gson GSON = new GsonBuilder().registerTypeAdapter(BasicThrowable.class, new BasicThrowableCodec()).create(); private static final String IAM_TOKEN = "iam-token"; @@ -63,17 +80,16 @@ public void testProvisioningCredentialWithRetries() throws Exception { gcpWorkloadIdentityCredentialProvider.initialize(credentialProviderContext); GcpWorkloadIdentityCredentialProvider mockedCredentialProvider = - Mockito.spy(gcpWorkloadIdentityCredentialProvider); + spy(gcpWorkloadIdentityCredentialProvider); - Mockito - .doThrow(new SocketTimeoutException()) + doThrow(new SocketTimeoutException()) .doThrow(new ConnectException()) .doReturn(getSecurityTokenServiceResponse()) .doReturn(getIamCredentialGenerateAccessTokenResponse()) - .when(mockedCredentialProvider).executeHttpPostRequest(Mockito.any(), Mockito.anyString(), - Mockito.any()); + .when(mockedCredentialProvider).executeHttpPostRequest(any(), anyString(), + any()); - Mockito.doReturn(getMockApiClient()).when(mockedCredentialProvider).getApiClient(); + doReturn(getMockApiClient()).when(mockedCredentialProvider).getApiClient(); CredentialProfile credentialProfile = new CredentialProfile( GcpWorkloadIdentityCredentialProvider.NAME, "profile", Collections.emptyMap()); @@ -92,7 +108,7 @@ public void testProvisioningCredentialWithRetries() throws Exception { Assert.assertEquals(credential.getExpiration().toString(), EXPIRES_IN); // twice per invocation of // {@link GcpWorkloadIdentityCredentialProvider#getProvisionedCredential} - Mockito.verify(mockedCredentialProvider.getApiClient(), Mockito.times(8)); + verify(mockedCredentialProvider.getApiClient(), times(8)); } private String getSecurityTokenServiceResponse() { @@ -109,13 +125,12 @@ private String getIamCredentialGenerateAccessTokenResponse() { } private ApiClient getMockApiClient() throws ApiException { - ApiClient mockApiClient = Mockito.mock(ApiClient.class); - Mockito.when(mockApiClient.escapeString(Mockito.anyString())).thenCallRealMethod(); + ApiClient mockApiClient = mock(ApiClient.class); + when(mockApiClient.escapeString(anyString())).thenCallRealMethod(); ApiResponse apiResponse = getAuthenticationTokenRequestResponse(); - Mockito - .doThrow(new ApiException(500, "Service is unavailable")) + doThrow(new ApiException(500, "Service is unavailable")) .doReturn(apiResponse) - .when(mockApiClient).execute(Mockito.any(), Mockito.any()); + .when(mockApiClient).execute(any(), any()); return mockApiClient; } @@ -128,8 +143,8 @@ private V1TokenRequestStatus getV1TokenRequestStatus() { private ApiResponse getAuthenticationTokenRequestResponse() { AuthenticationV1TokenRequest authenticationV1TokenRequest = - Mockito.mock(AuthenticationV1TokenRequest.class); - Mockito.doReturn(getV1TokenRequestStatus()).when(authenticationV1TokenRequest).getStatus(); + mock(AuthenticationV1TokenRequest.class); + doReturn(getV1TokenRequestStatus()).when(authenticationV1TokenRequest).getStatus(); return new ApiResponse(200, Collections.emptyMap(), authenticationV1TokenRequest); } @@ -162,4 +177,20 @@ public void testInvalidProfile() throws ProfileValidationException { "unknown-provider", "profile", Collections.emptyMap()); gcpWorkloadIdentityCredentialProvider.validateProfile(invalidCredentialProfile); } + + @Test(expected = IOException.class) + public void testExecuteHttpPostRequestHandlesHTTPErrorResponse() throws Exception { + InputStream errorMessageStream = new ByteArrayInputStream( + "Some error here".getBytes(StandardCharsets.UTF_8)); + OutputStream outputStream = new ByteArrayOutputStream(); + HttpURLConnection mockConnection = mock(HttpURLConnection.class); + when(mockConnection.getResponseCode()).thenReturn(500); + when(mockConnection.getOutputStream()).thenReturn(outputStream); + when(mockConnection.getInputStream()).thenThrow(new AssertionError("wrong exception")); + when(mockConnection.getErrorStream()).thenReturn(errorMessageStream); + GcpWorkloadIdentityCredentialProvider gcpWorkloadIdentityCredentialProvider = + new GcpWorkloadIdentityCredentialProvider(); + gcpWorkloadIdentityCredentialProvider + .executeHttpPostRequest(() -> mockConnection, "some body", new HashMap<>()); + } }