diff --git a/src/main/java/com/uid2/shared/secure/AttestationClientException.java b/src/main/java/com/uid2/shared/secure/AttestationClientException.java new file mode 100644 index 00000000..26c76d77 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/AttestationClientException.java @@ -0,0 +1,12 @@ +package com.uid2.shared.secure; + +public class AttestationClientException extends AttestationException +{ + public AttestationClientException(Throwable cause) { + super(cause, true); + } + + public AttestationClientException(String message) { + super(message, true); + } +} diff --git a/src/main/java/com/uid2/shared/secure/AttestationException.java b/src/main/java/com/uid2/shared/secure/AttestationException.java index 7bff4c8e..e6aa0077 100644 --- a/src/main/java/com/uid2/shared/secure/AttestationException.java +++ b/src/main/java/com/uid2/shared/secure/AttestationException.java @@ -1,10 +1,27 @@ package com.uid2.shared.secure; public class AttestationException extends Exception { + private final boolean isClientError; + + public boolean IsClientError() { + return this.isClientError; + } + + public AttestationException(Throwable cause, boolean isClientError) { + super(cause); + this.isClientError = isClientError; + } + public AttestationException(Throwable cause) { + this(cause, false); + } + + public AttestationException(String cause, boolean isClientError) { super(cause); + this.isClientError = isClientError; } + public AttestationException(String message) { - super(message); + this(message, false); } } diff --git a/src/main/java/com/uid2/shared/secure/AttestationResult.java b/src/main/java/com/uid2/shared/secure/AttestationResult.java index d58ff18e..2e3f239d 100644 --- a/src/main/java/com/uid2/shared/secure/AttestationResult.java +++ b/src/main/java/com/uid2/shared/secure/AttestationResult.java @@ -6,16 +6,27 @@ public class AttestationResult { private final String enclaveId; + private final AttestationClientException attestationClientException; + public AttestationResult(AttestationFailure reasonToFail) { this.failure = reasonToFail; this.publicKey = null; this.enclaveId = "Failed attestation, enclave Id unknown"; + this.attestationClientException = null; + } + + public AttestationResult(AttestationClientException exception) { + this.failure = AttestationFailure.UNKNOWN; + this.publicKey = null; + this.enclaveId = "Failed attestation, enclave Id unknown"; + this.attestationClientException = exception; } public AttestationResult(byte[] publicKey, String enclaveId) { this.failure = AttestationFailure.NONE; this.publicKey = publicKey; this.enclaveId = enclaveId; + this.attestationClientException = null; } public boolean isSuccess() { @@ -25,6 +36,9 @@ public boolean isSuccess() { public AttestationFailure getFailure() { return this.failure; } public String getReason() { + if (this.attestationClientException != null) { + return this.attestationClientException.getMessage(); + } return this.failure.explain(); } diff --git a/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java b/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java index 7e405282..49a84c24 100644 --- a/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java +++ b/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java @@ -52,10 +52,15 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler getEnclaveAllowlist() { // Returns // null if validation failed // enclaveId if validation succeed - private String validate(TokenPayload tokenPayload) { + private String validate(TokenPayload tokenPayload) throws Exception { + Exception lastException = null; for (var policyValidator : supportedPolicyValidators) { LOGGER.info("Validating policy... Validator version: " + policyValidator.getVersion()); try { @@ -87,13 +92,20 @@ private String validate(TokenPayload tokenPayload) { LOGGER.info("Validator version: " + policyValidator.getVersion() + ", result: " + enclaveId); if (allowedEnclaveIds.contains(enclaveId)) { - LOGGER.info("Successfully attested OIDC against registered enclaves"); + LOGGER.info("Successfully attested gcp-oidc against registered enclaves"); return enclaveId; + } else { + LOGGER.warn("Got unsupported gcp-oidc enclave id: " + enclaveId); } } catch (Exception ex) { + lastException = ex; LOGGER.warn("Fail to validator version: " + policyValidator.getVersion() + ", error :" + ex.getMessage()); } } + + if(lastException != null){ + throw lastException; + } return null; } } diff --git a/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java index d9d349e9..619cfc91 100644 --- a/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java +++ b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java @@ -5,6 +5,7 @@ import com.google.api.client.util.Clock; import com.google.auth.oauth2.TokenVerifier; import com.google.common.base.Strings; +import com.uid2.shared.secure.AttestationClientException; import com.uid2.shared.secure.AttestationException; import java.io.IOException; @@ -65,7 +66,7 @@ public MaaTokenPayload validate(String tokenString) throws AttestationException tokenVerifier.verify(tokenString); } } catch (TokenVerifier.VerificationException e) { - throw new AttestationException("Fail to validate the token signature, error: " + e.getMessage()); + throw new AttestationClientException("Fail to validate the token signature, error: " + e.getMessage()); } catch (IOException e) { throw new AttestationException("Fail to parse token, error: " + e.getMessage()); } diff --git a/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java index 482d76f7..7948339b 100644 --- a/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java +++ b/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java @@ -1,6 +1,7 @@ package com.uid2.shared.secure.azurecc; import com.google.common.base.Strings; +import com.uid2.shared.secure.AttestationClientException; import com.uid2.shared.secure.AttestationException; public class PolicyValidator implements IPolicyValidator{ @@ -16,11 +17,11 @@ public String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException { if(Strings.isNullOrEmpty(publicKey)){ - throw new AttestationException("public key to check is null or empty"); + throw new AttestationClientException("public key to check is null or empty"); } var runtimePublicKey = maaTokenPayload.getRuntimeData().getPublicKey(); if(!publicKey.equals(runtimePublicKey)){ - throw new AttestationException( + throw new AttestationClientException( String.format("Public key in payload is not match expected value. More info: runtime(%s), expected(%s)", runtimePublicKey, publicKey @@ -30,25 +31,25 @@ private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey) private void verifyVM(MaaTokenPayload maaTokenPayload) throws AttestationException { if(!maaTokenPayload.isSevSnpVM()){ - throw new AttestationException("Not in SevSnp VM"); + throw new AttestationClientException("Not in SevSnp VM"); } if(!maaTokenPayload.isUtilityVMCompliant()){ - throw new AttestationException("Not run in Azure Compliance Utility VM"); + throw new AttestationClientException("Not run in Azure Compliance Utility VM"); } if(maaTokenPayload.isVmDebuggable()){ - throw new AttestationException("The underlying harware should not run in debug mode"); + throw new AttestationClientException("The underlying hardware should not run in debug mode"); } } private void verifyLocation(MaaTokenPayload maaTokenPayload) throws AttestationException { var location = maaTokenPayload.getRuntimeData().getLocation(); if(Strings.isNullOrEmpty(location)){ - throw new AttestationException("Location is not specified."); + throw new AttestationClientException("Location is not specified."); } var lowerCaseLocation = location.toLowerCase(); if(lowerCaseLocation.contains(LOCATION_CHINA) || lowerCaseLocation.contains(LOCATION_EU)){ - throw new AttestationException("Location is not supported. Value: " + location); + throw new AttestationClientException("Location is not supported. Value: " + location); } } } diff --git a/src/main/java/com/uid2/shared/secure/gcpoidc/PolicyValidator.java b/src/main/java/com/uid2/shared/secure/gcpoidc/PolicyValidator.java index 98dd751a..df62fe24 100644 --- a/src/main/java/com/uid2/shared/secure/gcpoidc/PolicyValidator.java +++ b/src/main/java/com/uid2/shared/secure/gcpoidc/PolicyValidator.java @@ -4,6 +4,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.uid2.shared.Utils; +import com.uid2.shared.secure.AttestationClientException; import com.uid2.shared.secure.AttestationException; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; @@ -53,18 +54,18 @@ public String validate(TokenPayload payload) throws AttestationException { private static boolean checkConfidentialSpace(TokenPayload payload) throws AttestationException{ if(!payload.isConfidentialSpaceSW()){ - throw new AttestationException("Unexpected SW_NAME: " + payload.getSwName()); + throw new AttestationClientException("Unexpected SW_NAME: " + payload.getSwName()); } var isDebugMode = payload.isDebugMode(); if(!isDebugMode && !payload.isStableVersion()){ - throw new AttestationException("Confidential space image version is not stable."); + throw new AttestationClientException("Confidential space image version is not stable."); } return isDebugMode; } private static String checkWorkload(TokenPayload payload) throws AttestationException{ if(!payload.isRestartPolicyNever()){ - throw new AttestationException("Restart policy is not set to Never. Value: " + payload.getRestartPolicy()); + throw new AttestationClientException("Restart policy is not set to Never. Value: " + payload.getRestartPolicy()); } return payload.getWorkloadImageDigest(); } @@ -75,35 +76,35 @@ private static String checkWorkload(TokenPayload payload) throws AttestationExce private static String checkRegion(TokenPayload payload) throws AttestationException{ var region = payload.getGceZone(); if(Strings.isNullOrEmpty(region) || region.startsWith(EU_REGION_PREFIX)){ - throw new AttestationException("Region is not supported. Value: " + region); + throw new AttestationClientException("Region is not supported. Value: " + region); } return region; } private static void checkCmdOverrides(TokenPayload payload) throws AttestationException{ if(!CollectionUtils.isEmpty(payload.getCmdOverrides())){ - throw new AttestationException("Payload should not have cmd overrides"); + throw new AttestationClientException("Payload should not have cmd overrides"); } } private Environment checkEnvOverrides(TokenPayload payload) throws AttestationException{ var envOverrides = payload.getEnvOverrides(); if(MapUtils.isEmpty(envOverrides)){ - throw new AttestationException("env overrides should not be empty"); + throw new AttestationClientException("env overrides should not be empty"); } HashMap envOverridesCopy = new HashMap(envOverrides); // check all required env overrides for(var envKey: REQUIRED_ENV_OVERRIDES){ if(Strings.isNullOrEmpty(envOverridesCopy.get(envKey))){ - throw new AttestationException("Required env override is missing. key: " + envKey); + throw new AttestationClientException("Required env override is missing. key: " + envKey); } } // env could be parsed var env = Environment.fromString(envOverridesCopy.get(ENV_ENVIRONMENT)); if(env == null){ - throw new AttestationException("Environment can not be parsed. " + envOverridesCopy.get(ENV_ENVIRONMENT)); + throw new AttestationClientException("Environment can not be parsed. " + envOverridesCopy.get(ENV_ENVIRONMENT)); } // make sure there's no unexpected overrides @@ -118,7 +119,7 @@ private Environment checkEnvOverrides(TokenPayload payload) throws AttestationEx } if(!envOverridesCopy.isEmpty()){ - throw new AttestationException("More env overrides than allowed. " + envOverridesCopy); + throw new AttestationClientException("More env overrides than allowed. " + envOverridesCopy); } return env; diff --git a/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java b/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java index 480c5684..a70606de 100644 --- a/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java +++ b/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java @@ -15,8 +15,7 @@ import java.nio.charset.StandardCharsets; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -45,9 +44,7 @@ class AzureCCAttestationProviderTest { @BeforeEach public void setup() throws AttestationException { when(alwaysPassTokenValidator.validate(any())).thenReturn(VALID_TOKEN_PAYLOAD); - when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("token signature validation failed")); when(alwaysPassPolicyValidator.validate(any(), any())).thenReturn(ENCLAVE_ID); - when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("policy validation failed")); } @Test @@ -61,7 +58,21 @@ public void testHappyPath() throws AttestationException { } @Test - public void testSignatureCheckFailed() throws AttestationException { + public void testSignatureCheckFailed_ClientError() throws AttestationException { + var errorStr = "token signature validation failed"; + when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationClientException(errorStr)); + var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator); + provider.registerEnclave(ENCLAVE_ID); + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(errorStr, ar.result().getReason()); + }); + } + + @Test + public void testSignatureCheckFailed_ServerError() throws AttestationException { + when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("unknown server error")); var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator); provider.registerEnclave(ENCLAVE_ID); attest(provider, ar -> { @@ -71,7 +82,21 @@ public void testSignatureCheckFailed() throws AttestationException { } @Test - public void testPolicyCheckFailed() throws AttestationException { + public void testPolicyCheckFailed_ClientError() throws AttestationException { + var errorStr = "policy validation failed"; + when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr)); + var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysFailPolicyValidator); + provider.registerEnclave(ENCLAVE_ID); + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(errorStr, ar.result().getReason()); + }); + } + + @Test + public void testPolicyCheckFailed_ServerError() throws AttestationException { + when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error")); var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysFailPolicyValidator); provider.registerEnclave(ENCLAVE_ID); attest(provider, ar -> { @@ -84,8 +109,9 @@ public void testPolicyCheckFailed() throws AttestationException { public void testEnclaveNotRegistered() throws AttestationException { var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator); attest(provider, ar -> { - assertFalse(ar.succeeded()); - assertTrue(ar.cause() instanceof AttestationException); + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(AttestationFailure.FORBIDDEN_ENCLAVE, ar.result().getFailure()); }); } diff --git a/src/test/java/com/uid2/shared/secure/GcpOidcAttestationProviderTest.java b/src/test/java/com/uid2/shared/secure/GcpOidcAttestationProviderTest.java index 7757fe1b..7a590f64 100644 --- a/src/test/java/com/uid2/shared/secure/GcpOidcAttestationProviderTest.java +++ b/src/test/java/com/uid2/shared/secure/GcpOidcAttestationProviderTest.java @@ -16,8 +16,7 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -51,37 +50,63 @@ public class GcpOidcAttestationProviderTest { @BeforeEach public void setup() throws AttestationException { when(alwaysPassTokenValidator.validate(any())).thenReturn(VALID_TOKEN_PAYLOAD); - when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("token signature validation failed")); when(alwaysPassPolicyValidator1.validate(any())).thenReturn(ENCLAVE_ID_1); when(alwaysPassPolicyValidator2.validate(any())).thenReturn(ENCLAVE_ID_2); - when(alwaysFailPolicyValidator.validate(any())).thenThrow(new AttestationException("policy validation failed")); } @Test public void testHappyPath() throws AttestationException { var provider = new GcpOidcAttestationProvider(alwaysPassTokenValidator, Arrays.asList(alwaysPassPolicyValidator1)); provider.registerEnclave(ENCLAVE_ID_1); - attest(provider, ar ->{ + attest(provider, ar -> { assertTrue(ar.succeeded()); assertTrue(ar.result().isSuccess()); }); } @Test - public void testSignatureCheckFailed() throws AttestationException { + public void testSignatureCheckFailed_ClientError() throws AttestationException { + var errorStr = "signature validation failed"; + when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationClientException(errorStr)); var provider = new GcpOidcAttestationProvider(alwaysFailTokenValidator, Arrays.asList(alwaysPassPolicyValidator1)); provider.registerEnclave(ENCLAVE_ID_1); - attest(provider, ar ->{ + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(errorStr, ar.result().getReason()); + }); + } + + @Test + public void testSignatureCheckFailed_ServerError() throws AttestationException { + when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("unknown server error")); + var provider = new GcpOidcAttestationProvider(alwaysFailTokenValidator, Arrays.asList(alwaysPassPolicyValidator1)); + provider.registerEnclave(ENCLAVE_ID_1); + attest(provider, ar -> { assertFalse(ar.succeeded()); assertTrue(ar.cause() instanceof AttestationException); }); } @Test - public void testPolicyCheckFailed() throws AttestationException { + public void testPolicyCheckFailed_ClientError() throws AttestationException { + var errorStr = "policy validation failed"; + when(alwaysFailPolicyValidator.validate(any())).thenThrow(new AttestationClientException(errorStr)); var provider = new GcpOidcAttestationProvider(alwaysPassTokenValidator, Arrays.asList(alwaysFailPolicyValidator)); provider.registerEnclave(ENCLAVE_ID_1); - attest(provider, ar ->{ + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(errorStr, ar.result().getReason()); + }); + } + + @Test + public void testPolicyCheckFailed_ServerError() throws AttestationException { + when(alwaysFailPolicyValidator.validate(any())).thenThrow(new AttestationException("unknown server error")); + var provider = new GcpOidcAttestationProvider(alwaysPassTokenValidator, Arrays.asList(alwaysFailPolicyValidator)); + provider.registerEnclave(ENCLAVE_ID_1); + attest(provider, ar -> { assertFalse(ar.succeeded()); assertTrue(ar.cause() instanceof AttestationException); }); @@ -91,9 +116,10 @@ public void testPolicyCheckFailed() throws AttestationException { public void testNoPolicyConfigured() throws AttestationException { var provider = new GcpOidcAttestationProvider(alwaysPassTokenValidator, Arrays.asList()); provider.registerEnclave(ENCLAVE_ID_1); - attest(provider, ar ->{ - assertFalse(ar.succeeded()); - assertTrue(ar.cause() instanceof AttestationException); + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertFalse(ar.result().isSuccess()); + assertEquals(AttestationFailure.FORBIDDEN_ENCLAVE, ar.result().getFailure()); }); } @@ -101,7 +127,7 @@ public void testNoPolicyConfigured() throws AttestationException { public void testMultiplePolicyValidators() throws AttestationException { var provider = new GcpOidcAttestationProvider(alwaysPassTokenValidator, Arrays.asList(alwaysPassPolicyValidator1, alwaysFailPolicyValidator, alwaysPassPolicyValidator2)); provider.registerEnclave(ENCLAVE_ID_2); - attest(provider, ar ->{ + attest(provider, ar -> { assertTrue(ar.succeeded()); assertTrue(ar.result().isSuccess()); });