Skip to content

Commit

Permalink
refactor: Created JWT service that accepts callbacks and adjusted the…
Browse files Browse the repository at this point in the history
… code.
  • Loading branch information
Zoe Maas committed Oct 4, 2024
1 parent a419a02 commit 1c90e8d
Show file tree
Hide file tree
Showing 13 changed files with 422 additions and 485 deletions.
4 changes: 2 additions & 2 deletions modules/openid-federation-client/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ kotlin {
dependencies {
implementation("io.ktor:ktor-client-core-jvm:$ktorVersion")
runtimeOnly("io.ktor:ktor-client-cio-jvm:$ktorVersion")
implementation("com.nimbusds:nimbus-jose-jwt:9.40")
implementation(project(":modules:openid-federation-common"))
}
}
val jvmTest by getting {
dependencies {
implementation(kotlin("test-junit"))
implementation("com.nimbusds:nimbus-jose-jwt:9.40")
}
}
// TODO Should be placed back at a later point in time: https://sphereon.atlassian.net/browse/OIDF-50
Expand Down Expand Up @@ -141,7 +141,6 @@ kotlin {
runtimeOnly("io.ktor:ktor-client-core-js:$ktorVersion")
runtimeOnly("io.ktor:ktor-client-js:$ktorVersion")
implementation(npm("typescript", "5.5.3"))
implementation(npm("jose", "5.6.3"))
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.7.1")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.9.0-RC")
implementation(project(":modules:openid-federation-common"))
Expand All @@ -151,6 +150,7 @@ kotlin {
val jsTest by getting {
dependencies {
implementation(kotlin("test-js"))
implementation(npm("jose", "5.6.3"))
implementation(kotlin("test-annotations-common"))
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.9.0-RC")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package com.sphereon.oid.fed.client.validation

import com.sphereon.oid.fed.client.httpclient.OidFederationClient
import com.sphereon.oid.fed.common.jwt.JwtService
import com.sphereon.oid.fed.common.jwt.JwtVerifyInput
import com.sphereon.oid.fed.common.logging.Logger
import com.sphereon.oid.fed.common.mapper.JsonMapper
import com.sphereon.oid.fed.openapi.models.EntityConfigurationStatement
import com.sphereon.oid.fed.openapi.models.Jwk
import com.sphereon.oid.fed.openapi.models.SubordinateStatement
import io.ktor.client.engine.HttpClientEngine
import kotlinx.datetime.Clock
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlin.js.ExperimentalJsExport
import kotlin.js.JsExport

expect class TrustChainValidation {
fun validateTrustChains(
jwts: List<List<String>>,
knownTrustChainIds: List<String>
): List<List<Any>>
}

@ExperimentalJsExport
@JsExport
class TrustChainValidationCommon {
class TrustChainValidationCommon(val jwtService: JwtService) {

suspend fun readAuthorityHints(
partyBId: String,
Expand Down Expand Up @@ -71,7 +66,101 @@ class TrustChainValidationCommon {
return trustChains
}

fun retrieveJwk(key: JsonElement): Jwk {
fun validateTrustChains(
jwts: List<List<String>>,
knownTrustChainIds: List<String>
): List<List<Any>> {
val trustChains: MutableList<List<Any>> = mutableListOf()
for(it in jwts) {
try {
trustChains.add(validateTrustChain(it, knownTrustChainIds))
} catch (e: Exception) {
Logger.debug("TrustChainValidation", e.message.toString())
}
}
return trustChains
}

private fun validateTrustChain(jwts: List<String>, knownTrustChainIds: List<String>): List<Any> {
val entityStatements = jwts.toMutableList()
val firstEntityConfiguration =
entityStatements.removeFirst().let { JsonMapper().mapEntityConfigurationStatement(it) }
val lastEntityConfiguration =
entityStatements.removeLast().let { JsonMapper().mapEntityConfigurationStatement(it) }
val subordinateStatements = entityStatements.map { JsonMapper().mapSubordinateStatement(it) }

if (firstEntityConfiguration.iss != firstEntityConfiguration.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
}

if (firstEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
input = JwtVerifyInput(
jwt = jwts[0],
key = retrieveJwk(it)
)) } == false) {
throw IllegalArgumentException("Invalid signature")
}

subordinateStatements.forEachIndexed { index, current ->
val next =
if (index < subordinateStatements.size - 1) subordinateStatements[index + 1] else lastEntityConfiguration
val now = Clock.System.now().epochSeconds.toInt()

if (current.iat > now) {
throw IllegalArgumentException("Invalid iat")
}

if (current.exp < now) {
throw IllegalArgumentException("Invalid exp")
}

when (next) {
is EntityConfigurationStatement ->
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
input = JwtVerifyInput(
jwt = jwts[index],
key = retrieveJwk(it)
)) } == false) {
throw IllegalArgumentException("Invalid signature")
}
is SubordinateStatement ->
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
input = JwtVerifyInput(
jwt = jwts[index],
key = retrieveJwk(it)
)) } == false) {
throw IllegalArgumentException("Invalid signature")
}
}
}

if (!knownTrustChainIds.contains(lastEntityConfiguration.iss)) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to the Entity Identifier of the Trust Anchor")
}
if (lastEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any {
jwtService.verify(
input = JwtVerifyInput(
jwt = jwts[jwts.size - 1],
key = retrieveJwk(it))) } == false) {
throw IllegalArgumentException("Invalid signature")
}

val validTrustChain = mutableListOf<Any>()
validTrustChain.add(firstEntityConfiguration)
validTrustChain.addAll(subordinateStatements)
validTrustChain.add(lastEntityConfiguration)

return validTrustChain
}

private fun retrieveJwk(key: JsonElement): Jwk {
return when (key) {
is JsonObject -> Jwk(
kid = key["kid"]?.jsonPrimitive?.content,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
package com.sphereon.oid.fed.client.validation

import com.sphereon.oid.fed.common.jwt.verify
import com.sphereon.oid.fed.common.logging.Logger
import com.sphereon.oid.fed.common.mapper.JsonMapper
import com.sphereon.oid.fed.common.jwt.JwtService
import com.sphereon.oid.fed.openapi.models.EntityConfigurationStatement
import com.sphereon.oid.fed.openapi.models.SubordinateStatement
import io.ktor.client.engine.*
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.promise
import kotlinx.datetime.Clock
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlin.js.Promise

@ExperimentalJsExport
@JsExport
actual class TrustChainValidation {
class TrustChainValidation(val jwtService: JwtService) {

private val NAME = "TrustChainValidation"

Expand All @@ -26,7 +20,7 @@ actual class TrustChainValidation {
trustChains: MutableList<List<EntityConfigurationStatement>> = mutableListOf(),
trustChain: MutableSet<EntityConfigurationStatement> = mutableSetOf()
): Promise<List<List<EntityConfigurationStatement>>> = CoroutineScope(context = CoroutineName(NAME)).promise {
TrustChainValidationCommon()
TrustChainValidationCommon(jwtService)
.readAuthorityHints(
partyBId = partyBId,
engine = engine,
Expand All @@ -39,90 +33,22 @@ actual class TrustChainValidation {
entityConfigurationStatementsList: List<List<EntityConfigurationStatement>>,
engine: HttpClientEngine
): Promise<List<List<String>>> = CoroutineScope(context = CoroutineName(NAME)).promise {
TrustChainValidationCommon()
TrustChainValidationCommon(jwtService)
.fetchSubordinateStatements(
entityConfigurationStatementsList = entityConfigurationStatementsList,
engine = engine
)
}

actual fun validateTrustChains(
fun validateTrustChains(
jwts: List<List<String>>,
knownTrustChainIds: List<String>
): List<List<Any>> {
val trustChains: MutableList<List<Any>> = mutableListOf()
for(it in jwts) {
try {
trustChains.add(validateTrustChain(it, knownTrustChainIds))
} catch (e: Exception) {
Logger.debug("TrustChainValidation", e.message.toString())
}
}
return trustChains
}

@OptIn(ExperimentalJsExport::class)
private fun validateTrustChain(jwts: List<String>, knownTrustChainIds: List<String>): List<Any> {
val entityStatements = jwts.toMutableList()
val firstEntityConfiguration =
entityStatements.removeFirst().let { JsonMapper().mapEntityConfigurationStatement(it) }
val lastEntityConfiguration =
entityStatements.removeLast().let { JsonMapper().mapEntityConfigurationStatement(it) }
val subordinateStatements = entityStatements.map { JsonMapper().mapSubordinateStatement(it) }

if (firstEntityConfiguration.iss != firstEntityConfiguration.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
}

if (firstEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any { verify(jwts[0],
TrustChainValidationCommon().retrieveJwk(it)) } == false) {
throw IllegalArgumentException("Invalid signature")
}

subordinateStatements.forEachIndexed { index, current ->
val next =
if (index < subordinateStatements.size - 1) subordinateStatements[index + 1] else lastEntityConfiguration
val now = (Clock.System.now().toEpochMilliseconds() / 1000).toInt()

if (current.iat > now) {
throw IllegalArgumentException("Invalid iat")
}

if (current.exp < now) {
throw IllegalArgumentException("Invalid exp")
}

when (next) {
is EntityConfigurationStatement ->
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any { verify(jwts[0],
TrustChainValidationCommon().retrieveJwk(it)) } == false) {
throw IllegalArgumentException("Invalid signature")
}
is SubordinateStatement ->
if (current.iss != next.sub) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to sub")
} else if (next.jwks.jsonObject["keys"]?.jsonArray?.any { verify(jwts[0],
TrustChainValidationCommon().retrieveJwk(it)) } == false) {
throw IllegalArgumentException("Invalid signature")
}
}
}

if (!knownTrustChainIds.contains(lastEntityConfiguration.iss)) {
throw IllegalArgumentException("Entity Configuration of the Trust Chain subject requires that iss is equal to the Entity Identifier of the Trust Anchor")
}
if (lastEntityConfiguration.jwks.jsonObject["keys"]?.jsonArray?.any { verify(jwts[jwts.size - 1],
TrustChainValidationCommon().retrieveJwk(it)) } == false) {
throw IllegalArgumentException("Invalid signature")
}

val validTrustChain = mutableListOf<Any>()
validTrustChain.add(firstEntityConfiguration)
validTrustChain.addAll(subordinateStatements)
validTrustChain.add(lastEntityConfiguration)

return validTrustChain
}
): Promise<List<List<Any>>> =
Promise.resolve(
TrustChainValidationCommon(jwtService)
.validateTrustChains(
jwts = jwts,
knownTrustChainIds = knownTrustChainIds
)
)
}
Loading

0 comments on commit 1c90e8d

Please sign in to comment.