Skip to content

Commit

Permalink
feat: added Amazon KMS sign, verify and generate key pair
Browse files Browse the repository at this point in the history
  • Loading branch information
robertmathew committed Sep 4, 2024
1 parent f000710 commit c3029f5
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 11 deletions.
1 change: 1 addition & 0 deletions modules/amazon-kms/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies {
api(projects.modules.openapi)
implementation(platform("software.amazon.awssdk:bom:2.21.1"))
implementation("software.amazon.awssdk:kms")
implementation("io.ktor:ktor-serialization-kotlinx-json:2.3.11")
}

tasks.test {
Expand Down
86 changes: 86 additions & 0 deletions modules/amazon-kms/src/main/kotlin/AmazonKms.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package com.sphereon.oid.fed.kms.local

import com.sphereon.oid.fed.kms.amazon.extensions.toJwkAdminDto
import com.sphereon.oid.fed.openapi.models.JWTHeader
import com.sphereon.oid.fed.openapi.models.Jwk
import com.sphereon.oid.fed.openapi.models.JwkAdminDTO
import kotlinx.serialization.json.JsonObject
import software.amazon.awssdk.core.SdkBytes
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.kms.KmsClient
import software.amazon.awssdk.services.kms.model.*
import java.nio.charset.StandardCharsets
import java.util.*

class AmazonKms {

private val kmsClient = KmsClient.builder().region(Region.US_WEST_2) // Replace with your desired region
.build()

fun generateKey(): JwkAdminDTO {
val keyId = createKey()

val request =
GenerateDataKeyPairRequest.builder().keyId(keyId).keyPairSpec(DataKeyPairSpec.ECC_NIST_P256).build()
val response = kmsClient.generateDataKeyPair(request)

//TODO: Check this logic
val jwk = Jwk(kty = "EC", kid = response.keyId())
return jwk.toJwkAdminDto()
}

fun sign(header: JWTHeader, payload: JsonObject, keyId: String): String {
val encodedHeader = Base64.getUrlEncoder().withoutPadding().encodeToString(
header.toString().toByteArray(
StandardCharsets.UTF_8
)
)
val encodedPayload = Base64.getUrlEncoder().withoutPadding()
.encodeToString(payload.toString().toByteArray(StandardCharsets.UTF_8))

val messageBytes = (encodedHeader + "." + encodedPayload).toByteArray(StandardCharsets.UTF_8)

val signingRequest = SignRequest.builder().keyId(keyId).message(SdkBytes.fromByteArray(messageBytes))
.signingAlgorithm(SigningAlgorithmSpec.ECDSA_SHA_256) // Adjust if needed
.build()

val signingResponse = kmsClient.sign(signingRequest)
val signature =
Base64.getUrlEncoder().withoutPadding().encodeToString(signingResponse.signature().asByteArray())

return encodedHeader + "." + encodedPayload + "." + signature
}

fun verify(token: String, keyId: String): Boolean {
try {
val parts = token.split(".")
if (parts.size != 3) {
return false // Invalid token format
}

val header = parts[0]
val payload = parts[1]
val signature = parts[2]

val verificationRequest = VerifyRequest.builder().keyId(keyId)
.message(SdkBytes.fromString(header + "." + payload, StandardCharsets.UTF_8))
.signature(SdkBytes.fromByteArray(Base64.getUrlDecoder().decode(signature)))
.signingAlgorithm(SigningAlgorithmSpec.ECDSA_SHA_256) // Adjust if needed
.build()

val verificationResponse = kmsClient.verify(verificationRequest)

return verificationResponse.signatureValid()
} catch (e: Exception) {
return false
}
}

private fun createKey(): String {
val request = CreateKeyRequest.builder().keyUsage(KeyUsageType.SIGN_VERIFY) // Or adjust based on your needs
.build()

val response = kmsClient.createKey(request)
return response.keyMetadata().keyId()
}
}
5 changes: 0 additions & 5 deletions modules/amazon-kms/src/main/kotlin/Main.kt

This file was deleted.

20 changes: 20 additions & 0 deletions modules/amazon-kms/src/main/kotlin/extensions/JwkExtension.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.sphereon.oid.fed.kms.amazon.extensions

import com.sphereon.oid.fed.openapi.models.Jwk
import com.sphereon.oid.fed.openapi.models.JwkAdminDTO

fun Jwk.toJwkAdminDto(): JwkAdminDTO = JwkAdminDTO(
kid = this.kid,
use = this.use,
crv = this.crv,
n = this.n,
e = this.e,
x = this.x,
y = this.y,
kty = this.kty,
alg = this.alg,
x5u = this.x5u,
x5t = this.x5t,
x5c = this.x5c,
x5tHashS256 = this.x5tS256
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LocalKms {
return sign(header = mHeader, payload = payload, key = jwkObject)
}

fun verify(token: String, jwk: Jwk): Boolean {
return verify(jwt = token, key = jwk)
fun verify(token: String, keyId: String): Boolean {
return verify(jwt = token, key = Jwk(kty = keyId))
}
}
1 change: 1 addition & 0 deletions modules/services/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ kotlin {
api(projects.modules.persistence)
api(projects.modules.openidFederationCommon)
api(projects.modules.localKms)
api(projects.modules.amazonKms)
implementation("io.ktor:ktor-serialization-kotlinx-json:2.3.11")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.sphereon.oid.fed.services

import com.sphereon.oid.fed.kms.local.AmazonKms
import com.sphereon.oid.fed.openapi.models.JWTHeader
import com.sphereon.oid.fed.openapi.models.JwkAdminDTO
import kotlinx.serialization.json.JsonObject

class AmazonKmsClient : KmsClient {

private val amazonKms = AmazonKms()

override fun generateKeyPair(): JwkAdminDTO {
return amazonKms.generateKey()
}

override fun sign(header: JWTHeader, payload: JsonObject, keyId: String): String {
return amazonKms.sign(header, payload, keyId)
}

override fun verify(token: String, keyId: String): Boolean {
return amazonKms.verify(token, keyId)
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.sphereon.oid.fed.services

import com.sphereon.oid.fed.openapi.models.JWTHeader
import com.sphereon.oid.fed.openapi.models.Jwk
import com.sphereon.oid.fed.openapi.models.JwkAdminDTO
import kotlinx.serialization.json.JsonObject

Expand All @@ -10,6 +9,7 @@ object KmsService {

private val kmsClient: KmsClient = when (provider) {
"local" -> LocalKmsClient()
"amazon" -> AmazonKmsClient()
else -> throw IllegalArgumentException("Unsupported KMS provider: $provider")
}

Expand All @@ -19,5 +19,5 @@ object KmsService {
interface KmsClient {
fun generateKeyPair(): JwkAdminDTO
fun sign(header: JWTHeader, payload: JsonObject, keyId: String): String
fun verify(token: String, jwk: Jwk): Boolean
fun verify(token: String, keyId: String): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class LocalKmsClient : KmsClient {
return localKms.sign(header, payload, keyId)
}

override fun verify(token: String, jwk: Jwk): Boolean {
return localKms.verify(token, jwk)
override fun verify(token: String, keyId: String): Boolean {
return localKms.verify(token, keyId)
}
}

0 comments on commit c3029f5

Please sign in to comment.