Skip to content

Commit

Permalink
Yamux update
Browse files Browse the repository at this point in the history
  • Loading branch information
erwin-kok committed Jan 25, 2024
1 parent 2e2d22e commit c3d07e1
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class Session(
private val streamChannel = SafeChannel<MuxedStream>(16)
private val outputChannel = SafeChannel<Frame>(16)
private val mutex = ReentrantLock()
private val streams = mutableMapOf<MplexStreamId, YamuxMuxedStream>()
private val streams = mutableMapOf<YamuxStreamId, YamuxMuxedStream>()
private val nextId = AtomicLong(0)
private val isClosing = AtomicBoolean(false)
private var closeCause: Error? = null
Expand Down Expand Up @@ -172,7 +172,7 @@ class Session(
_context.complete()
}

internal fun removeStream(streamId: MplexStreamId) {
internal fun removeStream(streamId: YamuxStreamId) {
mutex.withLock {
streams.remove(streamId)
if (isClosing.get() && streams.isEmpty()) {
Expand Down Expand Up @@ -234,19 +234,19 @@ class Session(
private suspend fun processFrame(mplexFrame: Frame) {
val id = mplexFrame.id
val initiator = mplexFrame.initiator
val mplexStreamId = MplexStreamId(!initiator, id)
val yamuxStreamId = YamuxStreamId(!initiator, id)
mutex.lock()
val stream: YamuxMuxedStream? = streams[mplexStreamId]
val stream: YamuxMuxedStream? = streams[yamuxStreamId]
when (mplexFrame) {
is NewStreamFrame -> {
if (stream != null) {
mutex.unlock()
logger.warn { "$this: Remote creates existing new stream: $id. Ignoring." }
} else {
logger.debug { "$this: Remote creates new stream: $id" }
val name = streamName(mplexFrame.name, mplexStreamId)
val newStream = YamuxMuxedStream(scope, this, outputChannel, mplexStreamId, name, pool)
streams[mplexStreamId] = newStream
val name = streamName(mplexFrame.name, yamuxStreamId)
val newStream = YamuxMuxedStream(scope, this, outputChannel, yamuxStreamId, name, pool)
streams[yamuxStreamId] = newStream
mutex.unlock()
streamChannel.send(newStream)
}
Expand All @@ -271,29 +271,29 @@ class Session(
stream.remoteSendsNewMessage(builder.build())
}
if (timeout == null) {
logger.warn { "$this: Reader timeout for stream: $mplexStreamId. Reader is too slow, resetting the stream." }
logger.warn { "$this: Reader timeout for stream: $yamuxStreamId. Reader is too slow, resetting the stream." }
stream.reset()
}
} else {
mutex.unlock()
logger.warn { "$this: Remote sends message on non-existing stream: $mplexStreamId" }
logger.warn { "$this: Remote sends message on non-existing stream: $yamuxStreamId" }
}
}

is CloseFrame -> {
if (logger.isDebugEnabled) {
if (initiator) {
logger.debug("$this: Remote closes his stream: $mplexStreamId")
logger.debug("$this: Remote closes his stream: $yamuxStreamId")
} else {
logger.debug("$this: Remote closes our stream: $mplexStreamId")
logger.debug("$this: Remote closes our stream: $yamuxStreamId")
}
}
if (stream != null) {
mutex.unlock()
stream.remoteClosesWriting()
} else {
mutex.unlock()
logger.debug { "$this: Remote closes non-existing stream: $mplexStreamId" }
logger.debug { "$this: Remote closes non-existing stream: $yamuxStreamId" }
}
}

Expand Down Expand Up @@ -323,7 +323,7 @@ class Session(
}
mutex.lock()
val id = nextId.getAndIncrement()
val streamId = MplexStreamId(true, id)
val streamId = YamuxStreamId(true, id)
logger.debug { "$this: We create stream: $id" }
val name = streamName(newName, streamId)
val muxedStream = YamuxMuxedStream(scope, this, outputChannel, streamId, name, pool)
Expand All @@ -333,7 +333,7 @@ class Session(
return Ok(muxedStream)
}

private fun streamName(name: String?, streamId: MplexStreamId): String {
private fun streamName(name: String?, streamId: YamuxStreamId): String {
if (name != null) {
return name
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,14 @@ class YamuxMuxedStream(
private val scope: CoroutineScope,
private val session: Session,
private val outputChannel: Channel<Frame>,
private val mplexStreamId: MplexStreamId,
private val yamuxStreamId: YamuxStreamId,
override val name: String,
override val pool: ObjectPool<ChunkBuffer>
) : MuxedStream {
// sendWindow uint32
//
// memorySpan MemoryManager
//
// id uint32
// session *Session
//
// recvWindow uint32
// epochStart time.Time
//
Expand All @@ -87,7 +84,7 @@ class YamuxMuxedStream(
private val readerJob: ReaderJob

override val id
get() = mplexStreamId.toString()
get() = yamuxStreamId.toString()
override val jobContext: Job
get() = _context

Expand All @@ -100,7 +97,7 @@ class YamuxMuxedStream(
}.apply {
invokeOnCompletion {
if (readerJob.isCompleted) {
session.removeStream(mplexStreamId)
session.removeStream(yamuxStreamId)
}
}
}
Expand All @@ -111,7 +108,7 @@ class YamuxMuxedStream(
}.apply {
invokeOnCompletion {
if (writerJob.isCompleted) {
session.removeStream(mplexStreamId)
session.removeStream(yamuxStreamId)
}
}
}
Expand Down Expand Up @@ -143,7 +140,7 @@ class YamuxMuxedStream(
if (size > 0) {
buffer.flip()
val packet = buildPacket(pool) { writeFully(buffer) }
val messageFrame = MessageFrame(mplexStreamId, packet)
val messageFrame = MessageFrame(yamuxStreamId, packet)
outputChannel.send(messageFrame)
}
} catch (e: CancellationException) {
Expand All @@ -160,9 +157,9 @@ class YamuxMuxedStream(
}
if (!outputChannel.isClosedForSend) {
if (channel.closedCause is StreamResetException) {
outputChannel.send(ResetFrame(mplexStreamId))
outputChannel.send(ResetFrame(yamuxStreamId))
} else {
outputChannel.send(CloseFrame(mplexStreamId))
outputChannel.send(CloseFrame(yamuxStreamId))
}
}
}
Expand All @@ -182,7 +179,7 @@ class YamuxMuxedStream(
}

override fun toString(): String {
return "mplex-<$mplexStreamId>"
return "mplex-<$yamuxStreamId>"
}

internal suspend fun remoteSendsNewMessage(packet: ByteReadPacket): Boolean {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2024 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details.
package org.erwinkok.libp2p.muxer.yamux

data class MplexStreamId(val initiator: Boolean, val id: Long) {
data class YamuxStreamId(val initiator: Boolean, val id: Long) {
override fun toString(): String {
return String.format("stream%08x/%s", id, if (initiator) "initiator" else "responder")
}
Expand All @@ -10,7 +10,7 @@ data class MplexStreamId(val initiator: Boolean, val id: Long) {
if (other === this) {
return true
}
if (other !is MplexStreamId) {
if (other !is YamuxStreamId) {
return super.equals(other)
}
return (id == other.id) and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import org.erwinkok.libp2p.core.network.readUnsignedVarInt
import org.erwinkok.libp2p.core.network.writeUnsignedVarInt
import org.erwinkok.libp2p.muxer.yamux.MplexStreamId
import org.erwinkok.libp2p.muxer.yamux.YamuxStreamId
import org.erwinkok.result.Error
import org.erwinkok.result.Result
import org.erwinkok.result.map
import org.erwinkok.result.toErrorIf

internal class CloseFrame(streamId: MplexStreamId) : Frame(streamId) {
internal class CloseFrame(streamId: YamuxStreamId) : Frame(streamId) {
override val type: Int
get() {
return if (streamId.initiator) CloseInitiatorTag else CloseReceiverTag
Expand All @@ -24,7 +24,7 @@ internal class CloseFrame(streamId: MplexStreamId) : Frame(streamId) {
}
}

internal suspend fun ByteReadChannel.readCloseFrame(streamId: MplexStreamId): Result<CloseFrame> {
internal suspend fun ByteReadChannel.readCloseFrame(streamId: YamuxStreamId): Result<CloseFrame> {
return readUnsignedVarInt()
.toErrorIf({ it != 0uL }, { Error("CloseFrame should not carry data") })
.map { CloseFrame(streamId) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.core.Closeable
import org.erwinkok.libp2p.core.network.readUnsignedVarInt
import org.erwinkok.libp2p.core.network.writeUnsignedVarInt
import org.erwinkok.libp2p.muxer.yamux.MplexStreamId
import org.erwinkok.libp2p.muxer.yamux.YamuxStreamId
import org.erwinkok.result.Err
import org.erwinkok.result.Result
import org.erwinkok.result.getOrElse

sealed class Frame(val streamId: MplexStreamId) : Closeable {
sealed class Frame(val streamId: YamuxStreamId) : Closeable {
val initiator: Boolean get() = streamId.initiator
val id: Long get() = streamId.id
abstract val type: Int
Expand Down Expand Up @@ -41,12 +41,12 @@ internal suspend fun ByteReadChannel.readMplexFrame(): Result<Frame> {
val id = (header shr 3).toLong()
return when (tag) {
Frame.NewStreamTag -> readNewStreamFrame(id)
Frame.MessageReceiverTag -> readMessageFrame(MplexStreamId(false, id))
Frame.MessageInitiatorTag -> readMessageFrame(MplexStreamId(true, id))
Frame.CloseReceiverTag -> readCloseFrame(MplexStreamId(false, id))
Frame.CloseInitiatorTag -> readCloseFrame(MplexStreamId(true, id))
Frame.ResetReceiverTag -> readResetFrame(MplexStreamId(false, id))
Frame.ResetInitiatorTag -> readResetFrame(MplexStreamId(true, id))
Frame.MessageReceiverTag -> readMessageFrame(YamuxStreamId(false, id))
Frame.MessageInitiatorTag -> readMessageFrame(YamuxStreamId(true, id))
Frame.CloseReceiverTag -> readCloseFrame(YamuxStreamId(false, id))
Frame.CloseInitiatorTag -> readCloseFrame(YamuxStreamId(true, id))
Frame.ResetReceiverTag -> readResetFrame(YamuxStreamId(false, id))
Frame.ResetInitiatorTag -> readResetFrame(YamuxStreamId(true, id))
else -> Err("Unknown Mplex tag type '$tag'")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import io.ktor.utils.io.ByteWriteChannel
import io.ktor.utils.io.core.ByteReadPacket
import org.erwinkok.libp2p.core.network.readUnsignedVarInt
import org.erwinkok.libp2p.core.network.writeUnsignedVarInt
import org.erwinkok.libp2p.muxer.yamux.MplexStreamId
import org.erwinkok.libp2p.muxer.yamux.YamuxStreamId
import org.erwinkok.result.Result
import org.erwinkok.result.map

internal class MessageFrame(streamId: MplexStreamId, val packet: ByteReadPacket) : Frame(streamId) {
internal class MessageFrame(streamId: YamuxStreamId, val packet: ByteReadPacket) : Frame(streamId) {
override val type: Int
get() {
return if (streamId.initiator) MessageInitiatorTag else MessageReceiverTag
Expand All @@ -26,7 +26,7 @@ internal class MessageFrame(streamId: MplexStreamId, val packet: ByteReadPacket)
}
}

internal suspend fun ByteReadChannel.readMessageFrame(streamId: MplexStreamId): Result<MessageFrame> {
internal suspend fun ByteReadChannel.readMessageFrame(streamId: YamuxStreamId): Result<MessageFrame> {
return readUnsignedVarInt()
.map { length -> readPacket(length.toInt()) }
.map { packet -> MessageFrame(streamId, packet) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import io.ktor.utils.io.core.toByteArray
import io.ktor.utils.io.writeFully
import org.erwinkok.libp2p.core.network.readUnsignedVarInt
import org.erwinkok.libp2p.core.network.writeUnsignedVarInt
import org.erwinkok.libp2p.muxer.yamux.MplexStreamId
import org.erwinkok.libp2p.muxer.yamux.YamuxStreamId
import org.erwinkok.result.Result
import org.erwinkok.result.map

internal class NewStreamFrame(id: Long, val name: String) : Frame(MplexStreamId(true, id)) {
internal class NewStreamFrame(id: Long, val name: String) : Frame(YamuxStreamId(true, id)) {
override val type: Int get() = NewStreamTag

override fun close(): Unit = Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import org.erwinkok.libp2p.core.network.readUnsignedVarInt
import org.erwinkok.libp2p.core.network.writeUnsignedVarInt
import org.erwinkok.libp2p.muxer.yamux.MplexStreamId
import org.erwinkok.libp2p.muxer.yamux.YamuxStreamId
import org.erwinkok.result.Error
import org.erwinkok.result.Result
import org.erwinkok.result.map
import org.erwinkok.result.toErrorIf

internal class ResetFrame(streamId: MplexStreamId) : Frame(streamId) {
internal class ResetFrame(streamId: YamuxStreamId) : Frame(streamId) {
override val type: Int
get() {
return if (streamId.initiator) ResetInitiatorTag else ResetReceiverTag
Expand All @@ -24,7 +24,7 @@ internal class ResetFrame(streamId: MplexStreamId) : Frame(streamId) {
}
}

internal suspend fun ByteReadChannel.readResetFrame(streamId: MplexStreamId): Result<ResetFrame> {
internal suspend fun ByteReadChannel.readResetFrame(streamId: YamuxStreamId): Result<ResetFrame> {
return readUnsignedVarInt()
.toErrorIf({ it != 0uL }, { Error("ResetFrame should not carry data") })
.map { ResetFrame(streamId) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
repeat(1000) {
val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors()
assertEquals("newStreamName$it", muxedStream.name)
assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id)
assertEquals(YamuxStreamId(true, it.toLong()).toString(), muxedStream.id)
val actual = connectionPair.remote.input.readMplexFrame().expectNoErrors()
assertInstanceOf(NewStreamFrame::class.java, actual)
assertTrue(actual.initiator)
Expand All @@ -94,7 +94,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
assertEquals("aName$id", muxedStream.name)
assertStreamHasId(false, id, muxedStream)
val random1 = Random.nextBytes(1000)
connectionPair.remote.output.writeMplexFrame(MessageFrame(MplexStreamId(true, id), buildPacket(pool) { writeFully(random1) }))
connectionPair.remote.output.writeMplexFrame(MessageFrame(YamuxStreamId(true, id), buildPacket(pool) { writeFully(random1) }))
connectionPair.remote.output.flush()
assertFalse(muxedStream.input.isClosedForRead)
val random2 = ByteArray(random1.size)
Expand Down Expand Up @@ -137,7 +137,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
repeat(1000) {
val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors()
assertEquals("newStreamName$it", muxedStream.name)
assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id)
assertEquals(YamuxStreamId(true, it.toLong()).toString(), muxedStream.id)
assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote)
val random1 = Random.nextBytes(1000)
assertFalse(muxedStream.output.isClosedForWrite)
Expand All @@ -158,10 +158,10 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
repeat(1000) {
val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors()
assertEquals("newStreamName$it", muxedStream.name)
assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id)
assertEquals(YamuxStreamId(true, it.toLong()).toString(), muxedStream.id)
assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote)
val random1 = Random.nextBytes(1000)
connectionPair.remote.output.writeMplexFrame(MessageFrame(MplexStreamId(false, it.toLong()), buildPacket(pool) { writeFully(random1) }))
connectionPair.remote.output.writeMplexFrame(MessageFrame(YamuxStreamId(false, it.toLong()), buildPacket(pool) { writeFully(random1) }))
connectionPair.remote.output.flush()
assertFalse(muxedStream.input.isClosedForRead)
val random2 = ByteArray(random1.size)
Expand All @@ -186,7 +186,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
assertStreamHasId(false, id, muxedStream)
assertFalse(muxedStream.input.isClosedForRead)
assertFalse(muxedStream.output.isClosedForWrite)
connectionPair.remote.output.writeMplexFrame(CloseFrame(MplexStreamId(true, id)))
connectionPair.remote.output.writeMplexFrame(CloseFrame(YamuxStreamId(true, id)))
connectionPair.remote.output.flush()
val exception = assertThrows<ClosedReceiveChannelException> {
muxedStream.input.readPacket(10)
Expand Down Expand Up @@ -233,7 +233,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
repeat(1000) {
val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors()
assertEquals("newStreamName$it", muxedStream.name)
assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id)
assertEquals(YamuxStreamId(true, it.toLong()).toString(), muxedStream.id)
assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote)
muxedStream.output.close()
yield()
Expand All @@ -257,11 +257,11 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
repeat(1000) {
val muxedStream = mplexMultiplexer.openStream("newStreamName$it").expectNoErrors()
assertEquals("newStreamName$it", muxedStream.name)
assertEquals(MplexStreamId(true, it.toLong()).toString(), muxedStream.id)
assertEquals(YamuxStreamId(true, it.toLong()).toString(), muxedStream.id)
assertNewStreamFrameReceived(it, "newStreamName$it", connectionPair.remote)
assertFalse(muxedStream.input.isClosedForRead)
assertFalse(muxedStream.output.isClosedForWrite)
connectionPair.remote.output.writeMplexFrame(CloseFrame(MplexStreamId(false, it.toLong())))
connectionPair.remote.output.writeMplexFrame(CloseFrame(YamuxStreamId(false, it.toLong())))
connectionPair.remote.output.flush()
val exception = assertThrows<ClosedReceiveChannelException> {
muxedStream.input.readPacket(10)
Expand Down Expand Up @@ -448,7 +448,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
}

private fun assertStreamHasId(initiator: Boolean, id: Long, muxedStream: MuxedStream) {
assertEquals(MplexStreamId(initiator, id).toString(), muxedStream.id)
assertEquals(YamuxStreamId(initiator, id).toString(), muxedStream.id)
}

private suspend fun assertMessageFrameReceived(expected: ByteArray, connection: Connection) {
Expand Down
Loading

0 comments on commit c3d07e1

Please sign in to comment.