diff --git a/pom.xml b/pom.xml index 7cbd8dab..58dce090 100644 --- a/pom.xml +++ b/pom.xml @@ -147,6 +147,16 @@ google-cloud-logging 3.13.7 + + com.azure + azure-security-attestation + 1.1.15 + + + com.azure + azure-core-http-netty + 1.13.6 + co.nstant.in cbor diff --git a/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java b/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java new file mode 100644 index 00000000..7e405282 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/AzureCCAttestationProvider.java @@ -0,0 +1,86 @@ +package com.uid2.shared.secure; + +import com.uid2.shared.Utils; +import com.uid2.shared.secure.azurecc.IMaaTokenSignatureValidator; +import com.uid2.shared.secure.azurecc.IPolicyValidator; +import com.uid2.shared.secure.azurecc.MaaTokenSignatureValidator; +import com.uid2.shared.secure.azurecc.PolicyValidator; +import io.vertx.core.AsyncResult; +import io.vertx.core.Future; +import io.vertx.core.Handler; +import lombok.extern.slf4j.Slf4j; + +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; + +// CC stands for Confidential Container +@Slf4j +public class AzureCCAttestationProvider implements IAttestationProvider { + + private final Set allowedEnclaveIds = new HashSet<>(); + + private final IMaaTokenSignatureValidator tokenSignatureValidator; + + private final IPolicyValidator policyValidator; + + public AzureCCAttestationProvider(String maaServerBaseUrl) { + this(new MaaTokenSignatureValidator(maaServerBaseUrl), new PolicyValidator()); + } + + // used in UT + protected AzureCCAttestationProvider(IMaaTokenSignatureValidator tokenSignatureValidator, IPolicyValidator policyValidator) { + this.tokenSignatureValidator = tokenSignatureValidator; + this.policyValidator = policyValidator; + } + + @Override + public void attest(byte[] attestationRequest, byte[] publicKey, Handler> handler) { + try { + var tokenString = new String(attestationRequest, StandardCharsets.US_ASCII); + + log.debug("Validating signature..."); + var tokenPayload = tokenSignatureValidator.validate(tokenString); + + log.debug("Validating policy..."); + var encodedPublicKey = Utils.toBase64String(publicKey); + + var enclaveId = policyValidator.validate(tokenPayload, encodedPublicKey); + + if (allowedEnclaveIds.contains(enclaveId)) { + log.info("Successfully attested azure-cc against registered enclaves, enclave id: " + enclaveId); + handler.handle(Future.succeededFuture(new AttestationResult(publicKey, enclaveId))); + } else { + throw new AttestationException("Unregistered enclave, enclave id: " + enclaveId); + } + } catch (AttestationException ex) { + handler.handle(Future.failedFuture(ex)); + } catch (Exception ex) { + handler.handle(Future.failedFuture(new AttestationException(ex))); + } + } + + @Override + public void registerEnclave(String encodedIdentifier) throws AttestationException { + try { + allowedEnclaveIds.add(encodedIdentifier); + } catch (Exception e) { + throw new AttestationException(e); + } + } + + @Override + public void unregisterEnclave(String encodedIdentifier) throws AttestationException { + try { + allowedEnclaveIds.remove(encodedIdentifier); + } catch (Exception e) { + throw new AttestationException(e); + } + } + + @Override + public Collection getEnclaveAllowlist() { + return allowedEnclaveIds; + } +} diff --git a/src/main/java/com/uid2/shared/secure/GcpOidcAttestationProvider.java b/src/main/java/com/uid2/shared/secure/GcpOidcAttestationProvider.java index 981962d6..73e4c501 100644 --- a/src/main/java/com/uid2/shared/secure/GcpOidcAttestationProvider.java +++ b/src/main/java/com/uid2/shared/secure/GcpOidcAttestationProvider.java @@ -37,10 +37,10 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler getEnclaveAllowlist() { } // Pass as long as one of supported policy validator check pass. - private boolean Validate(TokenPayload tokenPayload) { + // Returns + // null if validation failed + // enclaveId if validation succeed + private String validate(TokenPayload tokenPayload) { for (var policyValidator : supportedPolicyValidators) { LOGGER.info("Validating policy... Validator version: " + policyValidator.getVersion()); try { @@ -85,12 +88,12 @@ private boolean Validate(TokenPayload tokenPayload) { if (allowedEnclaveIds.contains(enclaveId)) { LOGGER.info("Successfully attested OIDC against registered enclaves"); - return true; + return enclaveId; } } catch (Exception ex) { LOGGER.warn("Fail to validator version: " + policyValidator.getVersion() + ", error :" + ex.getMessage()); } } - return false; + return null; } } diff --git a/src/main/java/com/uid2/shared/secure/JwtUtils.java b/src/main/java/com/uid2/shared/secure/JwtUtils.java new file mode 100644 index 00000000..23d80218 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/JwtUtils.java @@ -0,0 +1,25 @@ +package com.uid2.shared.secure; + +import java.util.Map; + +public class JwtUtils { + public static T tryGetField(Map payload, String key, Class clazz){ + if(payload == null){ + return null; + } + var rawValue = payload.get(key); + return tryConvert(rawValue, clazz); + } + + public static T tryConvert(Object obj, Class clazz){ + if(obj == null){ + return null; + } + try{ + return clazz.cast(obj); + } + catch (ClassCastException e){ + return null; + } + } +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/AzurePublicKeyProvider.java b/src/main/java/com/uid2/shared/secure/azurecc/AzurePublicKeyProvider.java new file mode 100644 index 00000000..90693f70 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/AzurePublicKeyProvider.java @@ -0,0 +1,79 @@ +package com.uid2.shared.secure.azurecc; + +import com.azure.security.attestation.AttestationClientBuilder; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableMap; +import com.uid2.shared.secure.AttestationException; + +import java.security.PublicKey; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +// MAA certs are stored as x5c(X.509 certificate chain), not supported by Google auth lib. +// So we have to build a thin layer to fetch Azure public key. +public class AzurePublicKeyProvider implements IPublicKeyProvider { + + private final LoadingCache> publicKeyCache; + + public AzurePublicKeyProvider() { + this.publicKeyCache = CacheBuilder.newBuilder() + .expireAfterWrite(1L, TimeUnit.HOURS) + .build(new CacheLoader<>() { + @Override + public Map load(String maaServerBaseUrl) throws AttestationException { + return loadPublicKeys(maaServerBaseUrl); + } + }); + } + + @Override + public PublicKey GetPublicKey(String maaServerBaseUrl, String kid) throws AttestationException { + PublicKey key; + try { + key = publicKeyCache.get(maaServerBaseUrl).get(kid); + } + catch (ExecutionException e){ + throw new AttestationException( + String.format("Error fetching PublicKey from certificate location: %s, error: %s.", maaServerBaseUrl, e.getMessage()) + ); + } + + if(key == null){ + throw new AttestationException("Could not find PublicKey for provided keyId: " + kid); + } + return key; + } + + // We don't want to reinvent the wheel. Leverage Azure Attestation client library to fetch certs. + private static Map loadPublicKeys(String maaServerBaseUrl) throws AttestationException { + var attestationBuilder = new AttestationClientBuilder(); + var client = attestationBuilder + .endpoint(maaServerBaseUrl) + .buildClient(); + + var signers = client.listAttestationSigners().getAttestationSigners(); + + ImmutableMap.Builder keyCacheBuilder = new ImmutableMap.Builder(); + + for (var signer : signers){ + var keyId = signer.getKeyId(); + var certs = signer.getCertificates(); + + // It's possible that there's a certificate chain. We will use the public key of Leaf Certificate here. + if(!certs.isEmpty()){ + var publicKey = certs.get(0).getPublicKey(); + keyCacheBuilder.put(keyId, publicKey); + } + } + + var map = keyCacheBuilder.build(); + if(map.isEmpty()){ + throw new AttestationException("Fail to load certs from: " + maaServerBaseUrl); + } + + return map; + } +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/IMaaTokenSignatureValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/IMaaTokenSignatureValidator.java new file mode 100644 index 00000000..50041c19 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/IMaaTokenSignatureValidator.java @@ -0,0 +1,14 @@ +package com.uid2.shared.secure.azurecc; + +import com.uid2.shared.secure.AttestationException; + +public interface IMaaTokenSignatureValidator { + /** + * Validate token signature against authorized issuer. + * + * @param tokenString The raw MAA token string. + * @return Parsed token payload. + * @throws AttestationException + */ + MaaTokenPayload validate(String tokenString) throws AttestationException; +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/IPolicyValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/IPolicyValidator.java new file mode 100644 index 00000000..b2869a7e --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/IPolicyValidator.java @@ -0,0 +1,15 @@ +package com.uid2.shared.secure.azurecc; + +import com.uid2.shared.secure.AttestationException; + +public interface IPolicyValidator { + /** + * Validate token payload against defined policies. + * + * @param maaTokenPayload The parsed MAA token. + * @param publicKey The public key info to verify in payload runtime data. + * @return The enclave id. + * @throws AttestationException + */ + String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException; +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/IPublicKeyProvider.java b/src/main/java/com/uid2/shared/secure/azurecc/IPublicKeyProvider.java new file mode 100644 index 00000000..8954203f --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/IPublicKeyProvider.java @@ -0,0 +1,17 @@ +package com.uid2.shared.secure.azurecc; + +import com.uid2.shared.secure.AttestationException; + +import java.security.PublicKey; + +public interface IPublicKeyProvider { + /** + * Get Public Key from a MAA server. + * + * @param maaServerBaseUrl The Base Url of MAA server. + * @param kid The key id. + * @return The public key. + * @throws AttestationException + */ + PublicKey GetPublicKey(String maaServerBaseUrl, String kid) throws AttestationException; +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenPayload.java b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenPayload.java new file mode 100644 index 00000000..a4eceaa3 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenPayload.java @@ -0,0 +1,26 @@ +package com.uid2.shared.secure.azurecc; + +import lombok.Builder; +import lombok.Value; + +@Value +@Builder(toBuilder = true) +public class MaaTokenPayload { + public static final String SEV_SNP_VM_TYPE = "sevsnpvm"; + public static final String AZURE_COMPLIANT_UVM = "azure-compliant-uvm"; + + private String attestationType; + private String complianceStatus; + private boolean vmDebuggable; + private String ccePolicyDigest; + + private RuntimeData runtimeData; + + public boolean isSevSnpVM(){ + return SEV_SNP_VM_TYPE.equalsIgnoreCase(attestationType); + } + + public boolean isUtilityVMCompliant(){ + return AZURE_COMPLIANT_UVM.equalsIgnoreCase(complianceStatus); + } +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java new file mode 100644 index 00000000..d9d349e9 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidator.java @@ -0,0 +1,94 @@ +package com.uid2.shared.secure.azurecc; + +import com.google.api.client.json.gson.GsonFactory; +import com.google.api.client.json.webtoken.JsonWebSignature; +import com.google.api.client.util.Clock; +import com.google.auth.oauth2.TokenVerifier; +import com.google.common.base.Strings; +import com.uid2.shared.secure.AttestationException; + +import java.io.IOException; +import java.util.Map; + +import static com.uid2.shared.secure.JwtUtils.tryGetField; + +public class MaaTokenSignatureValidator implements IMaaTokenSignatureValidator { + + // set to true to facilitate local test. + public static final boolean BYPASS_SIGNATURE_CHECK = false; + + // e.g. https://sharedeus.eus.attest.azure.net + private final String maaServerBaseUrl; + + private final IPublicKeyProvider publicKeyProvider; + + // used in UT + private final Clock clockOverride; + + public MaaTokenSignatureValidator(String maaServerBaseUrl) { + this(maaServerBaseUrl, new AzurePublicKeyProvider(), null); + } + + protected MaaTokenSignatureValidator(String maaServerBaseUrl, IPublicKeyProvider publicKeyProvider, Clock clockOverride) { + this.maaServerBaseUrl = maaServerBaseUrl; + this.publicKeyProvider = publicKeyProvider; + this.clockOverride = clockOverride; + } + + private TokenVerifier buildTokenVerifier(String kid) throws AttestationException { + var verifierBuilder = TokenVerifier.newBuilder(); + + verifierBuilder.setPublicKey(publicKeyProvider.GetPublicKey(maaServerBaseUrl, kid)); + + if (clockOverride != null) { + verifierBuilder.setClock(clockOverride); + } + + verifierBuilder.setIssuer(maaServerBaseUrl); + + return verifierBuilder.build(); + } + + @Override + public MaaTokenPayload validate(String tokenString) throws AttestationException { + if (Strings.isNullOrEmpty(tokenString)) { + throw new IllegalArgumentException("tokenString can not be null or empty"); + } + + // Validate Signature + JsonWebSignature signature; + try { + signature = JsonWebSignature.parse(GsonFactory.getDefaultInstance(), tokenString); + if(!BYPASS_SIGNATURE_CHECK){ + var kid = signature.getHeader().getKeyId(); + var tokenVerifier = buildTokenVerifier(kid); + tokenVerifier.verify(tokenString); + } + } catch (TokenVerifier.VerificationException e) { + throw new AttestationException("Fail to validate the token signature, error: " + e.getMessage()); + } catch (IOException e) { + throw new AttestationException("Fail to parse token, error: " + e.getMessage()); + } + + // Parse Payload + var rawPayload = signature.getPayload(); + + var tokenPayloadBuilder = MaaTokenPayload.builder(); + + tokenPayloadBuilder.attestationType(tryGetField(rawPayload, "x-ms-attestation-type", String.class)); + tokenPayloadBuilder.complianceStatus(tryGetField(rawPayload, "x-ms-compliance-status", String.class)); + tokenPayloadBuilder.vmDebuggable(tryGetField(rawPayload, "x-ms-sevsnpvm-is-debuggable", Boolean.class)); + tokenPayloadBuilder.ccePolicyDigest(tryGetField(rawPayload, "x-ms-sevsnpvm-hostdata", String.class)); + + var runtime = tryGetField(rawPayload, ("x-ms-runtime"), Map.class); + + if(runtime != null){ + var runtimeDataBuilder = RuntimeData.builder(); + runtimeDataBuilder.location(tryGetField(runtime, "location", String.class)); + runtimeDataBuilder.publicKey(tryGetField(runtime, "publicKey", String.class)); + tokenPayloadBuilder.runtimeData(runtimeDataBuilder.build()); + } + + return tokenPayloadBuilder.build(); + } +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java b/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java new file mode 100644 index 00000000..482d76f7 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java @@ -0,0 +1,54 @@ +package com.uid2.shared.secure.azurecc; + +import com.google.common.base.Strings; +import com.uid2.shared.secure.AttestationException; + +public class PolicyValidator implements IPolicyValidator{ + private static final String LOCATION_CHINA = "china"; + private static final String LOCATION_EU = "europe"; + @Override + public String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException { + verifyVM(maaTokenPayload); + verifyLocation(maaTokenPayload); + verifyPublicKey(maaTokenPayload, publicKey); + return maaTokenPayload.getCcePolicyDigest(); + } + + private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException { + if(Strings.isNullOrEmpty(publicKey)){ + throw new AttestationException("public key to check is null or empty"); + } + var runtimePublicKey = maaTokenPayload.getRuntimeData().getPublicKey(); + if(!publicKey.equals(runtimePublicKey)){ + throw new AttestationException( + String.format("Public key in payload is not match expected value. More info: runtime(%s), expected(%s)", + runtimePublicKey, + publicKey + )); + } + } + + private void verifyVM(MaaTokenPayload maaTokenPayload) throws AttestationException { + if(!maaTokenPayload.isSevSnpVM()){ + throw new AttestationException("Not in SevSnp VM"); + } + if(!maaTokenPayload.isUtilityVMCompliant()){ + throw new AttestationException("Not run in Azure Compliance Utility VM"); + } + if(maaTokenPayload.isVmDebuggable()){ + throw new AttestationException("The underlying harware should not run in debug mode"); + } + } + + private void verifyLocation(MaaTokenPayload maaTokenPayload) throws AttestationException { + var location = maaTokenPayload.getRuntimeData().getLocation(); + if(Strings.isNullOrEmpty(location)){ + throw new AttestationException("Location is not specified."); + } + var lowerCaseLocation = location.toLowerCase(); + if(lowerCaseLocation.contains(LOCATION_CHINA) || + lowerCaseLocation.contains(LOCATION_EU)){ + throw new AttestationException("Location is not supported. Value: " + location); + } + } +} diff --git a/src/main/java/com/uid2/shared/secure/azurecc/RuntimeData.java b/src/main/java/com/uid2/shared/secure/azurecc/RuntimeData.java new file mode 100644 index 00000000..e1c67685 --- /dev/null +++ b/src/main/java/com/uid2/shared/secure/azurecc/RuntimeData.java @@ -0,0 +1,11 @@ +package com.uid2.shared.secure.azurecc; + +import lombok.Builder; +import lombok.Value; + +@Value +@Builder(toBuilder = true) +public class RuntimeData { + private String location; + private String publicKey; +} diff --git a/src/main/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidator.java b/src/main/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidator.java index 0de4fb4f..2d0499b2 100644 --- a/src/main/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidator.java +++ b/src/main/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidator.java @@ -13,6 +13,9 @@ import java.util.List; import java.util.Map; +import static com.uid2.shared.secure.JwtUtils.tryConvert; +import static com.uid2.shared.secure.JwtUtils.tryGetField; + public class TokenSignatureValidator implements ITokenSignatureValidator { private static final String PUBLIC_CERT_LOCATION = "https://www.googleapis.com/service_accounts/v1/metadata/jwk/signer@confidentialspace-sign.iam.gserviceaccount.com"; @@ -72,34 +75,34 @@ public TokenPayload validate(String tokenString) throws AttestationException { var tokenPayloadBuilder = TokenPayload.builder(); - tokenPayloadBuilder.dbgStat(TryGetField(rawPayload, "dbgstat", String.class)); - tokenPayloadBuilder.swName(TryGetField(rawPayload, "swname", String.class)); - var swVersion = TryGetField(rawPayload, "swversion", List.class); + tokenPayloadBuilder.dbgStat(tryGetField(rawPayload, "dbgstat", String.class)); + tokenPayloadBuilder.swName(tryGetField(rawPayload, "swname", String.class)); + var swVersion = tryGetField(rawPayload, "swversion", List.class); if(swVersion != null && !swVersion.isEmpty()){ - tokenPayloadBuilder.swVersion(TryConvert(swVersion.get(0), String.class)); + tokenPayloadBuilder.swVersion(tryConvert(swVersion.get(0), String.class)); } - var subModsDetails = TryGetField(rawPayload,"submods", Map.class); + var subModsDetails = tryGetField(rawPayload,"submods", Map.class); if(subModsDetails != null){ - var confidential_space = TryGetField(subModsDetails, "confidential_space", Map.class); + var confidential_space = tryGetField(subModsDetails, "confidential_space", Map.class); if(confidential_space != null){ - tokenPayloadBuilder.csSupportedAttributes(TryGetField(confidential_space, "support_attributes", List.class)); + tokenPayloadBuilder.csSupportedAttributes(tryGetField(confidential_space, "support_attributes", List.class)); } - var container = TryGetField(subModsDetails, "container", Map.class); + var container = tryGetField(subModsDetails, "container", Map.class); if(container != null){ - tokenPayloadBuilder.workloadImageReference(TryGetField(container, "image_reference", String.class)); - tokenPayloadBuilder.workloadImageDigest(TryGetField(container, "image_digest", String.class)); - tokenPayloadBuilder.restartPolicy(TryGetField(container, "restart_policy", String.class)); + tokenPayloadBuilder.workloadImageReference(tryGetField(container, "image_reference", String.class)); + tokenPayloadBuilder.workloadImageDigest(tryGetField(container, "image_digest", String.class)); + tokenPayloadBuilder.restartPolicy(tryGetField(container, "restart_policy", String.class)); - tokenPayloadBuilder.cmdOverrides(TryGetField(container, "cmd_override", ArrayList.class)); - tokenPayloadBuilder.envOverrides(TryGetField(container, "env_override", Map.class)); + tokenPayloadBuilder.cmdOverrides(tryGetField(container, "cmd_override", ArrayList.class)); + tokenPayloadBuilder.envOverrides(tryGetField(container, "env_override", Map.class)); } - var gce = TryGetField(subModsDetails, "gce", Map.class); + var gce = tryGetField(subModsDetails, "gce", Map.class); if(gce != null){ - var gceZone = TryGetField(gce, "zone", String.class); + var gceZone = tryGetField(gce, "zone", String.class); tokenPayloadBuilder.gceZone(gceZone); } } @@ -107,23 +110,5 @@ public TokenPayload validate(String tokenString) throws AttestationException { return tokenPayloadBuilder.build(); } - private static T TryGetField(Map payload, String key, Class clazz){ - if(payload == null){ - return null; - } - var rawValue = payload.get(key); - return TryConvert(rawValue, clazz); - } - private static T TryConvert(Object obj, Class clazz){ - if(obj == null){ - return null; - } - try{ - return clazz.cast(obj); - } - catch (ClassCastException e){ - return null; - } - } } diff --git a/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java b/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java new file mode 100644 index 00000000..480c5684 --- /dev/null +++ b/src/test/java/com/uid2/shared/secure/AzureCCAttestationProviderTest.java @@ -0,0 +1,98 @@ +package com.uid2.shared.secure; + +import com.uid2.shared.secure.azurecc.IMaaTokenSignatureValidator; +import com.uid2.shared.secure.azurecc.IPolicyValidator; +import com.uid2.shared.secure.azurecc.MaaTokenPayload; +import io.vertx.core.AsyncResult; +import io.vertx.core.Handler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +class AzureCCAttestationProviderTest { + private static final String ATTESTATION_REQUEST = "test-attestation-request"; + + private static final String PUBLIC_KEY = "test-public-key"; + + private static final String ENCLAVE_ID = "test-enclave"; + + private static final MaaTokenPayload VALID_TOKEN_PAYLOAD = MaaTokenPayload.builder().build(); + @Mock + private IMaaTokenSignatureValidator alwaysPassTokenValidator; + + @Mock + private IMaaTokenSignatureValidator alwaysFailTokenValidator; + + @Mock + private IPolicyValidator alwaysPassPolicyValidator; + + @Mock + private IPolicyValidator alwaysFailPolicyValidator; + + @BeforeEach + public void setup() throws AttestationException { + when(alwaysPassTokenValidator.validate(any())).thenReturn(VALID_TOKEN_PAYLOAD); + when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("token signature validation failed")); + when(alwaysPassPolicyValidator.validate(any(), any())).thenReturn(ENCLAVE_ID); + when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("policy validation failed")); + } + + @Test + public void testHappyPath() throws AttestationException { + var provider = new AzureCCAttestationProvider(alwaysPassTokenValidator, alwaysPassPolicyValidator); + provider.registerEnclave(ENCLAVE_ID); + attest(provider, ar -> { + assertTrue(ar.succeeded()); + assertTrue(ar.result().isSuccess()); + }); + } + + @Test + public void testSignatureCheckFailed() throws AttestationException { + var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator); + provider.registerEnclave(ENCLAVE_ID); + attest(provider, ar -> { + assertFalse(ar.succeeded()); + assertTrue(ar.cause() instanceof AttestationException); + }); + } + + @Test + public void testPolicyCheckFailed() throws AttestationException { + var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysFailPolicyValidator); + provider.registerEnclave(ENCLAVE_ID); + attest(provider, ar -> { + assertFalse(ar.succeeded()); + assertTrue(ar.cause() instanceof AttestationException); + }); + } + + @Test + public void testEnclaveNotRegistered() throws AttestationException { + var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator); + attest(provider, ar -> { + assertFalse(ar.succeeded()); + assertTrue(ar.cause() instanceof AttestationException); + }); + } + + private static void attest(IAttestationProvider provider, Handler> handler) { + provider.attest( + ATTESTATION_REQUEST.getBytes(StandardCharsets.UTF_8), + PUBLIC_KEY.getBytes(StandardCharsets.UTF_8), + handler); + } +} diff --git a/src/test/java/com/uid2/shared/secure/gcpoidc/TestClock.java b/src/test/java/com/uid2/shared/secure/TestClock.java similarity index 88% rename from src/test/java/com/uid2/shared/secure/gcpoidc/TestClock.java rename to src/test/java/com/uid2/shared/secure/TestClock.java index e198e74d..e2d7b892 100644 --- a/src/test/java/com/uid2/shared/secure/gcpoidc/TestClock.java +++ b/src/test/java/com/uid2/shared/secure/TestClock.java @@ -1,4 +1,4 @@ -package com.uid2.shared.secure.gcpoidc; +package com.uid2.shared.secure; import com.google.api.client.util.Clock; diff --git a/src/test/java/com/uid2/shared/secure/gcpoidc/TestUtils.java b/src/test/java/com/uid2/shared/secure/TestUtils.java similarity index 69% rename from src/test/java/com/uid2/shared/secure/gcpoidc/TestUtils.java rename to src/test/java/com/uid2/shared/secure/TestUtils.java index 72ef2a89..83f93953 100644 --- a/src/test/java/com/uid2/shared/secure/gcpoidc/TestUtils.java +++ b/src/test/java/com/uid2/shared/secure/TestUtils.java @@ -1,4 +1,4 @@ -package com.uid2.shared.secure.gcpoidc; +package com.uid2.shared.secure; import com.google.api.client.json.gson.GsonFactory; import com.google.api.client.json.webtoken.JsonWebSignature; @@ -9,6 +9,8 @@ import com.google.gson.Gson; import com.google.gson.JsonObject; import com.uid2.shared.Const; +import com.uid2.shared.secure.gcpoidc.TokenPayload; +import com.uid2.shared.secure.gcpoidc.TokenSignatureValidator; import java.io.IOException; import java.security.KeyPairGenerator; @@ -16,23 +18,6 @@ import java.security.SecureRandom; public class TestUtils { - public static TokenPayload validateAndParseToken(JsonObject payload, Clock clock) throws Exception{ - var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass); - gen.initialize(2048, new SecureRandom()); - var keyPair = gen.generateKeyPair(); - var privateKey = keyPair.getPrivate(); - var publicKey = keyPair.getPublic(); - - // generate token - var token = generateJwt(payload, privateKey); - - // init TokenSignatureValidator - var tokenVerifier = new TokenSignatureValidator(publicKey, clock); - - // validate token - return tokenVerifier.validate(token); - } - public static String generateJwt(JsonObject payload, PrivateKey privateKey) throws Exception { var jsonFactory = new GsonFactory(); var header = new JsonWebSignature.Header(); diff --git a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java new file mode 100644 index 00000000..de8fe584 --- /dev/null +++ b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenSignatureValidatorTest.java @@ -0,0 +1,42 @@ +package com.uid2.shared.secure.azurecc; + +import com.uid2.shared.secure.AttestationException; +import com.uid2.shared.secure.TestClock; +import org.junit.Ignore; +import org.junit.jupiter.api.Test; + +import static com.uid2.shared.secure.TestUtils.loadFromJson; +import static com.uid2.shared.secure.azurecc.MaaTokenUtils.validateAndParseToken; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MaaTokenSignatureValidatorTest { + @Test + public void testPayload() throws Exception { + // expire at 1695313895 + var payloadPath = "/com.uid2.shared/test/secure/azurecc/jwt_payload.json"; + var payload = loadFromJson(payloadPath); + var clock = new TestClock(); + clock.setCurrentTimeMs(1695313893000L); + + var expectedCcePolicy = "fef932e0103f6132437e8a1223f32efc4bea63342f893b5124645224ef29ba73"; + var expectedLocation = "East US"; + var expectedPublicKey = "abc"; + + var tokenPayload = validateAndParseToken(payload, clock); + assertEquals(true, tokenPayload.isSevSnpVM()); + assertEquals(true, tokenPayload.isUtilityVMCompliant()); + assertEquals(false, tokenPayload.isVmDebuggable()); + assertEquals(expectedCcePolicy, tokenPayload.getCcePolicyDigest()); + assertEquals(expectedLocation, tokenPayload.getRuntimeData().getLocation()); + assertEquals(expectedPublicKey, tokenPayload.getRuntimeData().getPublicKey()); + } + + @Ignore + // replace below Placeholder with real MAA token to run E2E verification. + public void testE2E() throws AttestationException { + var maaToken = ""; + var maaServerUrl = "https://sharedeus.eus.attest.azure.net"; + var validator = new MaaTokenSignatureValidator(maaServerUrl); + var token = validator.validate(maaToken); + } +} diff --git a/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java new file mode 100644 index 00000000..48d1335e --- /dev/null +++ b/src/test/java/com/uid2/shared/secure/azurecc/MaaTokenUtils.java @@ -0,0 +1,49 @@ +package com.uid2.shared.secure.azurecc; + +import com.google.api.client.util.Clock; +import com.google.gson.JsonObject; +import com.uid2.shared.Const; +import com.uid2.shared.secure.AttestationException; + +import java.security.KeyPairGenerator; +import java.security.PublicKey; +import java.security.SecureRandom; + +import static com.uid2.shared.secure.TestUtils.generateJwt; + +public class MaaTokenUtils { + public static final String MAA_BASE_URL = "https://sharedeus.eus.attest.azure.net"; + + public static MaaTokenPayload validateAndParseToken(JsonObject payload, Clock clock) throws Exception{ + var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass); + gen.initialize(2048, new SecureRandom()); + var keyPair = gen.generateKeyPair(); + var privateKey = keyPair.getPrivate(); + var publicKey = keyPair.getPublic(); + + // generate token + var token = generateJwt(payload, privateKey); + + var keyProvider = new MockKeyProvider(publicKey); + + // init TokenSignatureValidator + var tokenVerifier = new MaaTokenSignatureValidator(MAA_BASE_URL, keyProvider, clock); + + // validate token + return tokenVerifier.validate(token); + } + + private static class MockKeyProvider implements IPublicKeyProvider { + + private final PublicKey publicKey; + + MockKeyProvider(PublicKey publicKey){ + this.publicKey = publicKey; + } + + @Override + public PublicKey GetPublicKey(String maaServerBaseUrl, String kid) throws AttestationException { + return this.publicKey; + } + } +} diff --git a/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java b/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java new file mode 100644 index 00000000..3a0d4e68 --- /dev/null +++ b/src/test/java/com/uid2/shared/secure/azurecc/PolicyValidatorTest.java @@ -0,0 +1,95 @@ +package com.uid2.shared.secure.azurecc; + +import com.uid2.shared.secure.AttestationException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class PolicyValidatorTest { + private static final String PUBLIC_KEY = "public_key"; + private static final String CCE_POLICY_DIGEST = "digest"; + + @Test + public void testValidationSuccess() throws AttestationException { + var validator = new PolicyValidator(); + var payload = generateBasicPayload(); + var enclaveId = validator.validate(payload, PUBLIC_KEY); + assertEquals(CCE_POLICY_DIGEST, enclaveId); + } + + @Test + public void testValidationFailure_VMInfo() throws AttestationException { + var validator = new PolicyValidator(); + var newPayload = generateBasicPayload() + .toBuilder() + .attestationType("dummy") + .build(); + assertThrows(AttestationException.class, ()-> validator.validate(newPayload, PUBLIC_KEY)); + } + + @Test + public void testValidationFailure_UVMInfo() throws AttestationException { + var validator = new PolicyValidator(); + var newPayload = generateBasicPayload() + .toBuilder() + .complianceStatus("dummy") + .build(); + assertThrows(AttestationException.class, ()-> validator.validate(newPayload, PUBLIC_KEY)); + } + + @Test + public void testValidationFailure_VMDebug() throws AttestationException { + var validator = new PolicyValidator(); + var newPayload = generateBasicPayload() + .toBuilder() + .vmDebuggable(true) + .build(); + assertThrows(AttestationException.class, ()-> validator.validate(newPayload, PUBLIC_KEY)); + } + + @Test + public void testValidationFailure_PublicKeyNotMatch() throws AttestationException { + var newRunTimeData = generateBasicRuntimeData() + .toBuilder() + .publicKey("dummy") + .build(); + var validator = new PolicyValidator(); + var newPayload = generateBasicPayload() + .toBuilder() + .runtimeData(newRunTimeData) + .build(); + assertThrows(AttestationException.class, ()-> validator.validate(newPayload, PUBLIC_KEY)); + } + + @Test + public void testValidationFailure_LocationNotSupported() throws AttestationException { + var newRunTimeData = generateBasicRuntimeData() + .toBuilder() + .location("West Europe") + .build(); + var validator = new PolicyValidator(); + var newPayload = generateBasicPayload() + .toBuilder() + .runtimeData(newRunTimeData) + .build(); + assertThrows(AttestationException.class, ()-> validator.validate(newPayload, PUBLIC_KEY)); + } + + private MaaTokenPayload generateBasicPayload() { + return MaaTokenPayload.builder() + .attestationType("sevsnpvm") + .complianceStatus("azure-compliant-uvm") + .vmDebuggable(false) + .runtimeData(generateBasicRuntimeData()) + .ccePolicyDigest(CCE_POLICY_DIGEST) + .build(); + } + + private RuntimeData generateBasicRuntimeData(){ + return RuntimeData.builder() + .publicKey(PUBLIC_KEY) + .location("East US") + .build(); + } +} diff --git a/src/test/java/com/uid2/shared/secure/gcpoidc/OidcPayloadValidationTest.java b/src/test/java/com/uid2/shared/secure/gcpoidc/OidcPayloadValidationTest.java index 6ab3a8aa..c3549ba4 100644 --- a/src/test/java/com/uid2/shared/secure/gcpoidc/OidcPayloadValidationTest.java +++ b/src/test/java/com/uid2/shared/secure/gcpoidc/OidcPayloadValidationTest.java @@ -1,9 +1,10 @@ package com.uid2.shared.secure.gcpoidc; +import com.uid2.shared.secure.TestClock; import org.junit.jupiter.api.Test; -import static com.uid2.shared.secure.gcpoidc.TestUtils.loadFromJson; -import static com.uid2.shared.secure.gcpoidc.TestUtils.validateAndParseToken; +import static com.uid2.shared.secure.TestUtils.loadFromJson; +import static com.uid2.shared.secure.gcpoidc.OidcTokenUtils.validateAndParseToken; public class OidcPayloadValidationTest { // E2E to help prevent regression. diff --git a/src/test/java/com/uid2/shared/secure/gcpoidc/OidcTokenUtils.java b/src/test/java/com/uid2/shared/secure/gcpoidc/OidcTokenUtils.java new file mode 100644 index 00000000..4885b371 --- /dev/null +++ b/src/test/java/com/uid2/shared/secure/gcpoidc/OidcTokenUtils.java @@ -0,0 +1,29 @@ +package com.uid2.shared.secure.gcpoidc; + +import com.google.api.client.util.Clock; +import com.google.gson.JsonObject; +import com.uid2.shared.Const; + +import java.security.KeyPairGenerator; +import java.security.SecureRandom; + +import static com.uid2.shared.secure.TestUtils.generateJwt; + +public class OidcTokenUtils { + public static TokenPayload validateAndParseToken(JsonObject payload, Clock clock) throws Exception{ + var gen = KeyPairGenerator.getInstance(Const.Name.AsymetricEncryptionKeyClass); + gen.initialize(2048, new SecureRandom()); + var keyPair = gen.generateKeyPair(); + var privateKey = keyPair.getPrivate(); + var publicKey = keyPair.getPublic(); + + // generate token + var token = generateJwt(payload, privateKey); + + // init TokenSignatureValidator + var tokenVerifier = new TokenSignatureValidator(publicKey, clock); + + // validate token + return tokenVerifier.validate(token); + } +} diff --git a/src/test/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidatorTest.java b/src/test/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidatorTest.java index 3ec64da2..c94d57ec 100644 --- a/src/test/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidatorTest.java +++ b/src/test/java/com/uid2/shared/secure/gcpoidc/TokenSignatureValidatorTest.java @@ -1,12 +1,13 @@ package com.uid2.shared.secure.gcpoidc; import com.uid2.shared.secure.AttestationException; +import com.uid2.shared.secure.TestClock; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; import org.junit.jupiter.api.Test; -import static com.uid2.shared.secure.gcpoidc.TestUtils.loadFromJson; -import static com.uid2.shared.secure.gcpoidc.TestUtils.validateAndParseToken; +import static com.uid2.shared.secure.TestUtils.loadFromJson; +import static com.uid2.shared.secure.gcpoidc.OidcTokenUtils.validateAndParseToken; import static org.junit.jupiter.api.Assertions.*; public class TokenSignatureValidatorTest { diff --git a/src/test/resources/com.uid2.shared/test/secure/azurecc/jwt_payload.json b/src/test/resources/com.uid2.shared/test/secure/azurecc/jwt_payload.json new file mode 100644 index 00000000..26e3f2d7 --- /dev/null +++ b/src/test/resources/com.uid2.shared/test/secure/azurecc/jwt_payload.json @@ -0,0 +1,41 @@ +{ + "exp": 1695313895, + "iat": 1695285095, + "iss": "https://sharedeus.eus.attest.azure.net", + "jti": "3b16f2ab4492417aae4cc9a5e6506ca2519659c0d8fdc2bf442fe01aa9b8e46c", + "nbf": 1695285095, + "nonce": "7394904505194784658", + "x-ms-attestation-type": "sevsnpvm", + "x-ms-compliance-status": "azure-compliant-uvm", + "x-ms-policy-hash": "9NY0VnTQ-IiBriBplVUpFbczcDaEBUwsiFYAzHu_gco", + "x-ms-runtime": { + "location": "East US", + "publicKey": "abc" + }, + "x-ms-sevsnpvm-authorkeydigest": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "x-ms-sevsnpvm-bootloader-svn": 3, + "x-ms-sevsnpvm-familyId": "01000000000000000000000000000000", + "x-ms-sevsnpvm-guestsvn": 2, + "x-ms-sevsnpvm-hostdata": "fef932e0103f6132437e8a1223f32efc4bea63342f893b5124645224ef29ba73", + "x-ms-sevsnpvm-idkeydigest": "ebeeeabce075eeaba3d9ea24d8495137a2877c0d20ac6ea73fc6d2f8aeb50de132150e0a0752664919bcebbf2e8c5807", + "x-ms-sevsnpvm-imageId": "02000000000000000000000000000000", + "x-ms-sevsnpvm-is-debuggable": false, + "x-ms-sevsnpvm-launchmeasurement": "03fea02823189b25d0623a5c81f97c8ba4d2fbc48c914a55ce525f90454ddcec303743dac2fc013f0846912d1412f6df", + "x-ms-sevsnpvm-microcode-svn": 115, + "x-ms-sevsnpvm-migration-allowed": false, + "x-ms-sevsnpvm-reportdata": "4e7d4a413745ddea79f05d20d9ac7add3659ac783ef24684127bbbb3e50fc63c0000000000000000000000000000000000000000000000000000000000000000", + "x-ms-sevsnpvm-reportid": "d137a83c2d42d81dd42d39ad95ef9023de63216ddaaf2c368a8c41a636ddb2a9", + "x-ms-sevsnpvm-smt-allowed": true, + "x-ms-sevsnpvm-snpfw-svn": 8, + "x-ms-sevsnpvm-tee-svn": 0, + "x-ms-sevsnpvm-uvm-endorsement": { + "x-ms-sevsnpvm-guestsvn": "100", + "x-ms-sevsnpvm-launchmeasurement": "03fea02823189b25d0623a5c81f97c8ba4d2fbc48c914a55ce525f90454ddcec303743dac2fc013f0846912d1412f6df" + }, + "x-ms-sevsnpvm-uvm-endorsement-headers": { + "feed": "ContainerPlat-AMD-UVM", + "iss": "did:x509:0:sha256:I__iuL25oXEVFdTP_aBLx_eT1RPHbCQ_ECBQfYZpt9s::eku:1.3.6.1.4.1.311.76.59.1.2" + }, + "x-ms-sevsnpvm-vmpl": 0, + "x-ms-ver": "1.0" +}