Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

differentiate 401 with 500 #159

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.uid2.shared.secure;

public class AttestationClientException extends AttestationException
{
public AttestationClientException(Throwable cause) {
super(cause, true);
}

public AttestationClientException(String message) {
super(message, true);
}
}
19 changes: 18 additions & 1 deletion src/main/java/com/uid2/shared/secure/AttestationException.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
package com.uid2.shared.secure;

public class AttestationException extends Exception {
private final boolean isClientError;

public boolean IsClientError() {
return this.isClientError;
}

public AttestationException(Throwable cause, boolean isClientError) {
super(cause);
this.isClientError = isClientError;
}

public AttestationException(Throwable cause) {
this(cause, false);
}

public AttestationException(String cause, boolean isClientError) {
super(cause);
this.isClientError = isClientError;
}

public AttestationException(String message) {
super(message);
this(message, false);
}
}
14 changes: 14 additions & 0 deletions src/main/java/com/uid2/shared/secure/AttestationResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,27 @@ public class AttestationResult {

private final String enclaveId;

private final AttestationClientException attestationClientException;

public AttestationResult(AttestationFailure reasonToFail) {
this.failure = reasonToFail;
this.publicKey = null;
this.enclaveId = "Failed attestation, enclave Id unknown";
this.attestationClientException = null;
}

public AttestationResult(AttestationClientException exception) {
this.failure = AttestationFailure.UNKNOWN;
this.publicKey = null;
this.enclaveId = "Failed attestation, enclave Id unknown";
this.attestationClientException = exception;
}

public AttestationResult(byte[] publicKey, String enclaveId) {
this.failure = AttestationFailure.NONE;
this.publicKey = publicKey;
this.enclaveId = enclaveId;
this.attestationClientException = null;
}

public boolean isSuccess() {
Expand All @@ -25,6 +36,9 @@ public boolean isSuccess() {
public AttestationFailure getFailure() { return this.failure; }

public String getReason() {
if (this.attestationClientException != null) {
return this.attestationClientException.getMessage();
}
return this.failure.explain();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,15 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler<AsyncRes
log.info("Successfully attested azure-cc against registered enclaves, enclave id: " + enclaveId);
handler.handle(Future.succeededFuture(new AttestationResult(publicKey, enclaveId)));
} else {
throw new AttestationException("Unregistered enclave, enclave id: " + enclaveId);
log.warn("Got unsupported azure-cc enclave id: " + enclaveId);
handler.handle(Future.succeededFuture(new AttestationResult(AttestationFailure.FORBIDDEN_ENCLAVE)));
}
} catch (AttestationException ex) {
handler.handle(Future.failedFuture(ex));
}
catch (AttestationClientException ace){
handler.handle(Future.succeededFuture(new AttestationResult(ace)));
}
catch (AttestationException ae) {
handler.handle(Future.failedFuture(ae));
} catch (Exception ex) {
handler.handle(Future.failedFuture(new AttestationException(ex)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ public void attest(byte[] attestationRequest, byte[] publicKey, Handler<AsyncRes
if (enclaveId != null) {
handler.handle(Future.succeededFuture(new AttestationResult(publicKey, enclaveId)));
} else {
throw new AttestationException("unauthorized token");
LOGGER.warn("Can not find registered gcp-oidc enclave id.");
yishi-ttd marked this conversation as resolved.
Show resolved Hide resolved
handler.handle(Future.succeededFuture(new AttestationResult(AttestationFailure.FORBIDDEN_ENCLAVE)));
yishi-ttd marked this conversation as resolved.
Show resolved Hide resolved
}
}
catch (AttestationException ex){
handler.handle(Future.failedFuture(ex));
catch (AttestationClientException ace){
handler.handle(Future.succeededFuture(new AttestationResult(ace)));
}
catch (AttestationException ae){
handler.handle(Future.failedFuture(ae));
}
catch (Exception ex) {
handler.handle(Future.failedFuture(new AttestationException(ex)));
Expand Down Expand Up @@ -79,21 +83,29 @@ public Collection<String> getEnclaveAllowlist() {
// Returns
// null if validation failed
// enclaveId if validation succeed
private String validate(TokenPayload tokenPayload) {
private String validate(TokenPayload tokenPayload) throws Exception {
Exception lastException = null;
for (var policyValidator : supportedPolicyValidators) {
LOGGER.info("Validating policy... Validator version: " + policyValidator.getVersion());
try {
var enclaveId = policyValidator.validate(tokenPayload);
LOGGER.info("Validator version: " + policyValidator.getVersion() + ", result: " + enclaveId);

if (allowedEnclaveIds.contains(enclaveId)) {
LOGGER.info("Successfully attested OIDC against registered enclaves");
LOGGER.info("Successfully attested gcp-oidc against registered enclaves");
return enclaveId;
} else {
LOGGER.warn("Got unsupported gcp-oidc enclave id: " + enclaveId);
}
} catch (Exception ex) {
lastException = ex;
LOGGER.warn("Fail to validator version: " + policyValidator.getVersion() + ", error :" + ex.getMessage());
}
}

if(lastException != null){
throw lastException;
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.google.api.client.util.Clock;
import com.google.auth.oauth2.TokenVerifier;
import com.google.common.base.Strings;
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;

import java.io.IOException;
Expand Down Expand Up @@ -65,7 +66,7 @@ public MaaTokenPayload validate(String tokenString) throws AttestationException
tokenVerifier.verify(tokenString);
}
} catch (TokenVerifier.VerificationException e) {
throw new AttestationException("Fail to validate the token signature, error: " + e.getMessage());
throw new AttestationClientException("Fail to validate the token signature, error: " + e.getMessage());
thomasm-ttd marked this conversation as resolved.
Show resolved Hide resolved
} catch (IOException e) {
throw new AttestationException("Fail to parse token, error: " + e.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.uid2.shared.secure.azurecc;

import com.google.common.base.Strings;
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;

public class PolicyValidator implements IPolicyValidator{
Expand All @@ -16,11 +17,11 @@ public String validate(MaaTokenPayload maaTokenPayload, String publicKey) throws

private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey) throws AttestationException {
if(Strings.isNullOrEmpty(publicKey)){
throw new AttestationException("public key to check is null or empty");
throw new AttestationClientException("public key to check is null or empty");
}
var runtimePublicKey = maaTokenPayload.getRuntimeData().getPublicKey();
if(!publicKey.equals(runtimePublicKey)){
throw new AttestationException(
throw new AttestationClientException(
String.format("Public key in payload is not match expected value. More info: runtime(%s), expected(%s)",
runtimePublicKey,
publicKey
Expand All @@ -30,25 +31,25 @@ private void verifyPublicKey(MaaTokenPayload maaTokenPayload, String publicKey)

private void verifyVM(MaaTokenPayload maaTokenPayload) throws AttestationException {
if(!maaTokenPayload.isSevSnpVM()){
throw new AttestationException("Not in SevSnp VM");
throw new AttestationClientException("Not in SevSnp VM");
}
if(!maaTokenPayload.isUtilityVMCompliant()){
throw new AttestationException("Not run in Azure Compliance Utility VM");
throw new AttestationClientException("Not run in Azure Compliance Utility VM");
}
if(maaTokenPayload.isVmDebuggable()){
throw new AttestationException("The underlying harware should not run in debug mode");
throw new AttestationClientException("The underlying hardware should not run in debug mode");
}
}

private void verifyLocation(MaaTokenPayload maaTokenPayload) throws AttestationException {
var location = maaTokenPayload.getRuntimeData().getLocation();
if(Strings.isNullOrEmpty(location)){
throw new AttestationException("Location is not specified.");
throw new AttestationClientException("Location is not specified.");
}
var lowerCaseLocation = location.toLowerCase();
if(lowerCaseLocation.contains(LOCATION_CHINA) ||
lowerCaseLocation.contains(LOCATION_EU)){
throw new AttestationException("Location is not supported. Value: " + location);
throw new AttestationClientException("Location is not supported. Value: " + location);
}
}
}
19 changes: 10 additions & 9 deletions src/main/java/com/uid2/shared/secure/gcpoidc/PolicyValidator.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.uid2.shared.Utils;
import com.uid2.shared.secure.AttestationClientException;
import com.uid2.shared.secure.AttestationException;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
Expand Down Expand Up @@ -53,18 +54,18 @@ public String validate(TokenPayload payload) throws AttestationException {

private static boolean checkConfidentialSpace(TokenPayload payload) throws AttestationException{
if(!payload.isConfidentialSpaceSW()){
throw new AttestationException("Unexpected SW_NAME: " + payload.getSwName());
throw new AttestationClientException("Unexpected SW_NAME: " + payload.getSwName());
}
var isDebugMode = payload.isDebugMode();
if(!isDebugMode && !payload.isStableVersion()){
throw new AttestationException("Confidential space image version is not stable.");
throw new AttestationClientException("Confidential space image version is not stable.");
}
return isDebugMode;
}

private static String checkWorkload(TokenPayload payload) throws AttestationException{
if(!payload.isRestartPolicyNever()){
throw new AttestationException("Restart policy is not set to Never. Value: " + payload.getRestartPolicy());
throw new AttestationClientException("Restart policy is not set to Never. Value: " + payload.getRestartPolicy());
}
return payload.getWorkloadImageDigest();
}
Expand All @@ -75,35 +76,35 @@ private static String checkWorkload(TokenPayload payload) throws AttestationExce
private static String checkRegion(TokenPayload payload) throws AttestationException{
var region = payload.getGceZone();
if(Strings.isNullOrEmpty(region) || region.startsWith(EU_REGION_PREFIX)){
throw new AttestationException("Region is not supported. Value: " + region);
throw new AttestationClientException("Region is not supported. Value: " + region);
}
return region;
}

private static void checkCmdOverrides(TokenPayload payload) throws AttestationException{
if(!CollectionUtils.isEmpty(payload.getCmdOverrides())){
throw new AttestationException("Payload should not have cmd overrides");
throw new AttestationClientException("Payload should not have cmd overrides");
}
}

private Environment checkEnvOverrides(TokenPayload payload) throws AttestationException{
var envOverrides = payload.getEnvOverrides();
if(MapUtils.isEmpty(envOverrides)){
throw new AttestationException("env overrides should not be empty");
throw new AttestationClientException("env overrides should not be empty");
}
HashMap<String, String> envOverridesCopy = new HashMap(envOverrides);

// check all required env overrides
for(var envKey: REQUIRED_ENV_OVERRIDES){
if(Strings.isNullOrEmpty(envOverridesCopy.get(envKey))){
throw new AttestationException("Required env override is missing. key: " + envKey);
throw new AttestationClientException("Required env override is missing. key: " + envKey);
}
}

// env could be parsed
var env = Environment.fromString(envOverridesCopy.get(ENV_ENVIRONMENT));
if(env == null){
throw new AttestationException("Environment can not be parsed. " + envOverridesCopy.get(ENV_ENVIRONMENT));
throw new AttestationClientException("Environment can not be parsed. " + envOverridesCopy.get(ENV_ENVIRONMENT));
}

// make sure there's no unexpected overrides
Expand All @@ -118,7 +119,7 @@ private Environment checkEnvOverrides(TokenPayload payload) throws AttestationEx
}

if(!envOverridesCopy.isEmpty()){
throw new AttestationException("More env overrides than allowed. " + envOverridesCopy);
throw new AttestationClientException("More env overrides than allowed. " + envOverridesCopy);
}

return env;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

import java.nio.charset.StandardCharsets;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -45,9 +44,7 @@ class AzureCCAttestationProviderTest {
@BeforeEach
public void setup() throws AttestationException {
when(alwaysPassTokenValidator.validate(any())).thenReturn(VALID_TOKEN_PAYLOAD);
when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("token signature validation failed"));
when(alwaysPassPolicyValidator.validate(any(), any())).thenReturn(ENCLAVE_ID);
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("policy validation failed"));
}

@Test
Expand All @@ -61,7 +58,21 @@ public void testHappyPath() throws AttestationException {
}

@Test
public void testSignatureCheckFailed() throws AttestationException {
public void testSignatureCheckFailed_ClientError() throws AttestationException {
var errorStr = "token signature validation failed";
when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationClientException(errorStr));
var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertTrue(ar.succeeded());
assertFalse(ar.result().isSuccess());
assertEquals(errorStr, ar.result().getReason());
});
}

@Test
public void testSignatureCheckFailed_ServerError() throws AttestationException {
when(alwaysFailTokenValidator.validate(any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
Expand All @@ -71,7 +82,21 @@ public void testSignatureCheckFailed() throws AttestationException {
}

@Test
public void testPolicyCheckFailed() throws AttestationException {
public void testPolicyCheckFailed_ClientError() throws AttestationException {
var errorStr = "policy validation failed";
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationClientException(errorStr));
var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysFailPolicyValidator);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
assertTrue(ar.succeeded());
assertFalse(ar.result().isSuccess());
assertEquals(errorStr, ar.result().getReason());
});
}

@Test
public void testPolicyCheckFailed_ServerError() throws AttestationException {
when(alwaysFailPolicyValidator.validate(any(), any())).thenThrow(new AttestationException("unknown server error"));
var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysFailPolicyValidator);
provider.registerEnclave(ENCLAVE_ID);
attest(provider, ar -> {
Expand All @@ -84,8 +109,9 @@ public void testPolicyCheckFailed() throws AttestationException {
public void testEnclaveNotRegistered() throws AttestationException {
var provider = new AzureCCAttestationProvider(alwaysFailTokenValidator, alwaysPassPolicyValidator);
attest(provider, ar -> {
assertFalse(ar.succeeded());
assertTrue(ar.cause() instanceof AttestationException);
assertTrue(ar.succeeded());
assertFalse(ar.result().isSuccess());
assertEquals(AttestationFailure.FORBIDDEN_ENCLAVE, ar.result().getFailure());
});
}

Expand Down
Loading