Skip to content

Commit

Permalink
DtlsServer: Session store
Browse files Browse the repository at this point in the history
  • Loading branch information
szysas committed Jun 24, 2022
1 parent af15a45 commit 0c527c1
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 24 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ jobs:
uses: gradle/gradle-build-action@v2
with:
arguments: build -i
cache-read-only: false
6 changes: 3 additions & 3 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ import java.time.Duration.ofSeconds

class SslConfig(
private val conf: Memory,
private val cidSupplier: CidSupplier,
val cidSupplier: CidSupplier,
private val mtu: Int,
private val close: Closeable
) : Closeable by close {
Expand All @@ -88,7 +88,7 @@ class SslConfig(
return SslHandshakeContext(this, sslContext, cid, peerAddress)
}

fun loadSession(cid: ByteArray, session: ByteArray): SslSession {
fun loadSession(cid: ByteArray, session: ByteArray, peerAddress: InetSocketAddress): SslSession {
val sslContext = Memory(MbedtlsSizeOf.mbedtls_ssl_context).apply(MbedtlsApi::mbedtls_ssl_init)

mbedtls_ssl_setup(sslContext, conf).verify()
Expand All @@ -98,7 +98,7 @@ class SslConfig(
mbedtls_ssl_set_bio(sslContext, Pointer.NULL, SendCallback, null, ReceiveCallback)

return SslSession(this, sslContext, cid).also {
logger.info("Reconnected {}", it)
logger.info("[{}] Reconnected {}", peerAddress, it)
}
}

Expand Down
28 changes: 25 additions & 3 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,29 @@ import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.time.Duration

sealed interface SslContext : Closeable
sealed interface SslContext : Closeable {
companion object {
fun peekCID(size: Int, encBuffer: ByteBuffer): ByteArray? {
val pos = encBuffer.position()
if (encBuffer.remaining() < 11 + size) {
// too short
return null
}
if ((encBuffer.int shr 8) != 0x19fefd) {
// not a dtls+cid packet
encBuffer.position(pos)
return null
}

val cid = ByteArray(size)

encBuffer.position(pos + 11)
encBuffer.get(cid)
encBuffer.position(pos)
return cid
}
}
}

class SslHandshakeContext internal constructor(
private val conf: SslConfig, // keep in memory to prevent GC
Expand Down Expand Up @@ -86,8 +108,8 @@ class SslSession internal constructor(
private val cid: ByteArray?,
) : SslContext, Closeable {

val peerCid: ByteArray? by lazy { readPeerCid() }
val ownCid: ByteArray? by lazy { if (peerCid != null) cid else null }
val peerCid: ByteArray? = readPeerCid()
val ownCid: ByteArray? = if (peerCid != null) cid else null

private fun readPeerCid(): ByteArray? {
val mem = Memory(16 + 64 /* max cid len */)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.opencoap.ssl.transport
import org.opencoap.ssl.CloseNotifyException
import org.opencoap.ssl.HelloVerifyRequired
import org.opencoap.ssl.SslConfig
import org.opencoap.ssl.SslContext
import org.opencoap.ssl.SslException
import org.opencoap.ssl.SslHandshakeContext
import org.opencoap.ssl.SslSession
Expand All @@ -43,28 +44,35 @@ class DtlsServer private constructor(
private val channel: DatagramChannel,
private val sslConfig: SslConfig,
private val expireAfter: Duration,
private val executor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor { Thread(it, "dtls-srv-" + threadIndex.getAndIncrement()) }
private val sessionStore: SessionStore,
) {

init {
executor.execute { localActionPromises.set(hashMapOf()) }
}

companion object {
private val threadIndex = AtomicInteger(0)

fun create(config: SslConfig, listenPort: Int = 0, expireAfter: Duration = Duration.ofSeconds(60)): DtlsServer {
fun create(
config: SslConfig,
listenPort: Int = 0,
expireAfter: Duration = Duration.ofSeconds(60),
sessionStore: SessionStore = NoOpsSessionStore
): DtlsServer {
val channel = DatagramChannel.open().bind(InetSocketAddress("0.0.0.0", listenPort))
return DtlsServer(channel, config, expireAfter)
return DtlsServer(channel, config, expireAfter, sessionStore)
}
}

private val logger = LoggerFactory.getLogger(javaClass)
val executor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor { Thread(it, "dtls-srv-" + threadIndex.getAndIncrement()) }

// note: must be used only from executor
private val localActionPromises = ThreadLocal<MutableMap<InetSocketAddress, CompletableFuture<Action>>>()
private val actionPromises: MutableMap<InetSocketAddress, CompletableFuture<Action>>
get() = localActionPromises.get()
private val cidSize = sslConfig.cidSupplier.next().size

init {
executor.execute { localActionPromises.set(hashMapOf()) }
}

fun listen(handler: Handler): DtlsServer {
val bufPool: BlockingQueue<ByteBuffer> = ArrayBlockingQueue(1)
Expand All @@ -78,8 +86,13 @@ class DtlsServer private constructor(
if (promise != null) {
promise.complete(DecryptAction(buf))
} else {
DtlsHandshake(sslConfig.newContext(adr), adr, nonThrowingHandler)
.invoke(DecryptAction(buf))
val cid = SslContext.peekCID(cidSize, buf)
if (cid != null) {
loadSession(buf.copyDirect(), cid, adr, handler)
} else {
DtlsHandshake(sslConfig.newContext(adr), adr, nonThrowingHandler)
.invoke(DecryptAction(buf))
}
}

bufPool.put(buf) // return buffer to the pool
Expand Down Expand Up @@ -123,6 +136,25 @@ class DtlsServer private constructor(
return promise
}

private fun loadSession(encBuf: ByteBuffer, cid: ByteArray, adr: InetSocketAddress, handler: Handler) {
sessionStore.read(cid)
.thenAcceptAsync({ sessBuf ->
if (sessBuf == null) {
logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex())
} else {
DtlsSession(sslConfig.loadSession(cid, sessBuf, adr), adr, handler)
.invoke(DecryptAction(encBuf))
}
}, executor)
.whenComplete { _, ex ->
when (ex) {
null -> Unit // no error
is SslException -> logger.warn("[{}] [CID:{}] Failed to load session", ex.message)
else -> logger.error(ex.message, ex)
}
}
}

private fun Handler.decorateWithCatcher(): Handler {
return { adr: InetSocketAddress, packet: ByteArray ->
try {
Expand Down Expand Up @@ -179,19 +211,19 @@ class DtlsServer private constructor(
}

private inner class DtlsSession(private val ctx: SslSession, private val peerAddress: InetSocketAddress, private val handler: Handler) {
operator fun invoke(action: Action?, err: Throwable?) {
operator fun invoke(action: Action?, err: Throwable? = null) {
try {
when (action) {
null -> throw err!!
is DecryptAction -> decrypt(action.encPacket)
is EncryptAction -> encrypt(action.plainPacket)
is CloseAction -> {
logger.info("[{}] DTLS connection closed", peerAddress)
ctx.close()
close()
}
is TimeoutAction -> {
logger.info("[{}] DTLS connection expired", peerAddress)
ctx.close()
close()
}
}
} catch (ex: CloseNotifyException) {
Expand All @@ -206,6 +238,14 @@ class DtlsServer private constructor(
}
}

private fun close() {
if (ctx.ownCid != null) {
sessionStore.write(ctx.ownCid, ctx.saveAndClose())
} else {
ctx.close()
}
}

private fun decrypt(encPacket: ByteBuffer) {
val plainBuf = ctx.decrypt(encPacket)
receive(peerAddress).whenComplete(::invoke)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.opencoap.ssl.transport

import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletableFuture.completedFuture
import java.util.concurrent.ConcurrentHashMap

interface SessionStore {
fun read(cid: ByteArray): CompletableFuture<ByteArray?>
fun write(cid: ByteArray, session: ByteArray)
}

object NoOpsSessionStore : SessionStore {
override fun read(cid: ByteArray): CompletableFuture<ByteArray?> = completedFuture(null)
override fun write(cid: ByteArray, session: ByteArray) = Unit
}

class HashMapSessionStore : SessionStore {
private val map = ConcurrentHashMap<String, ByteArray>()

override fun read(cid: ByteArray): CompletableFuture<ByteArray?> =
completedFuture(map.remove(cid.toHex()))

override fun write(cid: ByteArray, session: ByteArray) {
map.put(cid.toHex(), session)
}

fun clear() = map.clear()
fun size() = map.size
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ internal fun <T> Executor.supply(supplier: Supplier<T>): CompletableFuture<T> {
internal fun ByteArray.toHex(): String {
return joinToString(separator = "") { eachByte -> "%02x".format(eachByte) }
}

fun ByteBuffer.copyDirect(): ByteBuffer {
val bb = ByteBuffer.allocateDirect(this.remaining())
bb.put(this)
bb.flip()
return bb
}
68 changes: 68 additions & 0 deletions kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls)
* SPDX-License-Identifier: Apache-2.0
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.opencoap.ssl

import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Test
import org.opencoap.ssl.transport.toHex
import org.opencoap.ssl.util.asByteBuffer
import org.opencoap.ssl.util.decodeHex

class SslContextTest {

@Test
fun `should peek CID from DTLS Packet`() {
val dtlsPacket =
"19fefd0001000000000001db04684e33424e42801f0e38023d243800280001000000000001a7eddd3aa34f5164499ca1fcaede85f9e77036ad66c2affb2ae9c97c5a78adb9"
.decodeHex().asByteBuffer()

val cid = SslContext.peekCID(16, dtlsPacket)

assertEquals("db04684e33424e42801f0e38023d2438", cid?.toHex())
assertEquals(0, dtlsPacket.position())
}

@Test
fun `should peek CID from DTLS Packet with different sizes`() {
assertEquals(
"db",
SslContext.peekCID(1, "19fefd0301000000000003db04684e3342".decodeHex().asByteBuffer())?.toHex()
)
assertEquals(
"db04684e",
SslContext.peekCID(4, "19fefdf001000000000001db04684e3342".decodeHex().asByteBuffer())?.toHex()
)
}

@Test
fun `should return null when not DTLS Packet`() {
assertNull(SslContext.peekCID(4, "17fefd0001000000000001db04684e3342".decodeHex().asByteBuffer()))
assertNull(SslContext.peekCID(4, "19f0fd0001000000000001db04684e3342".decodeHex().asByteBuffer()))
assertNull(SslContext.peekCID(4, "19fef00001000000000001db04684e3342".decodeHex().asByteBuffer()))
}

@Test
fun `should return null when too short DTLS Packet`() {
assertNull(
SslContext.peekCID(7, "19fefdf001000000000001db04684e3342".decodeHex().asByteBuffer())?.toHex()
)
assertNull(
SslContext.peekCID(2, "19fefd".decodeHex().asByteBuffer())?.toHex()
)
}
}
Loading

0 comments on commit 0c527c1

Please sign in to comment.