Skip to content

Commit

Permalink
fix(dcutr): update the DCUtR initiator transport direction to Inbound (
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomrsantos authored Nov 29, 2023
1 parent ce0685c commit deb72c8
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 50 deletions.
2 changes: 1 addition & 1 deletion libp2p/dial.nim
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async, base.} =
dir = Direction.Out) {.async, base.} =
## connect remote peer without negotiating
## a protocol
##
Expand Down
26 changes: 15 additions & 11 deletions libp2p/dialer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ proc dialAndUpgrade(
peerId: Opt[PeerId],
hostname: string,
address: MultiAddress,
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =

for transport in self.transports: # for each transport
Expand All @@ -75,23 +75,27 @@ proc dialAndUpgrade(

let mux =
try:
dialed.transportDir = upgradeDir
await transport.upgrade(dialed, upgradeDir, peerId)
# This is for the very specific case of a simultaneous dial during DCUtR. In this case, both sides will have
# an Outbound direction at the transport level. Therefore we update the DCUtR initiator transport direction to Inbound.
# The if below is more general and might handle other use cases in the future.
if dialed.dir != dir:
dialed.dir = dir
await transport.upgrade(dialed, peerId)
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
await dialed.close()
debug "Upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
if exc isnot CancelledError:
if upgradeDir == Direction.Out:
if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()

# Try other address
return nil

doAssert not isNil(mux), "connection died after upgrade " & $upgradeDir
doAssert not isNil(mux), "connection died after upgrade " & $dialed.dir
debug "Dial successful", peerId = mux.connection.peerId
return mux
return nil
Expand Down Expand Up @@ -128,7 +132,7 @@ proc dialAndUpgrade(
self: Dialer,
peerId: Opt[PeerId],
addrs: seq[MultiAddress],
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =

debug "Dialing peer", peerId = peerId.get(default(PeerId))
Expand All @@ -146,7 +150,7 @@ proc dialAndUpgrade(
else: await self.nameResolver.resolveMAddress(expandedAddress)

for resolvedAddress in resolvedAddresses:
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, upgradeDir)
result = await self.dialAndUpgrade(addrPeerId, hostname, resolvedAddress, dir)
if not isNil(result):
return result

Expand All @@ -164,7 +168,7 @@ proc internalConnect(
addrs: seq[MultiAddress],
forceDial: bool,
reuseConnection = true,
upgradeDir = Direction.Out):
dir = Direction.Out):
Future[Muxer] {.async.} =
if Opt.some(self.localPeerId) == peerId:
raise newException(CatchableError, "can't dial self!")
Expand All @@ -182,7 +186,7 @@ proc internalConnect(
let slot = self.connManager.getOutgoingSlot(forceDial)
let muxed =
try:
await self.dialAndUpgrade(peerId, addrs, upgradeDir)
await self.dialAndUpgrade(peerId, addrs, dir)
except CatchableError as exc:
slot.release()
raise exc
Expand All @@ -209,15 +213,15 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
dir = Direction.Out) {.async.} =
## connect remote peer without negotiating
## a protocol
##

if self.connManager.connCount(peerId) > 0 and reuseConnection:
return

discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, upgradeDir)
discard await self.internalConnect(Opt.some(peerId), addrs, forceDial, reuseConnection, dir)

method connect*(
self: Dialer,
Expand Down
2 changes: 1 addition & 1 deletion libp2p/protocols/connectivity/dcutr/client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ proc startSync*(self: DcutrClient, switch: Switch, remotePeerId: PeerId, addrs:

if peerDialableAddrs.len > self.maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<self.maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.In))
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.In))
try:
discard await anyCompleted(futs).wait(self.connectTimeout)
debug "Dcutr initiator has directly connected to the remote peer."
Expand Down
2 changes: 1 addition & 1 deletion libp2p/protocols/connectivity/dcutr/server.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ proc new*(T: typedesc[Dcutr], switch: Switch, connectTimeout = 15.seconds, maxDi

if peerDialableAddrs.len > maxDialableAddrs:
peerDialableAddrs = peerDialableAddrs[0..<maxDialableAddrs]
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, upgradeDir = Direction.Out))
var futs = peerDialableAddrs.mapIt(switch.connect(stream.peerId, @[it], forceDial = true, reuseConnection = false, dir = Direction.Out))
try:
discard await anyCompleted(futs).wait(connectTimeout)
debug "Dcutr receiver has directly connected to the remote peer."
Expand Down
3 changes: 1 addition & 2 deletions libp2p/protocols/secure/secure.nim
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,9 @@ method init*(s: Secure) =

method secure*(s: Secure,
conn: Connection,
initiator: bool,
peerId: Opt[PeerId]):
Future[Connection] {.base.} =
s.handleConn(conn, initiator, peerId)
s.handleConn(conn, conn.dir == Direction.Out, peerId)

method readOnce*(s: SecureConn,
pbytes: pointer,
Expand Down
6 changes: 3 additions & 3 deletions libp2p/switch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ method connect*(
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.public.} =
dir = Direction.Out): Future[void] {.public.} =
## Connects to a peer without opening a stream to it

s.dialer.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
s.dialer.connect(peerId, addrs, forceDial, reuseConnection, dir)

method connect*(
s: Switch,
Expand Down Expand Up @@ -213,7 +213,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil)
s.peerInfo.protocols.add(proto.codec)

proc upgrader(switch: Switch, trans: Transport, conn: Connection) {.async.} =
let muxed = await trans.upgrade(conn, Direction.In, Opt.none(PeerId))
let muxed = await trans.upgrade(conn, Opt.none(PeerId))
switch.connManager.storeMuxer(muxed)
await switch.peerStore.identify(muxed)
trace "Connection upgrade succeeded"
Expand Down
3 changes: 1 addition & 2 deletions libp2p/transports/transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ proc dial*(
method upgrade*(
self: Transport,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base, gcsafe.} =
## base upgrade method that the transport uses to perform
## transport specific upgrades
##

self.upgrader.upgrade(conn, direction, peerId)
self.upgrader.upgrade(conn, peerId)

method handles*(
self: Transport,
Expand Down
14 changes: 6 additions & 8 deletions libp2p/upgrademngrs/muxedupgrade.nim
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ proc getMuxerByCodec(self: MuxedUpgrade, muxerName: string): MuxerProvider =

proc mux*(
self: MuxedUpgrade,
conn: Connection,
direction: Direction): Future[Muxer] {.async, gcsafe.} =
conn: Connection): Future[Muxer] {.async, gcsafe.} =
## mux connection

trace "Muxing connection", conn
Expand All @@ -42,7 +41,7 @@ proc mux*(
return

let muxerName =
if direction == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
if conn.dir == Out: await self.ms.select(conn, self.muxers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.muxers.mapIt(it.codec))

if muxerName.len == 0 or muxerName == "na":
Expand All @@ -62,16 +61,15 @@ proc mux*(
method upgrade*(
self: MuxedUpgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.async.} =
trace "Upgrading connection", conn, direction
trace "Upgrading connection", conn, direction = conn.dir

let sconn = await self.secure(conn, direction, peerId) # secure the connection
let sconn = await self.secure(conn, peerId) # secure the connection
if isNil(sconn):
raise newException(UpgradeFailedError,
"unable to secure connection, stopping upgrade")

let muxer = await self.mux(sconn, direction) # mux it if possible
let muxer = await self.mux(sconn) # mux it if possible
if muxer == nil:
raise newException(UpgradeFailedError,
"a muxer is required for outgoing connections")
Expand All @@ -84,7 +82,7 @@ method upgrade*(
raise newException(UpgradeFailedError,
"Connection closed or missing peer info, stopping upgrade")

trace "Upgraded connection", conn, sconn, direction
trace "Upgraded connection", conn, sconn, direction = conn.dir
return muxer

proc new*(
Expand Down
6 changes: 2 additions & 4 deletions libp2p/upgrademngrs/upgrade.nim
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ type
method upgrade*(
self: Upgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Muxer] {.base.} =
doAssert(false, "Not implemented!")

proc secure*(
self: Upgrade,
conn: Connection,
direction: Direction,
peerId: Opt[PeerId]): Future[Connection] {.async, gcsafe.} =
if self.secureManagers.len <= 0:
raise newException(UpgradeFailedError, "No secure managers registered!")

let codec =
if direction == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
if conn.dir == Out: await self.ms.select(conn, self.secureManagers.mapIt(it.codec))
else: await MultistreamSelect.handle(conn, self.secureManagers.mapIt(it.codec))
if codec.len == 0:
raise newException(UpgradeFailedError, "Unable to negotiate a secure channel!")
Expand All @@ -65,4 +63,4 @@ proc secure*(
# let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0)

return await secureProtocol[0].secure(conn, direction == Out, peerId)
return await secureProtocol[0].secure(conn, peerId)
8 changes: 4 additions & 4 deletions tests/stubs/switchstub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ type
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.gcsafe, async.}
dir = Direction.Out): Future[void] {.gcsafe, async.}

method connect*(
self: SwitchStub,
peerId: PeerId,
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out) {.async.} =
dir = Direction.Out) {.async.} =
if (self.connectStub != nil):
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, upgradeDir)
await self.connectStub(self, peerId, addrs, forceDial, reuseConnection, dir)
else:
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, upgradeDir)
await self.switch.connect(peerId, addrs, forceDial, reuseConnection, dir)

proc new*(T: typedesc[SwitchStub], switch: Switch, connectStub: connectStubType = nil): T =
return SwitchStub(
Expand Down
8 changes: 4 additions & 4 deletions tests/testdcutr.nim
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis)

let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectTimeoutProc)
Expand All @@ -115,7 +115,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error")

let behindNATSwitch = SwitchStub.new(newStandardSwitch(), connectErrorProc)
Expand Down Expand Up @@ -163,7 +163,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
await sleepAsync(100.millis)

await ductrServerTest(connectProc)
Expand All @@ -175,7 +175,7 @@ suite "Dcutr":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
raise newException(CatchableError, "error")

await ductrServerTest(connectProc)
Expand Down
2 changes: 1 addition & 1 deletion tests/testhpservice.nim
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ suite "Hole Punching":
addrs: seq[MultiAddress],
forceDial = false,
reuseConnection = true,
upgradeDir = Direction.Out): Future[void] {.async.} =
dir = Direction.Out): Future[void] {.async.} =
self.connectStub = nil # this stub should be called only once
raise newException(CatchableError, "error")

Expand Down
16 changes: 8 additions & 8 deletions tests/testnoise.nim
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ suite "Noise":

proc acceptHandler() {.async.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
try:
await sconn.write("Hello!")
finally:
Expand All @@ -115,7 +115,7 @@ suite "Noise":
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])

let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

var msg = newSeq[byte](6)
await sconn.readExactly(addr msg[0], 6)
Expand Down Expand Up @@ -144,7 +144,7 @@ suite "Noise":
var conn: Connection
try:
conn = await transport1.accept()
discard await serverNoise.secure(conn, false, Opt.none(PeerId))
discard await serverNoise.secure(conn, Opt.none(PeerId))
except CatchableError:
discard
finally:
Expand All @@ -160,7 +160,7 @@ suite "Noise":

var sconn: Connection = nil
expect(NoiseDecryptTagError):
sconn = await clientNoise.secure(conn, true, Opt.some(conn.peerId))
sconn = await clientNoise.secure(conn, Opt.some(conn.peerId))

await conn.close()
await handlerWait
Expand All @@ -180,7 +180,7 @@ suite "Noise":

proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer:
await sconn.close()
await conn.close()
Expand All @@ -196,7 +196,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

await sconn.write("Hello!")
await acceptFut
Expand All @@ -223,7 +223,7 @@ suite "Noise":

proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
let sconn = await serverNoise.secure(conn, false, Opt.none(PeerId))
let sconn = await serverNoise.secure(conn, Opt.none(PeerId))
defer:
await sconn.close()
let msg = await sconn.readLp(1024*1024)
Expand All @@ -237,7 +237,7 @@ suite "Noise":
clientInfo = PeerInfo.new(clientPrivKey, transport1.addrs)
clientNoise = Noise.new(rng, clientPrivKey, outgoing = true)
conn = await transport2.dial(transport1.addrs[0])
let sconn = await clientNoise.secure(conn, true, Opt.some(serverInfo.peerId))
let sconn = await clientNoise.secure(conn, Opt.some(serverInfo.peerId))

await sconn.writeLp(hugePayload)
await readTask
Expand Down

0 comments on commit deb72c8

Please sign in to comment.