Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Confidential Container MAA token parse and policy validation logic #140

Merged
merged 9 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@
<artifactId>google-cloud-logging</artifactId>
<version>3.13.7</version>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-security-attestation</artifactId>
<version>1.1.15</version>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-core-http-netty</artifactId>
<version>1.13.6</version>
</dependency>
<dependency>
<groupId>co.nstant.in</groupId>
<artifactId>cbor</artifactId>
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/com/uid2/shared/secure/JwtUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.uid2.shared.secure;

import java.util.Map;

public class JwtUtils {
public static<T> T tryGetField(Map payload, String key, Class<T> clazz){
if(payload == null){
return null;
}
var rawValue = payload.get(key);
return tryConvert(rawValue, clazz);
}

public static<T> T tryConvert(Object obj, Class<T> clazz){
if(obj == null){
return null;
}
try{
return clazz.cast(obj);
}
catch (ClassCastException e){
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.uid2.shared.secure.azurecc;

import com.azure.security.attestation.AttestationClientBuilder;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import com.uid2.shared.secure.AttestationException;

import java.security.PublicKey;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

// MAA certs are stored as x5c(X.509 certificate chain), not supported by Google auth lib.
// So we have to build a thin layer to fetch Azure public key.
public class AzurePublicKeyProvider implements IPublicKeyProvider {

private final LoadingCache<String, Map<String, PublicKey>> publicKeyCache;

public AzurePublicKeyProvider() {
this.publicKeyCache = CacheBuilder.newBuilder()
.expireAfterWrite(1L, TimeUnit.HOURS)
.build(new CacheLoader<>() {
@Override
public Map<String, PublicKey> load(String maaServerBaseUrl) throws AttestationException {
return loadPublicKeys(maaServerBaseUrl);
}
});
}

@Override
public PublicKey GetPublicKey(String maaServerBaseUrl, String kid) throws AttestationException {
PublicKey key;
try {
key = publicKeyCache.get(maaServerBaseUrl).get(kid);
}
catch (ExecutionException e){
throw new AttestationException(
String.format("Error fetching PublicKey from certificate location: %s, error: %s.", maaServerBaseUrl, e.getMessage())
);
}

if(key == null){
throw new AttestationException("Could not find PublicKey for provided keyId: " + kid);
}
return key;
}

// We don't want to reinvent the wheel. Leverage Azure Attestation client library to fetch certs.
private static Map<String, PublicKey> loadPublicKeys(String maaServerBaseUrl) throws AttestationException {
var attestationBuilder = new AttestationClientBuilder();
var client = attestationBuilder
.endpoint(maaServerBaseUrl)
.buildClient();

var signers = client.listAttestationSigners().getAttestationSigners();

ImmutableMap.Builder<String, PublicKey> keyCacheBuilder = new ImmutableMap.Builder();

for (var signer : signers){
var keyId = signer.getKeyId();
var certs = signer.getCertificates();
if(!certs.isEmpty()){
var publicKey = certs.get(0).getPublicKey();
lunwang-ttd marked this conversation as resolved.
Show resolved Hide resolved
keyCacheBuilder.put(keyId, publicKey);
}
}

var map = keyCacheBuilder.build();
if(map.isEmpty()){
throw new AttestationException("Fail to load certs from: " + maaServerBaseUrl);
}

return map;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationException;

public interface IMaaTokenSignatureValidator {
/**
* Validate token signature against authorized issuer.
*
* @param tokenString The raw MAA token string.
* @return Parsed token payload.
* @throws AttestationException
*/
MaaTokenPayload validate(String tokenString) throws AttestationException;
}
15 changes: 15 additions & 0 deletions src/main/java/com/uid2/shared/secure/azurecc/IPolicyValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationException;

public interface IPolicyValidator {
/**
* Validate token payload against defined policies.
*
* @param maaTokenPayload The parsed MAA token.
* @param publicKey The public key info to verify in payload runtime data.
* @return The enclave id.
* @throws AttestationException
*/
String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.uid2.shared.secure.azurecc;

import com.uid2.shared.secure.AttestationException;

import java.security.PublicKey;

public interface IPublicKeyProvider {
/**
* Get Public Key from a MAA server.
*
* @param maaServerBaseUrl The Base Url of MAA server.
* @param kid The key id.
* @return The public key.
* @throws AttestationException
*/
PublicKey GetPublicKey(String maaServerBaseUrl, String kid) throws AttestationException;
}
26 changes: 26 additions & 0 deletions src/main/java/com/uid2/shared/secure/azurecc/MaaTokenPayload.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.uid2.shared.secure.azurecc;

import lombok.Builder;
import lombok.Value;

@Value
@Builder(toBuilder = true)
public class MaaTokenPayload {
public static final String SEV_SNP_VM_TYPE = "sevsnpvm";
public static final String AZURE_COMPLIANT_UVM = "azure-compliant-uvm";

private String attestationType;
private String complianceStatus;
private boolean vmDebuggable;
private String ccePolicy;
lunwang-ttd marked this conversation as resolved.
Show resolved Hide resolved

private RuntimeData runtimeData;

public boolean isSevSnpVM(){
return SEV_SNP_VM_TYPE.equalsIgnoreCase(attestationType);
}

public boolean isUtilityVMCompliant(){
return AZURE_COMPLIANT_UVM.equalsIgnoreCase(complianceStatus);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.uid2.shared.secure.azurecc;

import com.google.api.client.json.gson.GsonFactory;
import com.google.api.client.json.webtoken.JsonWebSignature;
import com.google.api.client.util.Clock;
import com.google.auth.oauth2.TokenVerifier;
import com.google.common.base.Strings;
import com.uid2.shared.secure.AttestationException;

import java.io.IOException;
import java.util.Map;

import static com.uid2.shared.secure.JwtUtils.tryGetField;

public class MaaTokenSignatureValidator implements IMaaTokenSignatureValidator {

// set to true to facilitate local test.
public static final boolean BYPASS_SIGNATURE_CHECK = false;

// e.g. https://sharedeus.eus.attest.azure.net
private final String maaServerBaseUrl;

private final IPublicKeyProvider publicKeyProvider;

// used in UT
private final Clock clockOverride;

public MaaTokenSignatureValidator(String maaServerBaseUrl) {
this(maaServerBaseUrl, new AzurePublicKeyProvider(), null);
}

protected MaaTokenSignatureValidator(String maaServerBaseUrl, IPublicKeyProvider publicKeyProvider, Clock clockOverride) {
this.maaServerBaseUrl = maaServerBaseUrl;
this.publicKeyProvider = publicKeyProvider;
this.clockOverride = clockOverride;
}

private TokenVerifier buildTokenVerifier(String kid) throws AttestationException {
var verifierBuilder = TokenVerifier.newBuilder();

verifierBuilder.setPublicKey(publicKeyProvider.GetPublicKey(maaServerBaseUrl, kid));

if (clockOverride != null) {
verifierBuilder.setClock(clockOverride);
}

verifierBuilder.setIssuer(maaServerBaseUrl);

return verifierBuilder.build();
}

@Override
public MaaTokenPayload validate(String tokenString) throws AttestationException {
if (Strings.isNullOrEmpty(tokenString)) {
throw new IllegalArgumentException("tokenString can not be null or empty");
}

// Validate Signature
JsonWebSignature signature;
try {
signature = JsonWebSignature.parse(GsonFactory.getDefaultInstance(), tokenString);
if(!BYPASS_SIGNATURE_CHECK){
var kid = signature.getHeader().getKeyId();
var tokenVerifier = buildTokenVerifier(kid);
tokenVerifier.verify(tokenString);
}
} catch (TokenVerifier.VerificationException e) {
throw new AttestationException("Fail to validate the token signature, error: " + e.getMessage());
} catch (IOException e) {
throw new AttestationException("Fail to parse token, error: " + e.getMessage());
}

// Parse Payload
var rawPayload = signature.getPayload();

var tokenPayloadBuilder = MaaTokenPayload.builder();

tokenPayloadBuilder.attestationType(tryGetField(rawPayload, "x-ms-attestation-type", String.class));
tokenPayloadBuilder.complianceStatus(tryGetField(rawPayload, "x-ms-compliance-status", String.class));
tokenPayloadBuilder.vmDebuggable(tryGetField(rawPayload, "x-ms-sevsnpvm-is-debuggable", Boolean.class));
tokenPayloadBuilder.ccePolicy(tryGetField(rawPayload, "x-ms-sevsnpvm-hostdata", String.class));

var runtime = tryGetField(rawPayload, ("x-ms-runtime"), Map.class);

if(runtime != null){
var runtimeDataBuilder = RuntimeData.builder();
runtimeDataBuilder.location(tryGetField(runtime, "location", String.class));
runtimeDataBuilder.publicKey(tryGetField(runtime, "publicKey", String.class));
tokenPayloadBuilder.runtimeData(runtimeDataBuilder.build());
}

return tokenPayloadBuilder.build();
}
}
54 changes: 54 additions & 0 deletions src/main/java/com/uid2/shared/secure/azurecc/PolicyValidator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.uid2.shared.secure.azurecc;

import com.google.common.base.Strings;
import com.uid2.shared.secure.AttestationException;

public class PolicyValidator implements IPolicyValidator{
private final String LOCATION_CHINA = "china";
private final String LOCATION_EU = "europe";
@Override
public String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException {
verifyVM(maaTokenPayload);
verifyLocation(maaTokenPayload);
verifyPublicKey(maaTokenPayload, publicKey);
return maaTokenPayload.getCcePolicy();
}

private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException {
if(Strings.isNullOrEmpty(publicKey)){
throw new AttestationException("public key to check is null or empty");
}
var runtimePublicKey = maaTokenPayload.getRuntimeData().getPublicKey();
if(!publicKey.equals(runtimePublicKey)){
throw new AttestationException(
String.format("Public key in payload is not match expected value. More info: runtime(%s), expected(%s)",
runtimePublicKey,
publicKey
));
}
}

private void verifyVM(MaaTokenPayload maaTokenPayload) throws AttestationException {
if(!maaTokenPayload.isSevSnpVM()){
throw new AttestationException("Not in SevSnp VM");
}
if(!maaTokenPayload.isUtilityVMCompliant()){
throw new AttestationException("Not run in Azure Compliance Utility VM");
}
if(maaTokenPayload.isVmDebuggable()){
throw new AttestationException("The underlying harware should not run in debug mode");
}
}

private void verifyLocation(MaaTokenPayload maaTokenPayload) throws AttestationException {
var location = maaTokenPayload.getRuntimeData().getLocation();
if(Strings.isNullOrEmpty(location)){
throw new AttestationException("Location is not specified.");
}
var lowerCaseLocation = location.toLowerCase();
if(lowerCaseLocation.contains(LOCATION_CHINA) ||
lowerCaseLocation.contains(LOCATION_EU)){
throw new AttestationException("Location is not supported. Value: " + location);
}
}
}
11 changes: 11 additions & 0 deletions src/main/java/com/uid2/shared/secure/azurecc/RuntimeData.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.uid2.shared.secure.azurecc;

import lombok.Builder;
import lombok.Value;

@Value
@Builder(toBuilder = true)
public class RuntimeData {
private String location;
private String publicKey;
}
Loading