Skip to content

Commit

Permalink
Added certificate support
Browse files Browse the repository at this point in the history
  • Loading branch information
szysas committed Jun 7, 2022
1 parent fdaf2cd commit f03525c
Show file tree
Hide file tree
Showing 14 changed files with 413 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/compile-mbedtls.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
paths:
- 'compileMbedtls.sh'
- '.github/workflows/compile-mbedtls.yml'

- '**/src/**/*.c'
jobs:
compile:

Expand Down
1 change: 1 addition & 0 deletions kotlin-mbedtls/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies {
testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:5.8.2")
testImplementation("org.awaitility:awaitility-kotlin:4.2.0")
testImplementation("ch.qos.logback:logback-classic:1.2.11")
testImplementation("org.bouncycastle:bcpkix-jdk15on:1.70")
}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@

package org.opencoap.ssl

import kotlin.random.Random

fun interface CidSupplier {
fun next(): ByteArray
}

object EmptyCidSupplier : CidSupplier {
override fun next(): ByteArray = byteArrayOf()
}

class RandomCidSupplier(private val size: Int) : CidSupplier {

override fun next(): ByteArray = Random.nextBytes(size)
}
13 changes: 13 additions & 0 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ internal object MbedtlsApi {
external fun mbedtls_ssl_get_peer_cid(sslContext: Pointer, enabled: Pointer, peerCid: Pointer, peerCidLen: Pointer): Int
external fun mbedtls_ssl_context_save(sslContext: Pointer, buf: Pointer, bufLen: Int, outputLen: Pointer): Int
external fun mbedtls_ssl_context_load(sslContext: Pointer, buf: Pointer, len: Int): Int
external fun mbedtls_ssl_conf_ca_chain(sslConfig: Pointer, caChain: Pointer, caCrl: Pointer?)
external fun mbedtls_ssl_conf_own_cert(sslConfig: Pointer, ownCert: Memory, pkKey: Pointer): Int
external fun mbedtls_ssl_set_mtu(sslContext: Pointer, mtu: Int)

const val MBEDTLS_ERR_SSL_TIMEOUT = -0x6800
const val MBEDTLS_ERR_SSL_WANT_READ = -0x6900
Expand Down Expand Up @@ -95,6 +98,16 @@ internal object MbedtlsApi {
// mbedtls/debug.h
external fun mbedtls_debug_set_threshold(threshold: Int)

// mbedtls/x509_crt.h
external fun mbedtls_x509_crt_init(cert: Pointer)
external fun mbedtls_x509_crt_free(cert: Pointer)
external fun mbedtls_x509_crt_parse_der(chain: Pointer, buf: ByteArray, len: Int): Int

// mbedtls/pk.h
external fun mbedtls_pk_init(ctx: Pointer)
external fun mbedtls_pk_free(ctx: Pointer)
external fun mbedtls_pk_parse_key(ctx: Pointer, key: ByteArray, keyLen: Int, pwd: Pointer?, pwdLen: Int, fRng: Pointer, pRbg: Pointer): Int

// -------------------------

internal fun Int.verify(): Int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ internal object MbedtlsSizeOf {
const val mbedtls_entropy_context = 1032L
const val mbedtls_ctr_drbg_context = 344L
const val mbedtls_ssl_context = 552L
const val mbedtls_pk_context = 16L
const val mbedtls_x509_crt = 616L
}
76 changes: 66 additions & 10 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import org.opencoap.ssl.MbedtlsApi.mbedtls_ctr_drbg_free
import org.opencoap.ssl.MbedtlsApi.mbedtls_ctr_drbg_random
import org.opencoap.ssl.MbedtlsApi.mbedtls_ctr_drbg_seed
import org.opencoap.ssl.MbedtlsApi.mbedtls_entropy_free
import org.opencoap.ssl.MbedtlsApi.mbedtls_pk_free
import org.opencoap.ssl.MbedtlsApi.mbedtls_pk_parse_key
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_authmode
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_ca_chain
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_cid
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_ciphersuites
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_dbg
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_dtls_cookies
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_min_version
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_own_cert
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_psk
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_conf_rng
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_config_defaults
Expand All @@ -37,15 +41,22 @@ import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_context_load
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_get_ciphersuite_id
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_set_bio
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_set_cid
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_set_mtu
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_set_timer_cb
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_setup
import org.opencoap.ssl.MbedtlsApi.mbedtls_x509_crt_free
import org.opencoap.ssl.MbedtlsApi.mbedtls_x509_crt_parse_der
import org.opencoap.ssl.MbedtlsApi.verify
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.security.Key
import java.security.PrivateKey
import java.security.cert.X509Certificate

class SslConfig(
private val conf: Memory,
private val cidSupplier: CidSupplier,
private val mtu: Int,
private val close: Closeable
) : Closeable by close {
private val logger = LoggerFactory.getLogger(javaClass)
Expand All @@ -58,7 +69,7 @@ class SslConfig(

val cid = cidSupplier.next()
mbedtls_ssl_set_cid(sslContext, 1, cid, cid.size).verify()

mbedtls_ssl_set_mtu(sslContext, mtu)
mbedtls_ssl_set_bio(sslContext, Pointer.NULL, SendCallback, null, ReceiveCallback)

return SslHandshakeContext(this, sslContext, cid)
Expand All @@ -83,25 +94,46 @@ class SslConfig(
@JvmStatic
@JvmOverloads
fun client(pskId: ByteArray, pskSecret: ByteArray, cipherSuites: List<String> = emptyList(), cidSupplier: CidSupplier = EmptyCidSupplier): SslConfig {
return create(false, pskId, pskSecret, cipherSuites, cidSupplier)
return create(false, pskId, pskSecret, cipherSuites, cidSupplier, listOf(), null, listOf(), true, 0)
}

@JvmStatic
@JvmOverloads
fun server(pskId: ByteArray, pskSecret: ByteArray, cipherSuites: List<String> = emptyList(), cidSupplier: CidSupplier = EmptyCidSupplier): SslConfig {
return create(true, pskId, pskSecret, cipherSuites, cidSupplier)
return create(true, pskId, pskSecret, cipherSuites, cidSupplier, listOf(), null, listOf(), true, 0)
}

@JvmStatic
@JvmOverloads
fun client(ownCertChain: List<X509Certificate> = listOf(), privateKey: PrivateKey? = null, trustedCerts: List<X509Certificate>, cipherSuites: List<String> = listOf(), cidSupplier: CidSupplier = EmptyCidSupplier, mtu: Int = 0): SslConfig {
return create(false, null, null, cipherSuites, cidSupplier, ownCertChain, privateKey, trustedCerts, true, mtu)
}

@JvmStatic
@JvmOverloads
fun server(ownCertChain: List<X509Certificate>, privateKey: PrivateKey, trustedCerts: List<X509Certificate> = listOf(), reqAuthentication: Boolean = true, cidSupplier: CidSupplier = EmptyCidSupplier, mtu: Int = 0): SslConfig {
return create(true, null, null, listOf(), cidSupplier, ownCertChain, privateKey, trustedCerts, reqAuthentication, mtu)
}

private fun create(
isServer: Boolean = false,
pskId: ByteArray,
pskSecret: ByteArray,
isServer: Boolean,
pskId: ByteArray?,
pskSecret: ByteArray?,
cipherSuites: List<String>,
cidSupplier: CidSupplier
cidSupplier: CidSupplier,
ownCertChain: List<X509Certificate>,
privateKey: Key?,
trustedCerts: List<X509Certificate>,
requiredAuthMode: Boolean = true,
mtu: Int,
): SslConfig {

val sslConfig = Memory(MbedtlsSizeOf.mbedtls_ssl_config).also(MbedtlsApi::mbedtls_ssl_config_init)
val entropy = Memory(MbedtlsSizeOf.mbedtls_entropy_context).also(MbedtlsApi::mbedtls_entropy_init)
val ctrDrbg = Memory(MbedtlsSizeOf.mbedtls_ctr_drbg_context).also(MbedtlsApi::mbedtls_ctr_drbg_init)
val ownCert = Memory(MbedtlsSizeOf.mbedtls_x509_crt).also(MbedtlsApi::mbedtls_x509_crt_init)
val caCert = Memory(MbedtlsSizeOf.mbedtls_x509_crt).also(MbedtlsApi::mbedtls_x509_crt_init)
val pkey = Memory(MbedtlsSizeOf.mbedtls_pk_context).also(MbedtlsApi::mbedtls_pk_init)

val endpointType = if (isServer) MbedtlsApi.MBEDTLS_SSL_IS_SERVER else MbedtlsApi.MBEDTLS_SSL_IS_CLIENT
mbedtls_ssl_config_defaults(sslConfig, endpointType, MbedtlsApi.MBEDTLS_SSL_TRANSPORT_DATAGRAM, MbedtlsApi.MBEDTLS_SSL_PRESET_DEFAULT).verify()
Expand All @@ -112,22 +144,46 @@ class SslConfig(
mbedtls_ssl_conf_dtls_cookies(sslConfig, null, null, null)

// PSK
mbedtls_ssl_conf_psk(sslConfig, pskSecret, pskSecret.size, pskId, pskId.size).verify()
mbedtls_ssl_conf_authmode(sslConfig, MbedtlsApi.MBEDTLS_SSL_VERIFY_REQUIRED)
if (pskSecret != null && pskId != null) {
mbedtls_ssl_conf_psk(sslConfig, pskSecret, pskSecret.size, pskId, pskId.size).verify()
}

mbedtls_ssl_conf_authmode(sslConfig, if (requiredAuthMode) MbedtlsApi.MBEDTLS_SSL_VERIFY_REQUIRED else MbedtlsApi.MBEDTLS_SSL_VERIFY_NONE)
if (cipherSuites.isNotEmpty()) {
mbedtls_ssl_conf_ciphersuites(sslConfig, mapCipherSuites(cipherSuites)).verify()
}

if (cidSupplier != EmptyCidSupplier) {
mbedtls_ssl_conf_cid(sslConfig, cidSupplier.next().size, 0)
}

// Trusted certificates
for (cert in trustedCerts) {
val certDer = cert.encoded
mbedtls_x509_crt_parse_der(caCert, certDer, certDer.size).verify()
}
mbedtls_ssl_conf_ca_chain(sslConfig, caCert, Pointer.NULL)

// Own certificate
for (cert in ownCertChain) {
val certDer = cert.encoded
mbedtls_x509_crt_parse_der(ownCert, certDer, certDer.size).verify()
}
if (privateKey != null) {
mbedtls_pk_parse_key(pkey, privateKey.encoded, privateKey.encoded.size, Pointer.NULL, 0, mbedtls_ctr_drbg_random, ctrDrbg)
mbedtls_ssl_conf_own_cert(sslConfig, ownCert, pkey)
}

// Logging
mbedtls_ssl_conf_dbg(sslConfig, LogCallback, Pointer.NULL)

return SslConfig(sslConfig, cidSupplier) {
return SslConfig(sslConfig, cidSupplier, mtu) {
mbedtls_ssl_config_free(sslConfig)
mbedtls_entropy_free(entropy)
mbedtls_ctr_drbg_free(ctrDrbg)
mbedtls_pk_free(pkey)
mbedtls_x509_crt_free(ownCert)
mbedtls_x509_crt_free(caCert)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,33 @@ import org.opencoap.ssl.SslConfig
import org.opencoap.ssl.SslContext
import org.opencoap.ssl.SslHandshakeContext
import org.opencoap.ssl.SslSession
import org.slf4j.LoggerFactory
import java.io.Closeable
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.DatagramChannel
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger

/*
DTLS transmitter based on DatagramChannel. Uses blocking calls.
*/
class DtlsTransmitter private constructor(
internal val channel: DatagramChannel,
private val sslSession: SslSession,
private val executor: ExecutorService = Executors.newSingleThreadExecutor()
) {
private val executor: ExecutorService
) : Closeable {
companion object {
private val logger = LoggerFactory.getLogger(DtlsTransmitter::class.java)

private val threadIndex = AtomicInteger(0)
private fun newSingleExecutor(): ExecutorService {
return Executors.newSingleThreadExecutor { Thread(it, "dtls-" + threadIndex.getAndIncrement()) }
}

fun connect(server: DtlsServer, conf: SslConfig): CompletableFuture<DtlsTransmitter> {
return connect(InetSocketAddress(InetAddress.getLocalHost(), server.localPort()), conf)
}
Expand All @@ -50,15 +60,18 @@ class DtlsTransmitter private constructor(
}

fun connect(dest: InetSocketAddress, conf: SslConfig, channel: DatagramChannel): CompletableFuture<DtlsTransmitter> {
val executor = Executors.newSingleThreadExecutor()
val executor = newSingleExecutor()
return executor.supply {
val sslSession = handshake(conf.newContext(), channel, dest)
DtlsTransmitter(channel, sslSession, executor)
}
}

private fun handshake(handshakeCtx: SslHandshakeContext, channel: DatagramChannel, dest: InetSocketAddress): SslSession {
val send: (ByteBuffer) -> Unit = { channel.send(it, dest) }
val send: (ByteBuffer) -> Unit = {
channel.send(it, dest)
logger.debug("[{}] DTLS handshake sent {} bytes", dest, it.position())
}

val buffer: ByteBuffer = ByteBuffer.allocateDirect(16384)
buffer.clear().flip()
Expand All @@ -68,6 +81,7 @@ class DtlsTransmitter private constructor(
buffer.clear()
channel.receive(buffer)
buffer.flip()
logger.debug("[{}] DTLS handshake recv {} bytes", dest, buffer.remaining())
sslContext = handshakeCtx.step(buffer, send)
}
return sslContext as SslSession
Expand All @@ -78,11 +92,11 @@ class DtlsTransmitter private constructor(
.bind(InetSocketAddress("0.0.0.0", bindPort))
.connect(dest)

return DtlsTransmitter(channel, sslSession, Executors.newSingleThreadExecutor())
return DtlsTransmitter(channel, sslSession, newSingleExecutor())
}
}

fun close() {
override fun close() {
channel.close()
executor.supply(sslSession::close).join()
}
Expand Down
2 changes: 2 additions & 0 deletions kotlin-mbedtls/src/test/c/mbedtls_sizeof_generator.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ int main( void )
printf(" const val mbedtls_entropy_context = %ldL\n", sizeof(mbedtls_entropy_context) );
printf(" const val mbedtls_ctr_drbg_context = %ldL\n", sizeof(mbedtls_ctr_drbg_context) );
printf(" const val mbedtls_ssl_context = %ldL\n", sizeof(mbedtls_ssl_context) );
printf(" const val mbedtls_pk_context = %ldL\n", sizeof(mbedtls_pk_context) );
printf(" const val mbedtls_x509_crt = %ldL\n", sizeof(mbedtls_x509_crt) );
printf("}\n");

}
39 changes: 39 additions & 0 deletions kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/Certs.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.opencoap.ssl.transport

import org.opencoap.ssl.util.Certificate

internal object Certs {
val root = Certificate.createRootEC("root-ca")
val server = root.signNew("server", false)
val serverChain = listOf(server, root).map(Certificate::asX509)

val rootRsa = Certificate.createRootRSA("root-ca2")
val int1 = rootRsa.signNew("intermediate-1", true)
val int2 = int1.signNew("intermediate-2", true)
val server2 = int2.signNew("server2", false)
val serverLongChain = listOf(server2, int2, int1, rootRsa).map(Certificate::asX509)

val int1a = rootRsa.signNew("intermediate-1a", true)

val dev01 = root.signNew("device01", false)
val dev01Chain = listOf(dev01.asX509(), root.asX509())

val dev99 = Certificate.createRootEC("device99")
val dev99Chain = listOf(dev99.asX509())
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import org.opencoap.ssl.RandomCidSupplier
import org.opencoap.ssl.SslConfig
import org.opencoap.ssl.SslException
import org.opencoap.ssl.util.await
Expand All @@ -35,6 +36,7 @@ class DtlsServerTest {

private val psk = Pair("dupa".encodeToByteArray(), byteArrayOf(1))
private val conf: SslConfig = SslConfig.server(psk.first, psk.second)
private val certConf = SslConfig.server(Certs.serverChain, Certs.server.privateKey, reqAuthentication = false, cidSupplier = RandomCidSupplier(16))
private val clientConfig = SslConfig.client(psk.first, psk.second)
private lateinit var server: DtlsServer
private val echoHandler: (InetSocketAddress, ByteArray) -> Unit = { adr: InetSocketAddress, packet: ByteArray ->
Expand Down Expand Up @@ -150,4 +152,29 @@ class DtlsServerTest {
assertEquals(0, cliChannel.read("aaa".toByteBuffer()))
cliChannel.close()
}

@Test
fun `should successfully handshake with certificate`() {
server = DtlsServer.create(certConf).listen(echoHandler)
val clientConf = SslConfig.client(trustedCerts = listOf(Certs.root.asX509()))

// when
val client = DtlsTransmitter.connect(server, clientConf).await()
client.send("12345")

// then
assertEquals("12345:resp", client.receiveString())
}

@Test
fun `should fail handshake when non trusted certificate`() {
server = DtlsServer.create(certConf).listen(echoHandler)
val clientConf = SslConfig.client(trustedCerts = listOf(Certs.rootRsa.asX509()))

// when
val result = runCatching { DtlsTransmitter.connect(server, clientConf).await() }

// then
assertEquals("X509 - Certificate verification failed, e.g. CRL, CA or signature check failed [-9984]", result.exceptionOrNull()?.cause?.message)
}
}
Loading

0 comments on commit f03525c

Please sign in to comment.