Skip to content

Commit

Permalink
Merge pull request #15430 from cdapio/cherrypick/CDAP-20897-610
Browse files Browse the repository at this point in the history
[CDAP-20897] Cherry-pick 6.10
  • Loading branch information
dli357 authored Nov 16, 2023
2 parents e0900cf + fac4b29 commit 87294ce
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.ConnectException;
import java.net.HttpURLConnection;
Expand Down Expand Up @@ -298,17 +299,23 @@ private String exchangeTokenViaSts(String token, String scopes, String audience)

Map<String, String> headers = new HashMap<>();
headers.put(HttpHeaders.CONTENT_TYPE, "application/json");
return executeHttpPostRequest(url, securityTokenServiceRequestJson, headers);
return executeHttpPostRequest(() -> (HttpURLConnection) url.openConnection(),
securityTokenServiceRequestJson, headers);
}

@VisibleForTesting
interface ConnectionProvider {
HttpURLConnection getConnection() throws IOException;
}

/**
* Executes a http post request with the specified parameters.
*/
@VisibleForTesting
String executeHttpPostRequest(URL url, String body, Map<String, String> headers)
throws IOException {
String executeHttpPostRequest(ConnectionProvider connectionProvider, String body,
Map<String, String> headers) throws IOException {

HttpURLConnection connection = (HttpURLConnection) url.openConnection();
HttpURLConnection connection = connectionProvider.getConnection();
connection.setRequestMethod(HttpMethod.POST);
connection.setUseCaches(false);
for (Map.Entry<String, String> entry : headers.entrySet()) {
Expand All @@ -321,15 +328,26 @@ String executeHttpPostRequest(URL url, String body, Map<String, String> headers)
outputStream.writeBytes(body);
outputStream.flush();
}
InputStream inputStream;
boolean errorResponse = false;
if (connection.getResponseCode() < 200 || connection.getResponseCode() >= 300) {
inputStream = connection.getErrorStream();
errorResponse = true;
} else {
inputStream = connection.getInputStream();
}

StringBuilder response = new StringBuilder();
try (BufferedReader in = new BufferedReader(
new InputStreamReader(connection.getInputStream()))) {
try (BufferedReader in = new BufferedReader(new InputStreamReader(inputStream))) {
String inputLine;
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
}
if (errorResponse) {
throw new IOException(String.format("Failed to call URL %s with code; response code %d:\n%s",
connection.getURL(), connection.getResponseCode(), response.toString()));
}
return response.toString();
}

Expand All @@ -346,6 +364,7 @@ private String fetchIamServiceAccountToken(String token, String scopes,
headers.put(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", token));
headers.put(HttpHeaders.CONTENT_TYPE, "application/json");
String generateAccessTokenRequestJson = GSON.toJson(credentialsGenerateAccessTokenRequest);
return executeHttpPostRequest(url, generateAccessTokenRequestJson, headers);
return executeHttpPostRequest(() -> (HttpURLConnection) url.openConnection(),
generateAccessTokenRequestJson, headers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@

package io.cdap.cdap.security.spi.credential;

import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.proto.BasicThrowable;
Expand All @@ -31,8 +41,15 @@
import io.kubernetes.client.openapi.ApiResponse;
import io.kubernetes.client.openapi.models.AuthenticationV1TokenRequest;
import io.kubernetes.client.openapi.models.V1TokenRequestStatus;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ConnectException;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;
Expand All @@ -41,12 +58,12 @@
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/**
* Unit Tests for {@link GcpWorkloadIdentityCredentialProvider}.
*/
public class GcpWorkloadIdentityCredentialProviderTest {

private static final Gson GSON = new GsonBuilder().registerTypeAdapter(BasicThrowable.class,
new BasicThrowableCodec()).create();
private static final String IAM_TOKEN = "iam-token";
Expand All @@ -63,17 +80,15 @@ public void testProvisioningCredentialWithRetries() throws Exception {
gcpWorkloadIdentityCredentialProvider.initialize(credentialProviderContext);

GcpWorkloadIdentityCredentialProvider mockedCredentialProvider =
Mockito.spy(gcpWorkloadIdentityCredentialProvider);
spy(gcpWorkloadIdentityCredentialProvider);

Mockito
.doThrow(new SocketTimeoutException())
doThrow(new SocketTimeoutException())
.doThrow(new ConnectException())
.doReturn(getSecurityTokenServiceResponse())
.doReturn(getIamCredentialGenerateAccessTokenResponse())
.when(mockedCredentialProvider).executeHttpPostRequest(Mockito.any(), Mockito.anyString(),
Mockito.any());
.when(mockedCredentialProvider).executeHttpPostRequest(any(), anyString(), any());

Mockito.doReturn(getMockApiClient()).when(mockedCredentialProvider).getApiClient();
doReturn(getMockApiClient()).when(mockedCredentialProvider).getApiClient();

CredentialProfile credentialProfile = new CredentialProfile(
GcpWorkloadIdentityCredentialProvider.NAME, "profile", Collections.emptyMap());
Expand All @@ -92,7 +107,7 @@ public void testProvisioningCredentialWithRetries() throws Exception {
Assert.assertEquals(credential.getExpiration().toString(), EXPIRES_IN);
// twice per invocation of
// {@link GcpWorkloadIdentityCredentialProvider#getProvisionedCredential}
Mockito.verify(mockedCredentialProvider.getApiClient(), Mockito.times(8));
verify(mockedCredentialProvider, times(4)).getApiClient();
}

private String getSecurityTokenServiceResponse() {
Expand All @@ -109,13 +124,12 @@ private String getIamCredentialGenerateAccessTokenResponse() {
}

private ApiClient getMockApiClient() throws ApiException {
ApiClient mockApiClient = Mockito.mock(ApiClient.class);
Mockito.when(mockApiClient.escapeString(Mockito.anyString())).thenCallRealMethod();
ApiClient mockApiClient = mock(ApiClient.class);
when(mockApiClient.escapeString(anyString())).thenCallRealMethod();
ApiResponse<Object> apiResponse = getAuthenticationTokenRequestResponse();
Mockito
.doThrow(new ApiException(500, "Service is unavailable"))
doThrow(new ApiException(500, "Service is unavailable"))
.doReturn(apiResponse)
.when(mockApiClient).execute(Mockito.any(), Mockito.any());
.when(mockApiClient).execute(any(), any());
return mockApiClient;
}

Expand All @@ -128,8 +142,8 @@ private V1TokenRequestStatus getV1TokenRequestStatus() {

private ApiResponse<Object> getAuthenticationTokenRequestResponse() {
AuthenticationV1TokenRequest authenticationV1TokenRequest =
Mockito.mock(AuthenticationV1TokenRequest.class);
Mockito.doReturn(getV1TokenRequestStatus()).when(authenticationV1TokenRequest).getStatus();
mock(AuthenticationV1TokenRequest.class);
doReturn(getV1TokenRequestStatus()).when(authenticationV1TokenRequest).getStatus();
return new
ApiResponse<Object>(200, Collections.emptyMap(), authenticationV1TokenRequest);
}
Expand Down Expand Up @@ -162,4 +176,20 @@ public void testInvalidProfile() throws ProfileValidationException {
"unknown-provider", "profile", Collections.emptyMap());
gcpWorkloadIdentityCredentialProvider.validateProfile(invalidCredentialProfile);
}

@Test(expected = IOException.class)
public void testExecuteHttpPostRequestHandlesHTTPErrorResponse() throws Exception {
InputStream errorMessageStream = new ByteArrayInputStream(
"Some error here".getBytes(StandardCharsets.UTF_8));
OutputStream outputStream = new ByteArrayOutputStream();
HttpURLConnection mockConnection = mock(HttpURLConnection.class);
when(mockConnection.getResponseCode()).thenReturn(500);
when(mockConnection.getOutputStream()).thenReturn(outputStream);
when(mockConnection.getInputStream()).thenThrow(new AssertionError("wrong exception"));
when(mockConnection.getErrorStream()).thenReturn(errorMessageStream);
GcpWorkloadIdentityCredentialProvider gcpWorkloadIdentityCredentialProvider =
new GcpWorkloadIdentityCredentialProvider();
gcpWorkloadIdentityCredentialProvider
.executeHttpPostRequest(() -> mockConnection, "some body", new HashMap<>());
}
}

0 comments on commit 87294ce

Please sign in to comment.