From 356122507f9c2b5e5679cadfaea39396090da6ba Mon Sep 17 00:00:00 2001 From: Szymon Sasin Date: Wed, 12 Oct 2022 09:37:44 +0300 Subject: [PATCH] Use simple Transport interface --- .../ssl/transport/DatagramChannelAdapter.kt | 88 +++++ .../org/opencoap/ssl/transport/DtlsServer.kt | 305 +++++++++--------- .../opencoap/ssl/transport/DtlsTransmitter.kt | 220 +++++-------- .../org/opencoap/ssl/transport/Packet.kt | 33 ++ .../org/opencoap/ssl/transport/Transport.kt | 64 ++++ .../org/opencoap/ssl/transport/extensions.kt | 39 +-- .../transport/DatagramChannelAdapterTest.kt | 97 ++++++ .../opencoap/ssl/transport/DtlsServerTest.kt | 77 +++-- .../ssl/transport/DtlsTransmitterCertTest.kt | 29 +- .../ssl/transport/DtlsTransmitterTest.kt | 16 +- .../kotlin/org/opencoap/ssl/util/Utils.kt | 4 + 11 files changed, 600 insertions(+), 372 deletions(-) create mode 100644 kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapter.kt create mode 100644 kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Packet.kt create mode 100644 kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Transport.kt create mode 100644 kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapterTest.kt diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapter.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapter.kt new file mode 100644 index 00000000..dc61e8a1 --- /dev/null +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapter.kt @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls) + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencoap.ssl.transport + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.nio.channels.DatagramChannel +import java.nio.channels.SelectionKey +import java.nio.channels.Selector +import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletableFuture.completedFuture +import java.util.concurrent.Executors + +class DatagramChannelAdapter( + private val channel: DatagramChannel, + private val buffer: ByteBuffer = ByteBuffer.allocateDirect(16384) +) : Transport { + + companion object { + fun open(port: Int = 0): Transport { + val datagramChannel = DatagramChannel.open().bind(InetSocketAddress("0.0.0.0", port)) + return DatagramChannelAdapter(datagramChannel) + } + + fun connect(dest: InetSocketAddress, listenPort: Int = 0): Transport { + val channel: DatagramChannel = DatagramChannel.open() + if (listenPort > 0) channel.bind(InetSocketAddress("0.0.0.0", listenPort)) + channel.connect(dest) + + return DatagramChannelAdapter(channel).map(ByteBufferPacket::buffer) { ByteBufferPacket(it, dest) } + } + } + + private val selector: Selector = Selector.open() + private val port get() = (channel.localAddress as InetSocketAddress).port + private val executor = Executors.newSingleThreadExecutor { Thread(it, "udp-io (:$port)") } + + init { + channel.configureBlocking(false) + channel.register(selector, SelectionKey.OP_READ) + } + + override fun receive(timeout: Duration): CompletableFuture { + return executor.supply { + selector.select(timeout.toMillis()) + buffer.clear() + + val sourceAddress = channel.receive(buffer) + if (sourceAddress == null) { + Packet.EmptyByteBufferPacket + } else { + buffer.flip() + Packet(buffer, sourceAddress as InetSocketAddress) + } + } + } + + override fun send(packet: Packet): CompletableFuture { + return try { + completedFuture(channel.send(packet.buffer, packet.peerAddress) > 0) + } catch (ex: Exception) { + CompletableFuture().also { it.completeExceptionally(ex) } + } + } + + override fun close() { + channel.close() + selector.close() + executor.shutdown() + } + + override fun localPort() = (channel.localAddress as InetSocketAddress).port +} diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt index 9dc1c7fe..b936dec5 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsServer.kt @@ -26,12 +26,10 @@ import org.opencoap.ssl.SslSession import org.slf4j.LoggerFactory import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.nio.channels.DatagramChannel import java.time.Duration -import java.util.concurrent.ArrayBlockingQueue -import java.util.concurrent.BlockingQueue import java.util.concurrent.CompletableFuture -import java.util.concurrent.CompletionStage +import java.util.concurrent.CompletableFuture.completedFuture +import java.util.concurrent.ScheduledFuture import java.util.concurrent.ScheduledThreadPoolExecutor import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger @@ -40,13 +38,15 @@ import java.util.concurrent.atomic.AtomicInteger Single threaded dtls server on top of DatagramChannel. */ class DtlsServer private constructor( - private val channel: DatagramChannel, + private val transport: Transport, private val sslConfig: SslConfig, private val expireAfter: Duration, private val sessionStore: SessionStore, -) { +) : Transport { companion object { + private val EMPTY_BUFFER = ByteBuffer.allocate(0) + private val threadIndex = AtomicInteger(0) @JvmStatic @@ -57,7 +57,7 @@ class DtlsServer private constructor( expireAfter: Duration = Duration.ofSeconds(60), sessionStore: SessionStore = NoOpsSessionStore ): DtlsServer { - val channel = DatagramChannel.open().bind(InetSocketAddress("0.0.0.0", listenPort)) + val channel = DatagramChannelAdapter.open(listenPort) return DtlsServer(channel, config, expireAfter, sessionStore) } } @@ -66,214 +66,209 @@ class DtlsServer private constructor( val executor = ScheduledThreadPoolExecutor(1) { r: Runnable -> Thread(r, "dtls-srv-" + threadIndex.getAndIncrement()) } // note: must be used only from executor - private val localActionPromises = ThreadLocal>>() - private val actionPromises: MutableMap> - get() = localActionPromises.get() + private val sessions = mutableMapOf() private val cidSize = sslConfig.cidSupplier.next().size - init { - executor.execute { localActionPromises.set(hashMapOf()) } + override fun receive(timeout: Duration): CompletableFuture { + return transport.receive(timeout).thenComposeAsync({ packet -> + if (packet == Packet.EmptyByteBufferPacket) return@thenComposeAsync completedFuture(Packet.EmptyBytesPacket) + + val adr: InetSocketAddress = packet.peerAddress + val buf: ByteBuffer = packet.buffer + + handleReceived(adr, buf, timeout) + }, executor) } - fun listen(handler: Handler): DtlsServer { - val bufPool: BlockingQueue = ArrayBlockingQueue(1) - bufPool.put(ByteBuffer.allocateDirect(16384)) - - val nonThrowingHandler = handler.decorateWithCatcher() - channel.listen(bufPool) { adr: InetSocketAddress, buf: ByteBuffer -> - // need to handle incoming message in executor for thread safety - executor.execute { - val promise = actionPromises.remove(adr) - if (promise != null) { - promise.complete(DecryptAction(buf)) - } else { - val cid = SslContext.peekCID(cidSize, buf) - if (cid != null) { - loadSession(buf.copyDirect(), cid, adr, handler) + private fun handleReceived(adr: InetSocketAddress, buf: ByteBuffer, timeout: Duration): CompletableFuture { + val cid by lazy { SslContext.peekCID(cidSize, buf) } + val dtlsState = sessions[adr] + + return when { + dtlsState is DtlsHandshake -> { + dtlsState.step(buf) + receive(timeout) + } + + dtlsState is DtlsSession -> { + val plainBytes = dtlsState.decrypt(buf) + if (plainBytes.isNotEmpty()) + completedFuture(Packet(plainBytes, adr)) + else + receive(timeout) + } + + // no session, but dtls packet contains CID + cid != null -> { + val copyBuf = buf.copyDirect() + loadSession(cid!!, adr).thenCompose { isLoaded -> + if (isLoaded) { + handleReceived(adr, copyBuf, timeout) } else { - DtlsHandshake(sslConfig.newContext(adr), adr, nonThrowingHandler) - .invoke(DecryptAction(buf)) + receive(timeout) } } + } - bufPool.put(buf) // return buffer to the pool + // new handshake + else -> { + sessions[adr] = DtlsHandshake(sslConfig.newContext(adr), adr) + handleReceived(adr, buf, timeout) } } - return this } - fun send(data: ByteArray, target: InetSocketAddress): CompletableFuture = executor.supply { - val promise = actionPromises.remove(target) - promise?.complete(EncryptAction(data)) ?: false + override fun send(packet: Packet): CompletableFuture = executor.supply { + when (val dtlsState = sessions[packet.peerAddress]) { + is DtlsSession -> { + transport.send(packet.map(dtlsState::encrypt)) + true + } + + else -> false + } } - fun numberOfSessions(): Int = executor.supply { actionPromises.size }.join() - val localAddress: InetSocketAddress - get() = channel.localAddress as InetSocketAddress + fun numberOfSessions(): Int = executor.supply { sessions.size }.join() - fun localPort() = localAddress.port + override fun localPort() = transport.localPort() - fun close() { + override fun close() { executor.supply { - channel.close() - val iterator = actionPromises.iterator() + transport.close() + + val iterator = sessions.iterator() while (iterator.hasNext()) { - val promise = iterator.next().value + val dtlsState = iterator.next().value iterator.remove() - promise.complete(CloseAction) + dtlsState.storeAndClose() } }.get(30, TimeUnit.SECONDS) executor.shutdown() } - private fun receive(peerAddress: InetSocketAddress, timeout: Duration = expireAfter, timeoutAction: Action = TimeoutAction): CompletionStage { - val timeoutMillis = if (timeout.isZero) expireAfter.toMillis() else timeout.toMillis() - - val promise = CompletableFuture() - val scheduledFuture = executor.schedule({ promise.complete(timeoutAction) }, timeoutMillis, TimeUnit.MILLISECONDS) - actionPromises.put(peerAddress, promise)?.cancel(false) - - promise.thenRun { - scheduledFuture.cancel(true) - actionPromises.remove(peerAddress, promise) - } - return promise - } - - private fun loadSession(encBuf: ByteBuffer, cid: ByteArray, adr: InetSocketAddress, handler: Handler) { - sessionStore.read(cid) - .thenAcceptAsync({ sessBuf -> - if (sessBuf == null) { - logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex()) - } else { - DtlsSession(sslConfig.loadSession(cid, sessBuf, adr), adr, handler) - .invoke(DecryptAction(encBuf)) + private fun loadSession(cid: ByteArray, adr: InetSocketAddress): CompletableFuture { + return sessionStore.read(cid) + .thenApplyAsync({ sessBuf -> + try { + if (sessBuf == null) { + logger.warn("[{}] [CID:{}] DTLS session not found", adr, cid.toHex()) + false + } else { + sessions[adr] = DtlsSession(sslConfig.loadSession(cid, sessBuf, adr), adr) + true + } + } catch (ex: SslException) { + logger.warn("[{}] [CID:{}] Failed to load session: {}", adr, cid.toHex(), ex.message) + false } }, executor) - .whenComplete { _, ex -> - when (ex) { - null -> Unit // no error - is SslException -> logger.warn("[{}] [CID:{}] Failed to load session: {}", adr, cid.toHex(), ex.message) - else -> logger.error(ex.message, ex) - } - } } - private fun Handler.decorateWithCatcher(): Handler { - return object : Handler { - override fun invoke(adr: InetSocketAddress, packet: ByteArray) { - try { - this@decorateWithCatcher(adr, packet) - } catch (ex: Exception) { - logger.error(ex.toString(), ex) - } - } - } + private fun DtlsState.closeAndRemove() { + sessions.remove(this.peerAddress, this) + } + + private sealed class DtlsState(val peerAddress: InetSocketAddress) { + protected var scheduledTask: ScheduledFuture<*>? = null + + abstract fun storeAndClose() } - private inner class DtlsHandshake(private val ctx: SslHandshakeContext, private val peerAddress: InetSocketAddress, private val handler: Handler) { + private inner class DtlsHandshake( + private val ctx: SslHandshakeContext, + peerAddress: InetSocketAddress + ) : DtlsState(peerAddress) { + private fun send(buf: ByteBuffer) { - channel.send(buf, peerAddress) + transport.send(Packet(buf, peerAddress)) } - operator fun invoke(action: Action?, err: Throwable? = null) { + private fun retryStep() = step(EMPTY_BUFFER) + + fun step(encPacket: ByteBuffer) { + scheduledTask?.cancel(false) + try { - when (action) { - is DecryptAction -> stepHandshake(action.encPacket) - is EncryptAction -> return - is CloseAction -> ctx.close() - is TimeoutAction -> { - logger.warn("[{}] DTLS handshake expired", peerAddress) - ctx.close() + when (val newCtx = ctx.step(encPacket, ::send)) { + is SslHandshakeContext -> { + scheduledTask = if (!newCtx.readTimeout.isZero) { + executor.schedule(::retryStep, newCtx.readTimeout) + } else { + executor.schedule(::timeout, expireAfter) + } } - null -> throw err!! + + is SslSession -> + sessions[peerAddress] = DtlsSession(newCtx, peerAddress) } } catch (ex: HelloVerifyRequired) { - ctx.close() + closeAndRemove() } catch (ex: SslException) { logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message) - ctx.close() + closeAndRemove() } catch (ex: Exception) { logger.error(ex.toString(), ex) - ctx.close() + closeAndRemove() } } - private fun stepHandshake(encPacket: ByteBuffer) { - when (val newCtx = ctx.step(encPacket, ::send)) { - is SslHandshakeContext -> { - if (!newCtx.readTimeout.isZero) - receive(peerAddress, newCtx.readTimeout, EmptyDecryptAction).whenComplete(::invoke) - else - receive(peerAddress).whenComplete(::invoke) - } + fun timeout() { + closeAndRemove() + logger.warn("[{}] DTLS handshake expired", peerAddress) + } + + override fun storeAndClose() { + ctx.close() + } + } + + private inner class DtlsSession( + private val ctx: SslSession, + peerAddress: InetSocketAddress + ) : DtlsState(peerAddress) { - is SslSession -> { - val dtlsSession = DtlsSession(newCtx, peerAddress, handler) - receive(peerAddress).whenComplete(dtlsSession::invoke) + override fun storeAndClose() { + if (ctx.ownCid != null) { + try { + sessionStore.write(ctx.ownCid, ctx.saveAndClose()) + } catch (ex: SslException) { + logger.warn("[{}] DTLS failed to store session: {}", peerAddress, ex.message) } + } else { + ctx.close() } } - } - private inner class DtlsSession(private val ctx: SslSession, private val peerAddress: InetSocketAddress, private val handler: Handler) { - operator fun invoke(action: Action?, err: Throwable? = null) { + fun decrypt(encPacket: ByteBuffer): ByteArray { + scheduledTask?.cancel(false) try { - when (action) { - null -> throw err!! - is DecryptAction -> decrypt(action.encPacket) - is EncryptAction -> encrypt(action.plainPacket) - is CloseAction -> { - logger.info("[{}] DTLS connection closed", peerAddress) - close() - } - is TimeoutAction -> { - logger.info("[{}] DTLS connection expired", peerAddress) - close() - } - } + val plainBuf = ctx.decrypt(encPacket) + scheduledTask = executor.schedule(::timeout, expireAfter) + return plainBuf } catch (ex: CloseNotifyException) { logger.info("[{}] DTLS received close notify", peerAddress) - ctx.close() } catch (ex: SslException) { logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message) - ctx.close() - } catch (ex: Throwable) { - logger.error(ex.message, ex) - ctx.close() } + closeAndRemove() + return byteArrayOf() } - private fun close() { - if (ctx.ownCid != null) { - sessionStore.write(ctx.ownCid, ctx.saveAndClose()) - } else { - ctx.close() + fun encrypt(plainPacket: ByteArray): ByteBuffer { + try { + return ctx.encrypt(plainPacket) + } catch (ex: SslException) { + logger.warn("[{}] DTLS failed: {}", peerAddress, ex.message) + closeAndRemove() + throw ex } } - private fun decrypt(encPacket: ByteBuffer) { - val plainBuf = ctx.decrypt(encPacket) - receive(peerAddress).whenComplete(::invoke) - handler(peerAddress, plainBuf) - } - - private fun encrypt(plainPacket: ByteArray) { - val encBuf = ctx.encrypt(plainPacket) - receive(peerAddress).whenComplete(::invoke) - channel.send(encBuf, peerAddress) + fun timeout() { + sessions.remove(peerAddress, this) + logger.info("[{}] DTLS connection expired", peerAddress) + storeAndClose() } } - - private sealed interface Action - private open class DecryptAction(val encPacket: ByteBuffer) : Action - private object EmptyDecryptAction : DecryptAction(ByteBuffer.allocate(0)) - private class EncryptAction(val plainPacket: ByteArray) : Action - private object CloseAction : Action - private object TimeoutAction : Action -} - -interface Handler { - @Throws(Exception::class) - operator fun invoke(adr: InetSocketAddress, packet: ByteArray) } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsTransmitter.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsTransmitter.kt index 15c084a3..bf713bf6 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsTransmitter.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/DtlsTransmitter.kt @@ -18,129 +18,54 @@ package org.opencoap.ssl.transport import org.opencoap.ssl.HelloVerifyRequired import org.opencoap.ssl.SslConfig -import org.opencoap.ssl.SslContext import org.opencoap.ssl.SslHandshakeContext import org.opencoap.ssl.SslSession -import java.io.Closeable import java.net.InetAddress import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.nio.channels.DatagramChannel -import java.nio.channels.SelectionKey -import java.nio.channels.Selector import java.time.Duration import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletableFuture.completedFuture import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger /* -DTLS transmitter based on DatagramChannel. Uses blocking calls. +Single DTLS connection, transmitter. Can be server or client mode. */ class DtlsTransmitter private constructor( - internal val cnnTrans: ConnectedDatagramTransmitter, + val remoteAddress: InetSocketAddress, + internal val transport: Transport, private val sslSession: SslSession, - private val executor: ExecutorService -) : Closeable { - companion object { - private val threadIndex = AtomicInteger(0) - internal fun newSingleExecutor(): ExecutorService { - return Executors.newSingleThreadExecutor { Thread(it, "dtls-" + threadIndex.getAndIncrement()) } - } - - @JvmStatic - @JvmOverloads - fun connect(server: DtlsServer, conf: SslConfig): CompletableFuture { - return connect(InetSocketAddress(InetAddress.getLocalHost(), server.localPort()), conf) - } - - @JvmStatic - @JvmOverloads - fun connect(peerCnnTrans: ConnectedDatagramTransmitter, conf: SslConfig, bindPort: Int = 0): CompletableFuture { - return connect(peerCnnTrans.localAddress(), conf, bindPort) - } - - @JvmStatic - @JvmOverloads - fun connect(dest: InetSocketAddress, conf: SslConfig, bindPort: Int = 0): CompletableFuture { - return connect(conf, ConnectedDatagramTransmitter.connect(dest, bindPort)) - } - - @JvmStatic - @JvmOverloads - fun connect(conf: SslConfig, channel: ConnectedDatagramTransmitter, executor: ExecutorService = newSingleExecutor()): CompletableFuture { - return executor.supply { - connect0(conf, channel, executor) - } - } - - private fun connect0(conf: SslConfig, trans: ConnectedDatagramTransmitter, executor: ExecutorService): DtlsTransmitter { - val sslHandshakeContext = conf.newContext(trans.remoteAddress()) - return try { - val sslSession = handshake(sslHandshakeContext, trans) - DtlsTransmitter(trans, sslSession, executor) - } catch (ex: HelloVerifyRequired) { - sslHandshakeContext.close() - connect0(conf, trans, executor) - } catch (ex: Exception) { - sslHandshakeContext.close() - trans.close() - throw ex - } - } - - private fun handshake(handshakeCtx: SslHandshakeContext, trans: ConnectedDatagramTransmitter): SslSession { - val buffer: ByteBuffer = ByteBuffer.allocateDirect(16384) + private val executor: ExecutorService, +) : Transport { - var sslContext: SslContext = handshakeCtx.step(trans::send) - while (sslContext is SslHandshakeContext) { - trans.receive(buffer, sslContext.readTimeout) - sslContext = handshakeCtx.step(buffer, trans::send) - } - return sslContext as SslSession - } - - @JvmStatic - @JvmOverloads - fun create(dest: InetSocketAddress, sslSession: SslSession, bindPort: Int = 0): DtlsTransmitter { - return create(sslSession, ConnectedDatagramTransmitter.connect(dest, bindPort)) - } - - @JvmStatic - @JvmOverloads - fun create(sslSession: SslSession, cnnTransmitter: ConnectedDatagramTransmitter): DtlsTransmitter { - return DtlsTransmitter(cnnTransmitter, sslSession, newSingleExecutor()) - } + override fun send(packet: ByteArray): CompletableFuture { + return executor + .supply { sslSession.encrypt(packet) } + .thenCompose(transport::send) } - override fun close() { - cnnTrans.close() - executor.supply(sslSession::close).join() - } + fun send(text: String) = send(text.encodeToByteArray()) - fun send(data: ByteArray) { - executor.supply { cnnTrans.send(sslSession.encrypt(data)) }.join() + override fun receive(timeout: Duration): CompletableFuture { + return transport.receive(timeout).thenApplyAsync({ + if (it.remaining() == 0) byteArrayOf() else sslSession.decrypt(it) + }, executor) } - fun send(text: String) = send(text.encodeToByteArray()) - - @JvmOverloads - fun receive(timeout: Duration = Duration.ofSeconds(30)): ByteArray { - val buffer: ByteBuffer = ByteBuffer.allocateDirect(16384) + fun receive() = receive(Duration.ofSeconds(30)) - cnnTrans.receive(buffer, timeout) - if (!buffer.hasRemaining()) { - return byteArrayOf() - } + override fun localPort() = transport.localPort() - return executor.supply { sslSession.decrypt(buffer) }.join() + override fun close() { + transport.close() + executor.supply(sslSession::close).join() } - fun receiveString(): String = receive().decodeToString() - fun closeNotify() { executor.supply { - cnnTrans.send(sslSession.closeNotify()) + transport.send(sslSession.closeNotify()) }.join() close() } @@ -149,53 +74,80 @@ class DtlsTransmitter private constructor( fun getPeerCid() = sslSession.peerCid fun getOwnCid() = sslSession.ownCid fun saveSession() = sslSession.saveAndClose() - val remoteAddress: InetSocketAddress - get() = cnnTrans.remoteAddress() -} - -interface ConnectedDatagramTransmitter : Closeable { - fun send(buf: ByteBuffer) - fun receive(buf: ByteBuffer, timeout: Duration) - fun localAddress(): InetSocketAddress - fun remoteAddress(): InetSocketAddress companion object { + private val threadIndex = AtomicInteger(0) + internal fun newSingleExecutor(): ExecutorService { + return Executors.newSingleThreadExecutor { Thread(it, "dtls-" + threadIndex.getAndIncrement()) } + } + @JvmStatic @JvmOverloads - fun connect(dest: InetSocketAddress, listenPort: Int = 0): ConnectedDatagramTransmitter { - val channel: DatagramChannel = DatagramChannel.open() - if (listenPort > 0) channel.bind(InetSocketAddress("0.0.0.0", listenPort)) - channel.connect(dest) - channel.configureBlocking(false) - val selector: Selector = Selector.open() - channel.register(selector, SelectionKey.OP_READ) - - return ConnectedDatagramTransmitterImpl(channel, selector) + fun connect(server: Transport<*>, conf: SslConfig, bindPort: Int = 0): CompletableFuture { + return connect(InetSocketAddress(InetAddress.getLocalHost(), server.localPort()), conf, bindPort) } - } -} -class ConnectedDatagramTransmitterImpl( - private val channel: DatagramChannel, - private val selector: Selector -) : ConnectedDatagramTransmitter { - init { - require(channel.isConnected) - } + @JvmStatic + @JvmOverloads + fun connect(dest: InetSocketAddress, conf: SslConfig, bindPort: Int = 0): CompletableFuture { + return connect(dest, conf, DatagramChannelAdapter.connect(dest, bindPort)) + } - override fun send(buf: ByteBuffer) { - channel.write(buf) - } + @JvmStatic + @JvmOverloads + fun connect(dest: InetSocketAddress, conf: SslConfig, trans: Transport, executor: ExecutorService = newSingleExecutor()): CompletableFuture { + val promise = CompletableFuture() + val sslHandshakeContext = conf.newContext(dest) + val send: (ByteBuffer) -> Unit = { trans.send(it) } + + fun handleReceive(buffer: ByteBuffer): CompletableFuture { + val newSslContext = sslHandshakeContext.step(buffer, send) + + return when (newSslContext) { + is SslSession -> completedFuture(newSslContext) + is SslHandshakeContext -> { + val timeout = if (newSslContext.readTimeout.isZero) Duration.ofSeconds(1) else newSslContext.readTimeout + trans.receive(timeout).thenComposeAsync(::handleReceive, executor) + } + } + } - override fun receive(buf: ByteBuffer, timeout: Duration) { - channel.receive(buf, selector, timeout) - } + val sslContext: SslHandshakeContext = sslHandshakeContext.step(send) as SslHandshakeContext + trans.receive(sslContext.readTimeout) + .thenComposeAsync(::handleReceive, executor) + .whenComplete { sslSession, ex -> + when (ex?.cause) { + null -> promise.complete(DtlsTransmitter(dest, trans, sslSession, executor)) + + is HelloVerifyRequired -> { + sslHandshakeContext.close() + connect(dest, conf, trans, executor).whenComplete { t, ex2 -> + if (ex2 != null) promise.completeExceptionally(ex2) + else promise.complete(t) + } + } + + else -> { + sslHandshakeContext.close() + trans.close() + promise.completeExceptionally(ex) + } + } + } + + return promise + } - override fun localAddress() = channel.localAddress as InetSocketAddress - override fun remoteAddress() = channel.remoteAddress as InetSocketAddress + @JvmStatic + @JvmOverloads + fun create(dest: InetSocketAddress, sslSession: SslSession, bindPort: Int = 0): DtlsTransmitter { + return create(dest, sslSession, DatagramChannelAdapter.connect(dest, bindPort)) + } - override fun close() { - selector.close() - channel.close() + @JvmStatic + @JvmOverloads + fun create(dest: InetSocketAddress, sslSession: SslSession, cnnTransmitter: Transport): DtlsTransmitter { + return DtlsTransmitter(dest, cnnTransmitter, sslSession, newSingleExecutor()) + } } } diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Packet.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Packet.kt new file mode 100644 index 00000000..e7b29680 --- /dev/null +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Packet.kt @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls) + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencoap.ssl.transport + +import java.net.InetSocketAddress +import java.nio.ByteBuffer + +data class Packet(val buffer: T, val peerAddress: InetSocketAddress) { + fun map(f: (T) -> T2): Packet = Packet(f(buffer), peerAddress) + + companion object { + val EmptyByteBufferPacket: ByteBufferPacket = Packet(ByteBuffer.allocate(0), InetSocketAddress.createUnresolved("", 0)) + val EmptyBytesPacket: BytesPacket = Packet(byteArrayOf(), InetSocketAddress.createUnresolved("", 0)) + } +} + +typealias ByteBufferPacket = Packet + +typealias BytesPacket = Packet diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Transport.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Transport.kt new file mode 100644 index 00000000..b9b6adc5 --- /dev/null +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/Transport.kt @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls) + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencoap.ssl.transport + +import org.slf4j.LoggerFactory +import java.io.Closeable +import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.function.Consumer + +interface Transport

: Closeable { + fun receive(timeout: Duration): CompletableFuture

+ fun send(packet: P): CompletableFuture + fun localPort(): Int + + fun map(f: (P) -> P2, f2: (P2) -> P): Transport { + val underlying = this + return object : Transport { + override fun receive(timeout: Duration) = underlying.receive(timeout).thenApply(f::invoke) + override fun send(packet: P2): CompletableFuture = underlying.send(f2.invoke(packet)) + override fun localPort(): Int = underlying.localPort() + override fun close() = underlying.close() + } + } +} + +fun > T.listen(handler: Consumer

): T { + val logger = LoggerFactory.getLogger(javaClass) + + fun handle(packet: P?, err: Throwable?) { + if (err != null) { + logger.warn("Listener stopped: {}", err.message) + return + } + + if (packet != null) { + try { + handler.accept(packet) + } catch (ex: Exception) { + logger.error(ex.toString(), ex) + } + } + // continue + receive(Duration.ofSeconds(5)).whenComplete(::handle) + } + + // start loop + receive(Duration.ofSeconds(5)).whenComplete(::handle) + return this +} diff --git a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/extensions.kt b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/extensions.kt index e721731c..ff0e69b3 100644 --- a/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/extensions.kt +++ b/kotlin-mbedtls/src/main/kotlin/org/opencoap/ssl/transport/extensions.kt @@ -16,48 +16,23 @@ package org.opencoap.ssl.transport -import org.slf4j.LoggerFactory -import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.nio.channels.ClosedChannelException -import java.nio.channels.DatagramChannel -import java.nio.channels.Selector import java.time.Duration -import java.util.concurrent.BlockingQueue import java.util.concurrent.CompletableFuture import java.util.concurrent.Executor +import java.util.concurrent.ScheduledExecutorService +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit import java.util.function.Supplier -internal fun DatagramChannel.listen(bufPool: BlockingQueue, handler: (InetSocketAddress, ByteBuffer) -> Unit) { - val task = Runnable { - try { - while (this.isOpen) { - val buffer = bufPool.take() - buffer.clear() - val peerAddress = this.receive(buffer) as InetSocketAddress - buffer.flip() - handler.invoke(peerAddress, buffer) - } - } catch (ex: Exception) { - if (ex !is ClosedChannelException) LoggerFactory.getLogger(javaClass).error(ex.toString(), ex) - } - } - - Thread(task, "udp-io (:" + (localAddress as InetSocketAddress).port + ")").start() -} - -internal fun DatagramChannel.receive(buffer: ByteBuffer, selector: Selector, timeout: Duration): InetSocketAddress? { - buffer.clear() - selector.select(timeout.toMillis()) - val sourceAddress = this.receive(buffer) as? InetSocketAddress - buffer.flip() - return sourceAddress -} - internal fun Executor.supply(supplier: Supplier): CompletableFuture { return CompletableFuture.supplyAsync(supplier, this) } +internal fun ScheduledExecutorService.schedule(task: Runnable, delay: Duration): ScheduledFuture<*> { + return this.schedule(task, delay.toMillis(), TimeUnit.MILLISECONDS) +} + internal fun ByteArray.toHex(): String { return joinToString(separator = "") { eachByte -> "%02x".format(eachByte) } } diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapterTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapterTest.kt new file mode 100644 index 00000000..c218dd03 --- /dev/null +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DatagramChannelAdapterTest.kt @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2022 kotlin-mbedtls contributors (https://github.com/open-coap/kotlin-mbedtls) + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.opencoap.ssl.transport + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.opencoap.ssl.util.await +import org.opencoap.ssl.util.decodeToString +import org.opencoap.ssl.util.localAddress +import org.opencoap.ssl.util.millis +import org.opencoap.ssl.util.seconds +import org.opencoap.ssl.util.toByteBuffer +import java.nio.ByteBuffer +import java.util.concurrent.CompletionException +import java.util.concurrent.RejectedExecutionException + +class DatagramChannelAdapterTest { + + @Test + fun sendAndReceive() { + val trans1 = DatagramChannelAdapter.open() + val trans2 = DatagramChannelAdapter.open() + + // when + assertTrue(trans1.send(Packet("dupa".toByteBuffer(), trans2.localAddress())).join()) + + // then + val resp = trans2.receive(1.seconds).join() + assertEquals("dupa", resp?.buffer?.decodeToString()) + + trans1.close() + trans2.close() + } + + @Test + fun cancelWhenClosing() { + val trans = DatagramChannelAdapter.open() + val received = trans.receive(1.seconds) + assertFalse(received.isDone) + + // when + trans.close() + + // then + assertThrows { received.join() } + assertThrows { trans.receive(1.seconds) } + } + + @Test + fun timeout() { + val trans = DatagramChannelAdapter.open() + + // when + val received = trans.receive(1.millis) + + // then + assertEquals(Packet.EmptyByteBufferPacket, received.await()) + trans.close() + } + + @Test + fun listen() { + val trans = DatagramChannelAdapter.open() + trans.listen { packet -> + trans.send(packet.map(ByteBuffer::decodeToString).map { "echo:$it" }.map(String::toByteBuffer)) + } + val cli = DatagramChannelAdapter.open() + + for (i in 1..10) { + cli.send(Packet("$i:dupa".toByteBuffer(), trans.localAddress())) + } + + for (i in 1..10) { + assertEquals("echo:$i:dupa", cli.receive(1.seconds).await()!!.buffer.decodeToString()) + } + + trans.close() + cli.close() + } +} diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt index cf5c6535..4d569376 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsServerTest.kt @@ -19,6 +19,7 @@ package org.opencoap.ssl.transport import org.awaitility.kotlin.await import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test import org.opencoap.ssl.EmptyCidSupplier @@ -36,8 +37,11 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.DatagramChannel import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletableFuture.completedFuture import java.util.concurrent.ScheduledThreadPoolExecutor import java.util.concurrent.TimeUnit +import java.util.function.Consumer import kotlin.random.Random class DtlsServerTest { @@ -53,13 +57,11 @@ class DtlsServerTest { private lateinit var server: DtlsServer - private val echoHandler: Handler = object : Handler { - override fun invoke(adr: InetSocketAddress, packet: ByteArray) { - if (packet.decodeToString() == "error") { - throw Exception("error") - } else { - server.send(packet.plus(":resp".encodeToByteArray()), adr) - } + private val echoHandler: Consumer = Consumer { packet -> + if (packet.buffer.decodeToString() == "error") { + throw Exception("error") + } else { + server.send(packet.map { it.plus(":resp".encodeToByteArray()) }) } } @@ -75,13 +77,19 @@ class DtlsServerTest { @Test fun testSingleConnection() { - server = DtlsServer.create(conf).listen(echoHandler) + server = DtlsServer.create(conf) + val receive = server.receive(2.seconds) val client = DtlsTransmitter.connect(server, clientConfig).await() + client.send("hi") + assertEquals("hi", receive.await().buffer.decodeToString()) + server.send(Packet("czesc".encodeToByteArray(), receive.await().peerAddress)) + assertEquals("czesc", client.receiveString()) + repeat(5) { i -> client.send("perse$i") - assertEquals("perse$i:resp", client.receiveString()) + assertEquals("perse$i", server.receive(1.seconds).await().buffer.decodeToString()) } assertEquals(1, server.numberOfSessions()) @@ -98,8 +106,8 @@ class DtlsServerTest { val clients = (1..MAX) .map { - val ch = ConnectedDatagramTransmitter.connect(localAddress(server.localPort()), 0) - DtlsTransmitter.connect(clientCertConf, ch, executors[it % executors.size]) + val ch = DatagramChannelAdapter.connect(localAddress(server.localPort()), 0) + DtlsTransmitter.connect(localAddress(server.localPort()), clientCertConf, ch, executors[it % executors.size]) }.map { it.get(30, TimeUnit.SECONDS) }.map { client -> @@ -116,7 +124,8 @@ class DtlsServerTest { @Test fun testFailedHandshake() { // given - server = DtlsServer.create(conf).listen(echoHandler) + server = DtlsServer.create(conf) + val srvReceive = server.receive(2.seconds) val clientFut = DtlsTransmitter.connect(server, SslConfig.client(psk.first, byteArrayOf(-128))) // when @@ -126,6 +135,7 @@ class DtlsServerTest { await.untilAsserted { assertEquals(0, server.numberOfSessions()) } + assertFalse(srvReceive.isDone) } @Test @@ -136,7 +146,7 @@ class DtlsServerTest { client.send("perse") // when - client.cnnTrans.send("malformed dtls packet".toByteBuffer()) + client.transport.send("malformed dtls packet".toByteBuffer()) client.send("perse") // then @@ -228,12 +238,12 @@ class DtlsServerTest { @Test fun `should successfully handshake with retransmission`() { server = DtlsServer.create(timeoutConf).listen(echoHandler) - val cli: ConnectedDatagramTransmitter = ConnectedDatagramTransmitter + val cli = DatagramChannelAdapter .connect(localAddress(server.localPort())) .dropReceive { it == 1 } // drop ServerHello, the only message that server will retry // when - val sslSession = DtlsTransmitter.connect(timeoutClientConf, cli).await() + val sslSession = DtlsTransmitter.connect(server.localAddress(), timeoutClientConf, cli).await() // then sslSession.close() @@ -243,12 +253,12 @@ class DtlsServerTest { @Test fun `should remove handshake session when handshake timeout`() { server = DtlsServer.create(timeoutConf).listen(echoHandler) - val cli: ConnectedDatagramTransmitter = ConnectedDatagramTransmitter - .connect(localAddress(server.localPort())) + val cli = DatagramChannelAdapter + .connect(server.localAddress()) .dropReceive { it > 0 } // drop everything after client hello with verify // when - DtlsTransmitter.connect(timeoutClientConf, cli) + DtlsTransmitter.connect(server.localAddress(), timeoutClientConf, cli) // then await.untilAsserted { @@ -312,8 +322,8 @@ class DtlsServerTest { // establish dtls connections val clients = (1..MAX) .map { - val ch = ConnectedDatagramTransmitter.connect(localAddress(server.localPort()), 0) - DtlsTransmitter.connect(clientConfig, ch, executors[it % executors.size]) + val ch = DatagramChannelAdapter.connect(server.localAddress(), 0) + DtlsTransmitter.connect(server.localAddress(), clientConfig, ch, executors[it % executors.size]) .get(30, TimeUnit.SECONDS) .also { it.send("hello") } } @@ -339,22 +349,29 @@ class DtlsServerTest { assertTrue(server.executor is ScheduledThreadPoolExecutor) } - private fun ConnectedDatagramTransmitter.dropReceive(drop: (Int) -> Boolean): ConnectedDatagramTransmitter { + private fun Transport.dropReceive(drop: (Int) -> Boolean): Transport { val underlying = this var i = 0 - return object : ConnectedDatagramTransmitter by this { + return object : Transport by this { private val logger = LoggerFactory.getLogger(javaClass) - override fun receive(buf: ByteBuffer, timeout: Duration) { - underlying.receive(buf, timeout) - if (drop(i++)) { - logger.info("receive DROPPED {}", buf.remaining()) - receive(buf, timeout) - } else { - logger.info("receive {}", buf.remaining()) - } + override fun receive(timeout: Duration): CompletableFuture { + return underlying.receive(timeout) + .thenCompose { + if (drop(i++)) { + logger.info("receive DROPPED {}", it) + receive(timeout) + } else { + logger.info("receive {}", it) + completedFuture(it) + } + } } } } } + +fun Transport.receiveString(): String { + return receive(Duration.ofSeconds(5)).join().decodeToString() +} diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterCertTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterCertTest.kt index 5b8970e7..aef04700 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterCertTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterCertTest.kt @@ -31,13 +31,13 @@ import org.slf4j.LoggerFactory import java.nio.ByteBuffer import java.time.Duration import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletableFuture.completedFuture class DtlsTransmitterCertTest { - private lateinit var srvTrans: ConnectedDatagramTransmitter + private lateinit var srvTrans: Transport private val randomCid = RandomCidSupplier(16) private var serverConf = SslConfig.server(Certs.serverChain, Certs.server.privateKey, listOf(Certs.root.asX509())) - private val logger = LoggerFactory.getLogger(javaClass) @AfterEach fun after() { @@ -45,8 +45,8 @@ class DtlsTransmitterCertTest { } private fun newServerDtlsTransmitter(destLocalPort: Int): CompletableFuture { - srvTrans = ConnectedDatagramTransmitter.connect(localAddress(destLocalPort), 0) - return DtlsTransmitter.connect(serverConf, srvTrans) + srvTrans = DatagramChannelAdapter.connect(localAddress(destLocalPort), 0) + return DtlsTransmitter.connect(localAddress(destLocalPort), serverConf, srvTrans) } @Test @@ -139,12 +139,12 @@ class DtlsTransmitterCertTest { retransmitMin = Duration.ofMillis(10), retransmitMax = Duration.ofMillis(100) ) - val cli: ConnectedDatagramTransmitter = ConnectedDatagramTransmitter + val cli = DatagramChannelAdapter .connect(srvTrans.localAddress(), 7007) .dropSend { it % 3 != 2 } // when - val sslSession = DtlsTransmitter.connect(clientConf, cli).await() + val sslSession = DtlsTransmitter.connect(srvTrans.localAddress(), clientConf, cli).await() // then sslSession.close() @@ -161,12 +161,12 @@ class DtlsTransmitterCertTest { retransmitMin = Duration.ofMillis(10), retransmitMax = Duration.ofMillis(100) ) - val cli: ConnectedDatagramTransmitter = ConnectedDatagramTransmitter + val cli = DatagramChannelAdapter .connect(srvTrans.localAddress(), 7008) .dropSend { true } // when - val res = runCatching { DtlsTransmitter.connect(clientConf, cli).await() } + val res = runCatching { DtlsTransmitter.connect(srvTrans.localAddress(), clientConf, cli).await() } // then assertEquals("SSL - The operation timed out [-0x6800]", res.exceptionOrNull()?.cause?.message) @@ -175,18 +175,19 @@ class DtlsTransmitterCertTest { } } -internal fun ConnectedDatagramTransmitter.dropSend(drop: (Int) -> Boolean): ConnectedDatagramTransmitter { +internal fun

Transport

.dropSend(drop: (Int) -> Boolean): Transport

{ val underlying = this var i = 0 - return object : ConnectedDatagramTransmitter by this { + return object : Transport

by this { private val logger = LoggerFactory.getLogger(javaClass) - override fun send(buf: ByteBuffer) { - if (!drop(i++)) { - underlying.send(buf) + override fun send(packet: P): CompletableFuture { + return if (!drop(i++)) { + underlying.send(packet) } else { - logger.info("send DROPPED {}", buf.remaining()) + logger.info("send DROPPED {}", packet) + completedFuture(true) } } } diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterTest.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterTest.kt index 22c59c14..40e70759 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterTest.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/transport/DtlsTransmitterTest.kt @@ -26,6 +26,7 @@ import org.opencoap.ssl.util.await import org.opencoap.ssl.util.decodeHex import org.opencoap.ssl.util.localAddress import org.opencoap.ssl.util.runGC +import java.nio.ByteBuffer import java.time.Duration import java.util.concurrent.CompletableFuture import kotlin.random.Random @@ -34,7 +35,7 @@ class DtlsTransmitterTest { private val cidSupplier = { Random.nextBytes(16) } private val serverConf = SslConfig.server("dupa".encodeToByteArray(), byteArrayOf(0x01, 0x02), cidSupplier = cidSupplier) - private lateinit var srvTrans: ConnectedDatagramTransmitter + private lateinit var srvTrans: Transport @AfterEach fun after() { @@ -43,8 +44,8 @@ class DtlsTransmitterTest { } private fun newServerDtlsTransmitter(destLocalPort: Int): CompletableFuture { - srvTrans = ConnectedDatagramTransmitter.connect(localAddress(destLocalPort), 1_5684) - return DtlsTransmitter.connect(serverConf, srvTrans) + srvTrans = DatagramChannelAdapter.connect(localAddress(destLocalPort), 1_5684) + return DtlsTransmitter.connect(localAddress(destLocalPort), serverConf, srvTrans) } @Test @@ -55,13 +56,14 @@ class DtlsTransmitterTest { // when val client = DtlsTransmitter.connect(localAddress(1_5684), conf, 6001).await() + assertEquals(localAddress(1_5684), client.remoteAddress) runGC() // make sure none of needed objects is garbage collected // then client.send("dupa") assertEquals("dupa", server.await().receiveString()) // and read with timeout - assertTrue(client.receive(Duration.ofMillis(1)).isEmpty()) + assertTrue(client.receive(Duration.ofMillis(1)).join().isEmpty()) assertNotNull(client.getCipherSuite()) client.close() @@ -71,7 +73,7 @@ class DtlsTransmitterTest { @Test fun `should fail to handshake - wrong psk`() { - val server = newServerDtlsTransmitter(6002) + newServerDtlsTransmitter(6002) val conf = SslConfig.client("dupa".encodeToByteArray(), "bad".encodeToByteArray()) val client = DtlsTransmitter.connect(localAddress(1_5684), conf, 6002) @@ -118,11 +120,11 @@ class DtlsTransmitterTest { val cliSession = "030201003700000f0000006b030000000063495efcc0a420fcd161a09184307644d53c759d3e15a56ff410967160e5ab24f6f6576ec3df661713ceff637a5d525f4d903e440f01eb538628c8598e77a933daf8c96540ba4330398e1eb5d51b5e16a2531589c10c2300000000000000000000000000000063495efce6b125a4061b5f94f80b1b5a9eb0b9fbc08fa5ea7f44359d477ff1cd63495efc3b50619fde84b36978e2752e217c80f2aa79e7465f6940f8f7cc6c2f010110c8d5148adf5ddd18c92bd799044643510000000000000000000000000000000000000001000001000000000002000000".decodeHex() val srvSession = "030201003700000f0000006b030000000063495efcc0a420fcd161a09184307644d53c759d3e15a56ff410967160e5ab24f6f6576ec3df661713ceff637a5d525f4d903e440f01eb538628c8598e77a933daf8c96540ba4330398e1eb5d51b5e16a2531589c10c2300000000000000000000000000000063495efce6b125a4061b5f94f80b1b5a9eb0b9fbc08fa5ea7f44359d477ff1cd63495efc3b50619fde84b36978e2752e217c80f2aa79e7465f6940f8f7cc6c2f10c8d5148adf5ddd18c92bd7990446435101010000000000000000000000010000000000000003000001000000000001000000".decodeHex() val clientConf = SslConfig.client("dupa".encodeToByteArray(), byteArrayOf(0x01, 0x02), listOf("TLS-PSK-WITH-AES-128-CCM"), { byteArrayOf(0x01) }) - srvTrans = ConnectedDatagramTransmitter.connect(localAddress(6004), 2_5684) + srvTrans = DatagramChannelAdapter.connect(localAddress(6004), 2_5684) // when val client = DtlsTransmitter.create(localAddress(2_5684), clientConf.loadSession(byteArrayOf(), cliSession, localAddress(2_5684)), 6004) - val server = DtlsTransmitter.create(serverConf.loadSession(byteArrayOf(), srvSession, localAddress(6004)), srvTrans) + val server = DtlsTransmitter.create(localAddress(6004), serverConf.loadSession(byteArrayOf(), srvSession, localAddress(6004)), srvTrans) runGC() // then diff --git a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/util/Utils.kt b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/util/Utils.kt index f8e81412..6b772c74 100644 --- a/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/util/Utils.kt +++ b/kotlin-mbedtls/src/test/kotlin/org/opencoap/ssl/util/Utils.kt @@ -17,6 +17,7 @@ package org.opencoap.ssl.util import com.sun.jna.Memory +import org.opencoap.ssl.transport.Transport import java.net.InetAddress import java.net.InetSocketAddress import java.nio.ByteBuffer @@ -67,3 +68,6 @@ val Int.seconds: Duration val Int.millis: Duration get() = Duration.ofMillis(this.toLong()) + +fun Transport.localAddress(): InetSocketAddress = + InetSocketAddress(InetAddress.getLocalHost(), this.localPort())