Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for DTLS session authentication context updates via outbound dgram #63

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ class DtlsChannelHandler @JvmOverloads constructor(
when (msg) {
is DatagramPacketWithContext -> {
write(msg, promise, ctx)
if (msg.sessionContext.sessionExpirationHint) {
promise.toCompletableFuture().thenAccept {
dtlsServer.closeSession(msg.recipient())
}
}
handleDtlsContext(msg, promise)
}
is DatagramPacket -> write(msg, promise, ctx)
is SessionAuthenticationContext -> {
Expand All @@ -114,6 +110,22 @@ class DtlsChannelHandler @JvmOverloads constructor(
}
}

private fun handleDtlsContext(msg: DatagramPacketWithContext, promise: ChannelPromise) {
val sessCtx = msg.sessionContext
if (sessCtx.sessionSuspensionHint) {
promise.toCompletableFuture().thenAccept {
dtlsServer.closeSession(msg.recipient())
}
}
if (sessCtx.authenticationContext.isNotEmpty()) {
promise.toCompletableFuture().thenAccept {
szysas marked this conversation as resolved.
Show resolved Hide resolved
sessCtx.authenticationContext.forEach { (key, value) ->
dtlsServer.putSessionAuthenticationContext(msg.recipient(), key, value)
}
}
}
}

private fun write(msg: DatagramPacket, promise: ChannelPromise, ctx: ChannelHandlerContext) {
msg.useAndRelease {
val plainContent = msg.content().nioBuffer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class EchoHandler : ChannelInboundHandlerAdapter() {
val authContext = (sessionContext.authenticationContext["AUTH"] ?: "")
val dgramContent = dgram.content().toByteArray()
val goToSleep = dgramContent.toString(Charset.defaultCharset()).endsWith(":sleep")
val newAuthContext = dgramContent.toString(Charset.defaultCharset())
.takeIf { it.startsWith("auth:") }
?.substringAfter(":")

val reply = ctx.alloc().buffer(dgramContent.size + 20)
reply.writeBytes(echoPrefix)
Expand All @@ -43,7 +46,10 @@ class EchoHandler : ChannelInboundHandlerAdapter() {
reply,
dgram.sender(),
null,
sessionContext.copy(sessionExpirationHint = goToSleep)
sessionContext.copy(
authenticationContext = newAuthContext?.let { mapOf("AUTH" to it) } ?: emptyMap(),
sessionSuspensionHint = goToSleep
)
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ class NettyTest {
client.close()
}

@Test
fun `should forward authentication context passed inside outbound datagram`() {
// connect and handshake
val client = NettyTransportAdapter.connect(clientConf, srvAddress).mapToString()

assertTrue(client.send("hi").await())
assertEquals("ECHO:hi", client.receive(5.seconds).await())

// when
assertTrue(client.send("auth:007:").await())
assertEquals("ECHO:auth:007:", client.receive(5.seconds).await())

// then
assertTrue(client.send("hi").await())
assertEquals("ECHO:007:hi", client.receive(5.seconds).await())

client.close()
}

@Test
fun `should fail to forward authentication context for non existing client`() {
assertThatThrownBy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,21 @@ class DtlsServerTransport private constructor(

when {
encPacket == null -> completedFuture(false)
packet.sessionContext.sessionExpirationHint -> {
transport.send(encPacket).thenApply { isSuccess ->
if (isSuccess) {
dtlsServer.closeSession(packet.peerAddress)
}
isSuccess
else -> transport.send(encPacket)
}.thenApply { isSuccess ->
if (!isSuccess) return@thenApply false

val sessCtx = packet.sessionContext
if (sessCtx.sessionSuspensionHint) {
dtlsServer.closeSession(packet.peerAddress)
}
if (sessCtx.authenticationContext.isNotEmpty()) {
sessCtx.authenticationContext.forEach { (key, value) ->
dtlsServer.putSessionAuthenticationContext(packet.peerAddress, key, value)
}
}
szysas marked this conversation as resolved.
Show resolved Hide resolved
else -> transport.send(encPacket)

true
}
}.thenCompose(Function.identity())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
val peerCertificateSubject: String? = null,
val cid: ByteArray? = null,
val sessionStartTimestamp: Instant? = null,
val sessionExpirationHint: Boolean = false
val sessionSuspensionHint: Boolean = false
) {
companion object {
@JvmField
Expand All @@ -45,7 +45,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
if (!cid.contentEquals(other.cid)) return false
} else if (other.cid != null) return false
if (sessionStartTimestamp != other.sessionStartTimestamp) return false
if (sessionExpirationHint != other.sessionExpirationHint) return false
if (sessionSuspensionHint != other.sessionSuspensionHint) return false

return true
}
Expand All @@ -55,7 +55,7 @@ data class DtlsSessionContext @JvmOverloads constructor(
result = 31 * result + (peerCertificateSubject?.hashCode() ?: 0)
result = 31 * result + (cid?.contentHashCode() ?: 0)
result = 31 * result + (sessionStartTimestamp?.hashCode() ?: 0)
result = 31 * result + (sessionExpirationHint.hashCode())
result = 31 * result + (sessionSuspensionHint.hashCode())
return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class DtlsServerTransportTest {
} else if (msg.startsWith("Authenticate:")) {
server.putSessionAuthenticationContext(packet.peerAddress, "auth", msg.substring(12))
server.send(Packet("OK".toByteBuffer(), packet.peerAddress))
} else if (msg.startsWith("AuthenticateWithContext:")) {
server.send(
Packet(
"OK".toByteBuffer(),
packet.peerAddress,
DtlsSessionContext(authenticationContext = mapOf("auth" to msg.substring(23)))
)
)
} else {
val ctx = (packet.sessionContext.authenticationContext["auth"] ?: "")
server.send(packet.map { "$msg:resp$ctx".toByteBuffer() })
Expand Down Expand Up @@ -479,6 +487,19 @@ class DtlsServerTransportTest {
client.close()
}

@Test
fun `should set and use session context passed inside outbound datagram`() {
server = DtlsServerTransport.create(conf, expireAfter = 100.millis, sessionStore = sessionStore, lifecycleCallbacks = sslLifecycleCallbacks).listen(echoHandler)
// client connected
val client = DtlsTransmitter.connect(server, clientConfig).await()
client.send("AuthenticateWithContext:dev-007")
assertEquals("OK", client.receiveString())
client.send("hi")
assertEquals("hi:resp:dev-007", client.receiveString())

client.close()
}

@Test
fun `server should store session if hinted to do so`() {
// given
Expand All @@ -491,7 +512,7 @@ class DtlsServerTransportTest {
assertEquals("dupa", client.receive(1.seconds).await())

client.send("sleep")
server.send(Packet("sleep".toByteBuffer(), serverReceived.await().peerAddress, sessionContext = DtlsSessionContext(sessionExpirationHint = true)))
server.send(Packet("sleep".toByteBuffer(), serverReceived.await().peerAddress, sessionContext = DtlsSessionContext(sessionSuspensionHint = true)))
assertEquals("sleep", client.receive(1.seconds).await())

await.atMost(5.seconds).untilAsserted {
Expand Down