Skip to content

Commit

Permalink
refactor: refactor trust chain validation module
Browse files Browse the repository at this point in the history
  • Loading branch information
Zoe Maas committed Oct 10, 2024
1 parent 37f1759 commit a36581a
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 108 deletions.
3 changes: 0 additions & 3 deletions modules/openid-federation-client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ kotlin {
}
}
}
useEsModules()
generateTypeScriptDefinitions()
binaries.executable()
}

sourceSets {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.sphereon.oid.fed.client

import com.sphereon.oid.fed.client.httpclient.httpService
import com.sphereon.oid.fed.client.validation.trustChainValidationService

object OidFederationClientService {
val HTTP = httpService()
val TRUST_CHAIN_VALIDATION = trustChainValidationService()
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interface HttpClientCallbackService: ICallbackService<IHttpClientService>, IHttp

expect fun httpService(): HttpClientCallbackService

//FIXME Extract the implementation to a separate class
object OidFederationHttpClientObject: HttpClientCallbackService {

private val isRequestAuthenticated: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.sphereon.oid.fed.client.validation

import com.sphereon.oid.fed.common.logging.Logger

object TrustChainValidationConst {
val LOG_NAMESPACE = "sphereon:kmp:openid-federation-client"
val LOG = Logger.Static.tag(LOG_NAMESPACE)
val TRUST_CHAIN_VALIDATION_LITERAL = "TRUST_CHAIN_VALIDATION"
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,76 @@
package com.sphereon.oid.fed.client.validation

import com.sphereon.oid.fed.client.httpclient.OidFederationClient
import com.sphereon.oid.fed.client.ICallbackService
import com.sphereon.oid.fed.client.OidFederationClientService
import com.sphereon.oid.fed.common.jwt.JwtService
import com.sphereon.oid.fed.common.jwt.JwtVerifyInput
import com.sphereon.oid.fed.common.logging.Logger
import com.sphereon.oid.fed.common.mapper.JsonMapper
import com.sphereon.oid.fed.openapi.models.EntityConfigurationStatement
import com.sphereon.oid.fed.openapi.models.Jwk
import com.sphereon.oid.fed.openapi.models.SubordinateStatement
import io.ktor.client.engine.HttpClientEngine
import kotlinx.datetime.Clock
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.jvm.JvmStatic

class TrustChainValidationCommon(val jwtService: JwtService) {

interface ITrustChainValidationCallback {
suspend fun readAuthorityHints(
partyBId: String,
engine: HttpClientEngine,
): List<List<EntityConfigurationStatement>>

suspend fun fetchSubordinateStatements(
entityConfigurationStatementsList: List<List<EntityConfigurationStatement>>,
): List<List<String>>

suspend fun validateTrustChains(
jwts: List<List<String>>,
knownTrustChainIds: List<String>
): List<List<Any>>
}

interface TrustChainValidationCallback : ICallbackService<ITrustChainValidationCallback>,
ITrustChainValidationCallback

expect fun trustChainValidationService(): TrustChainValidationCallback

//FIXME Extract the actual implementation to a separate class
object TrustChainValidationObject : TrustChainValidationCallback {

@JvmStatic
private lateinit var jwtService: JwtService

@JvmStatic
private lateinit var httpService: OidFederationClientService

@JvmStatic
private lateinit var platformCallback: ITrustChainValidationCallback

private var disabled = false

override suspend fun readAuthorityHints(
partyBId: String,
): List<List<EntityConfigurationStatement>> {
if (!isEnabled()) {
TrustChainValidationConst.LOG.info("TRUST_CHAIN_VALIDATION readAuthorityHints has been disabled")
throw IllegalStateException("TRUST_CHAIN_VALIDATION service is disabled; cannot read authority hints")
} else if (!this::platformCallback.isInitialized) {
TrustChainValidationConst.LOG.error(
"TRUST_CHAIN_VALIDATION callback (JS) is not registered"
)
throw IllegalStateException("TRUST_CHAIN_VALIDATION have not been initialized. Please register your TrustChainValidationCallbacksServiceJS implementation, or register a default implementation")
}
return readAuthorityHintsImpl(partyBId)
}

private suspend fun readAuthorityHintsImpl(
partyBId: String,
trustChains: MutableList<List<EntityConfigurationStatement>> = mutableListOf(),
trustChain: MutableSet<EntityConfigurationStatement> = mutableSetOf()
): List<List<EntityConfigurationStatement>> {
OidFederationClient(engine).fetchEntityStatement(partyBId).run {
httpService.HTTP.fetchEntityStatement(partyBId).run {
JsonMapper().mapEntityConfigurationStatement(this).let {
if (it.authorityHints.isNullOrEmpty()) {
trustChain.add(it)
Expand All @@ -33,9 +79,8 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
} else {
it.authorityHints?.forEach { hint ->
trustChain.add(it)
readAuthorityHints(
readAuthorityHintsImpl(
hint,
engine,
trustChains,
trustChain
)
Expand All @@ -46,16 +91,24 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
return trustChains
}

suspend fun fetchSubordinateStatements(
override suspend fun fetchSubordinateStatements(
entityConfigurationStatementsList: List<List<EntityConfigurationStatement>>,
engine: HttpClientEngine
): List<List<String>> {
if (!isEnabled()) {
TrustChainValidationConst.LOG.info("TRUST_CHAIN_VALIDATION readAuthorityHints has been disabled")
throw IllegalStateException("TRUST_CHAIN_VALIDATION service is disabled; cannot read authority hints")
} else if (!this::platformCallback.isInitialized) {
TrustChainValidationConst.LOG.error(
"TRUST_CHAIN_VALIDATION callback (JS) is not registered"
)
throw IllegalStateException("TRUST_CHAIN_VALIDATION have not been initialized. Please register your TrustChainValidationCallbacksServiceJS implementation, or register a default implementation")
}
val trustChains: MutableList<List<String>> = mutableListOf()
val trustChain: MutableList<String> = mutableListOf()
entityConfigurationStatementsList.forEach { entityConfigurationStatements ->
entityConfigurationStatements.forEach { it ->
it.metadata?.jsonObject?.get("federation_entity")?.jsonObject?.get("federation_fetch_endpoint")?.jsonPrimitive?.content.let { url ->
OidFederationClient(engine).fetchEntityStatement(url.toString()).run {
httpService.HTTP.fetchEntityStatement(url.toString()).run {
trustChain.add(this)
}
}
Expand All @@ -66,22 +119,31 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
return trustChains
}

fun validateTrustChains(
override suspend fun validateTrustChains(
jwts: List<List<String>>,
knownTrustChainIds: List<String>
): List<List<Any>> {
if (!isEnabled()) {
TrustChainValidationConst.LOG.info("TRUST_CHAIN_VALIDATION readAuthorityHints has been disabled")
throw IllegalStateException("TRUST_CHAIN_VALIDATION service is disabled; cannot read authority hints")
} else if (!this::platformCallback.isInitialized) {
TrustChainValidationConst.LOG.error(
"TRUST_CHAIN_VALIDATION callback (JS) is not registered"
)
throw IllegalStateException("TRUST_CHAIN_VALIDATION have not been initialized. Please register your TrustChainValidationCallbacksServiceJS implementation, or register a default implementation")
}
val trustChains: MutableList<List<Any>> = mutableListOf()
for(it in jwts) {
for (it in jwts) {
try {
trustChains.add(validateTrustChain(it, knownTrustChainIds))
} catch (e: Exception) {
Logger.debug("TrustChainValidation", e.message.toString())
TrustChainValidationConst.LOG.error("Trust Chain Validation Error: ${e.message.toString()}")
}
}
return trustChains
}

private fun validateTrustChain(jwts: List<String>, knownTrustChainIds: List<String>): List<Any> {
private suspend fun validateTrustChain(jwts: List<String>, knownTrustChainIds: List<String>): List<Any> {
val entityStatements = jwts.toMutableList()
val firstEntityConfiguration =
entityStatements.removeFirst().let { JsonMapper().mapEntityConfigurationStatement(it) }
Expand All @@ -94,11 +156,13 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
}

if (firstEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
jwtService.JWT.verify(
input = JwtVerifyInput(
jwt = jwts[0],
key = retrieveJwk(it)
)) } == false) {
)
)
} == false) {
throw IllegalArgumentException("Invalid signature")
}

Expand All @@ -120,22 +184,27 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
jwtService.JWT.verify(
input = JwtVerifyInput(
jwt = jwts[index],
key = retrieveJwk(it)
)) } == false) {
)
)
} == false) {
throw IllegalArgumentException("Invalid signature")
}

is SubordinateStatement ->
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
jwtService.JWT.verify(
input = JwtVerifyInput(
jwt = jwts[index],
key = retrieveJwk(it)
)) } == false) {
)
)
} == false) {
throw IllegalArgumentException("Invalid signature")
}
}
Expand All @@ -145,10 +214,13 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to the Entity Identifier of the Trust Anchor")
}
if (lastEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
jwtService.JWT.verify(
input = JwtVerifyInput(
jwt = jwts[jwts.size - 1],
key = retrieveJwk(it))) } == false) {
key = retrieveJwk(it)
)
)
} == false) {
throw IllegalArgumentException("Invalid signature")
}

Expand All @@ -169,7 +241,27 @@ class TrustChainValidationCommon(val jwtService: JwtService) {
x = key["x"]?.jsonPrimitive?.content,
y = key["y"]?.jsonPrimitive?.content
)

else -> throw IllegalArgumentException("Invalid key")
}
}

override fun disable(): ITrustChainValidationCallback {
this.disabled = true
return this
}

override fun enable(): ITrustChainValidationCallback {
this.disabled = false
return this
}

override fun isEnabled(): Boolean {
return !this.disabled
}

override fun register(platformCallback: ITrustChainValidationCallback): ITrustChainValidationCallback {
this.platformCallback = platformCallback
return this
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.sphereon.oid.fed.client

import com.sphereon.oid.fed.client.httpclient.httpService
import com.sphereon.oid.fed.client.validation.trustChainValidationService


object OidFederationClientServiceJS {
val HTTP = httpService()
val TRUST_CHAIN_VALIDATION = trustChainValidationService()
}

/**
Expand Down
Loading

0 comments on commit a36581a

Please sign in to comment.