Skip to content

Commit

Permalink
Validate DTLS message before decryption (#59)
Browse files Browse the repository at this point in the history
* Validate DTLS message before decryption
  • Loading branch information
JuhaPekkaa authored Nov 6, 2024
1 parent fb0e60a commit 3d81000
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class DtlsChannelHandler @JvmOverloads constructor(

private fun loadSession(result: DtlsServer.ReceiveResult.CidSessionMissing, msg: DatagramPacket, ctx: ChannelHandlerContext) {
sessionStore.read(result.cid)
.thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid) }, ctx.executor())
.thenApplyAsync({ sessBuf -> dtlsServer.loadSession(sessBuf, msg.sender(), result.cid, msg.content().nioBuffer()) }, ctx.executor())
.whenComplete { isLoaded: Boolean?, _ ->
if (isLoaded == true) {
channelRead(ctx, msg)
Expand Down
2 changes: 2 additions & 0 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/MbedtlsApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ internal object MbedtlsApi {
external fun mbedtls_ssl_get_peer_cid(sslContext: Pointer, enabled: Pointer, peerCid: Pointer, peerCidLen: Pointer): Int
external fun mbedtls_ssl_context_save(sslContext: Pointer, buf: ByteArray, bufLen: Int, outputLen: ByteArray): Int
external fun mbedtls_ssl_context_load(sslContext: Pointer, buf: ByteArray, len: Int): Int
external fun mbedtls_ssl_check_record(sslContext: Pointer, buf: Memory, bufLen: Int): Int
external fun mbedtls_ssl_conf_ca_chain(sslConfig: Pointer, caChain: Pointer, caCrl: Pointer?)
external fun mbedtls_ssl_conf_own_cert(sslConfig: Pointer, ownCert: Memory, pkKey: Pointer): Int
external fun mbedtls_ssl_set_mtu(sslContext: Pointer, mtu: Int)
Expand All @@ -96,6 +97,7 @@ internal object MbedtlsApi {
const val MBEDTLS_SSL_TRANSPORT_DATAGRAM = 1
const val MBEDTLS_SSL_VERIFY_NONE = 0
const val MBEDTLS_SSL_VERIFY_REQUIRED = 2
const val MBEDTLS_ERR_SSL_UNEXPECTED_RECORD = -0x6700

// ----- net_sockets.h -----
val MBEDTLS_ERR_NET_RECV_FAILED = -0x004C
Expand Down
21 changes: 21 additions & 0 deletions kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/SslContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.opencoap.ssl

import com.sun.jna.Memory
import org.opencoap.ssl.MbedtlsApi.MBEDTLS_ERR_SSL_UNEXPECTED_RECORD
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_close_notify
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_context_save
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_free
Expand All @@ -27,6 +28,7 @@ import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_handshake
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_read
import org.opencoap.ssl.MbedtlsApi.mbedtls_ssl_write
import org.opencoap.ssl.MbedtlsApi.verify
import org.opencoap.ssl.transport.cloneToMemory
import org.opencoap.ssl.transport.toHex
import org.slf4j.LoggerFactory
import java.io.Closeable
Expand Down Expand Up @@ -164,6 +166,20 @@ class SslSession internal constructor(
plainBuffer.limit(size + plainBuffer.position())
}

fun checkRecord(encBuffer: ByteBuffer): VerificationResult {
val memory = encBuffer.cloneToMemory()
try {
val result = MbedtlsApi.mbedtls_ssl_check_record(sslContext, memory, memory.size().toInt())
return if (result == 0 || result != MBEDTLS_ERR_SSL_UNEXPECTED_RECORD) {
VerificationResult.Valid("Success")
} else {
VerificationResult.Invalid(SslException.from(result).localizedMessage)
}
} finally {
memory.close()
}
}

fun decrypt(encBuffer: ByteBuffer, send: (ByteBuffer) -> Unit): ByteBuffer {
val buf = ByteBuffer.allocate(encBuffer.remaining())
decrypt(encBuffer, buf, send)
Expand Down Expand Up @@ -215,4 +231,9 @@ class SslSession internal constructor(
override fun close() {
mbedtls_ssl_free(sslContext)
}

sealed interface VerificationResult {
data class Valid(val message: String) : VerificationResult
data class Invalid(val message: String) : VerificationResult
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.opencoap.ssl.transport

import com.sun.jna.Memory
import java.nio.ByteBuffer

internal fun ByteArray.toHex(): String {
Expand All @@ -29,6 +30,16 @@ fun ByteBuffer.copy(): ByteBuffer {
return bb
}

fun ByteBuffer.cloneToMemory(): Memory {
this.mark() // saves the original position
val remaining = this.remaining()
val memory = Memory(remaining.toLong())
val intermediateBuffer: ByteBuffer = memory.getByteBuffer(0, remaining.toLong())
intermediateBuffer.put(this)
this.reset()
return memory
}

fun ByteBuffer.isNotEmpty(): Boolean = this.hasRemaining()
fun ByteBuffer.isEmpty(): Boolean = !this.hasRemaining()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,12 @@ class DtlsServer(
private fun closeSession(addr: InetSocketAddress) {
sessions.remove(addr)?.apply {
storeAndClose()
logger.info("[{}] [CID:{}] DTLS session was stored", peerAddress, (this as? DtlsSession)?.sessionContext?.cid?.toHex() ?: "na")
logger.info(
"[{}] [CID:{}] DTLS session was stored",
peerAddress,
(this as? DtlsSession)?.sessionContext?.cid?.toHex()
?: "na"
)
}
}

Expand All @@ -138,18 +143,25 @@ class DtlsServer(
updateSessionAuthenticationContext(adr, ctx.authenticationContext)
}

fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray): Boolean {
fun loadSession(sessBuf: SessionWithContext?, adr: InetSocketAddress, cid: ByteArray, dtlsPacket: ByteBuffer): Boolean {
return try {
if (sessBuf == null) {
logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex())
reportMessageDrop(adr)
false
} else {
sessions[adr] = DtlsSession(sslConfig.loadSession(cid, sessBuf.sessionBlob, adr), adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp)
true
return false
}

val sslSession = sslConfig.loadSession(cid, sessBuf.sessionBlob, adr)
val verificationResult = sslSession.checkRecord(dtlsPacket)
if (verificationResult is SslSession.VerificationResult.Invalid) {
logger.warn("[{}] [CID:{}] Record verification failed: {}", adr, cid.toHex(), verificationResult.message)
reportMessageDrop(adr)
return false
}
} catch (ex: SslException) {
logger.warn("[{}] [CID:{}] Failed to load session: {}", adr, cid.toHex(), ex.message)
sessions[adr] = DtlsSession(sslSession, adr, sessBuf.authenticationContext, sessBuf.sessionStartTimestamp)
true
} catch (ex: Exception) {
logger.error("[{}] [CID:{}] DTLS failed to load session: {}", adr, cid.toHex(), ex.message)
reportMessageDrop(adr)
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class DtlsServerTransport private constructor(
val copyBuf = buf.copy()

sessionStore.read(result.cid).thenApplyAsync(
{ sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid) },
{ sessBuf -> dtlsServer.loadSession(sessBuf, adr, result.cid, copyBuf) },
executor
).thenCompose { isLoaded ->
if (isLoaded) {
Expand Down
23 changes: 23 additions & 0 deletions kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/SslContextTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,29 @@ class SslContextTest {
assertEquals("perse", serverSession.decrypt(encryptedDtls2, noSend).decodeToString())
}

@Test
fun `should check record is valid authentic and decrypt`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684))

val encryptedDtls = clientSession.encrypt("auto".toByteBuffer())

val verificationResult = serverSession.checkRecord(encryptedDtls)
assertTrue(verificationResult is SslSession.VerificationResult.Valid)
assertEquals("auto", serverSession.decrypt(encryptedDtls, noSend).decodeToString())
}

@Test
fun `should check record is invalid when record is unexpected and replayed`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
val serverSession = serverConf.loadSession(byteArrayOf(), StoredSessionPair.srvSession, localAddress(1_5684))
val encryptedDtls = clientSession.encrypt("auto".toByteBuffer())

serverSession.decrypt(encryptedDtls, noSend)
val result = serverSession.checkRecord(encryptedDtls.rewind() as ByteBuffer)
assertTrue(result is SslSession.VerificationResult.Invalid)
}

@Test
fun `should exchange data with direct byte buffer`() {
val clientSession = clientConf.loadSession(byteArrayOf(), StoredSessionPair.cliSession, localAddress(2_5684))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.opencoap.ssl.transport

import com.sun.jna.Memory
import org.junit.jupiter.api.Assertions.assertArrayEquals
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -70,4 +72,24 @@ class BytesExtensionsTest {
buf.position(2)
assertEquals("dupa", buf.decodeToString())
}

@Test
fun `should clone buffer to memory`() {
val originalData = byteArrayOf(1, 2, 3, 4, 5)
val byteBuffer = ByteBuffer.wrap(originalData)

val originalPosition = byteBuffer.position()
val originalLimit = byteBuffer.limit()
val originalCapacity = byteBuffer.capacity()

val memory: Memory = byteBuffer.cloneToMemory()

val clonedData = ByteArray(originalData.size)
memory.read(0, clonedData, 0, originalData.size)
assertArrayEquals(originalData, clonedData)

assertEquals(originalPosition, byteBuffer.position(), "Buffer position should not change")
assertEquals(originalLimit, byteBuffer.limit(), "Buffer limit should not change")
assertEquals(originalCapacity, byteBuffer.capacity(), "Buffer capacity should not change")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class DtlsServerTest {
assertTrue(dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) is ReceiveResult.CidSessionMissing)

// when
dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex())
dtlsServer.loadSession(SessionWithContext(StoredSessionPair.srvSession, mapOf(), Instant.ofEpochSecond(123456789)), localAddress(2_5684), "f935adc57425e1b214f8640d56e0c733".decodeHex(), dtlsPacket)

// then
val dtlsPacketIn = (dtlsServer.handleReceived(localAddress(2_5684), dtlsPacket) as ReceiveResult.Decrypted).packet
Expand Down

0 comments on commit 3d81000

Please sign in to comment.