diff --git a/libi2pd/SSU2.cpp b/libi2pd/SSU2.cpp index 5e2533f4c59..2005294be47 100644 --- a/libi2pd/SSU2.cpp +++ b/libi2pd/SSU2.cpp @@ -450,7 +450,8 @@ namespace transport if (session) { m_Sessions.emplace (session->GetConnID (), session); - AddSessionByRouterHash (session); + if (session->GetState () != eSSU2SessionStatePeerTest) + AddSessionByRouterHash (session); } } @@ -459,13 +460,16 @@ namespace transport auto it = m_Sessions.find (connID); if (it != m_Sessions.end ()) { - auto ident = it->second->GetRemoteIdentity (); - if (ident) - { - auto it1 = m_SessionsByRouterHash.find (ident->GetIdentHash ()); - if (it1 != m_SessionsByRouterHash.end () && it->second == it1->second) - m_SessionsByRouterHash.erase (it1); - } + if (it->second->GetState () != eSSU2SessionStatePeerTest) + { + auto ident = it->second->GetRemoteIdentity (); + if (ident) + { + auto it1 = m_SessionsByRouterHash.find (ident->GetIdentHash ()); + if (it1 != m_SessionsByRouterHash.end () && it->second == it1->second.lock ()) + m_SessionsByRouterHash.erase (it1); + } + } if (m_LastSession == it->second) m_LastSession = nullptr; m_Sessions.erase (it); @@ -480,16 +484,20 @@ namespace transport if (ident) { auto ret = m_SessionsByRouterHash.emplace (ident->GetIdentHash (), session); - if (!ret.second && ret.first->second != session) + if (!ret.second) { - // session already exists - LogPrint (eLogWarning, "SSU2: Session to ", ident->GetIdentHash ().ToBase64 (), " already exists"); - // move unsent msgs to new session - ret.first->second->MoveSendQueue (session); - // terminate existing - GetService ().post (std::bind (&SSU2Session::RequestTermination, ret.first->second, eSSU2TerminationReasonReplacedByNewSession)); - // update session - ret.first->second = session; + auto oldSession = ret.first->second.lock (); + if (oldSession != session) + { + // session already exists + LogPrint (eLogWarning, "SSU2: Session to ", ident->GetIdentHash ().ToBase64 (), " already exists"); + // move unsent msgs to new session + oldSession->MoveSendQueue (session); + // terminate existing + GetService ().post (std::bind (&SSU2Session::RequestTermination, oldSession, eSSU2TerminationReasonReplacedByNewSession)); + // update session + ret.first->second = session; + } } } } @@ -506,7 +514,7 @@ namespace transport { auto it = m_SessionsByRouterHash.find (ident); if (it != m_SessionsByRouterHash.end ()) - return it->second; + return it->second.lock (); return nullptr; } @@ -761,11 +769,9 @@ namespace transport if (it != m_SessionsByRouterHash.end ()) { // session with router found, trying to send peer test if requested - if (peerTest && it->second->IsEstablished ()) - { - auto session = it->second; + auto session = it->second.lock (); + if (peerTest && session && session->IsEstablished ()) GetService ().post ([session]() { session->SendPeerTest (); }); - } return false; } // check is no pending session @@ -825,11 +831,15 @@ namespace transport auto it1 = m_SessionsByRouterHash.find (it.iH); if (it1 != m_SessionsByRouterHash.end ()) { - auto addr = it1->second->GetAddress (); - if (addr && addr->IsIntroducer ()) + auto s = it1->second.lock (); + if (s) { - it1->second->Introduce (session, it.iTag); - return; + auto addr = s->GetAddress (); + if (addr && addr->IsIntroducer ()) + { + s->Introduce (session, it.iTag); + return; + } } } else @@ -936,17 +946,19 @@ namespace transport if (!router) return false; auto addr = v4 ? router->GetSSU2V4Address () : router->GetSSU2V6Address (); if (!addr) return false; + std::shared_ptr session; auto it = m_SessionsByRouterHash.find (router->GetIdentHash ()); if (it != m_SessionsByRouterHash.end ()) + session = it->second.lock (); + if (session) { - auto remoteAddr = it->second->GetAddress (); + auto remoteAddr = session->GetAddress (); if (!remoteAddr || !remoteAddr->IsPeerTesting () || - (v4 && !remoteAddr->IsV4 ()) || (!v4 && !remoteAddr->IsV6 ())) return false; - auto s = it->second; - if (s->IsEstablished ()) - GetService ().post ([s]() { s->SendPeerTest (); }); + (v4 && !remoteAddr->IsV4 ()) || (!v4 && !remoteAddr->IsV6 ())) return false; + if (session->IsEstablished ()) + GetService ().post ([session]() { session->SendPeerTest (); }); else - s->SetOnEstablished ([s]() { s->SendPeerTest (); }); + session->SetOnEstablished ([session]() { session->SendPeerTest (); }); return true; } else @@ -997,7 +1009,7 @@ namespace transport for (auto it = m_SessionsByRouterHash.begin (); it != m_SessionsByRouterHash.begin ();) { - if (it->second && it->second->GetState () == eSSU2SessionStateTerminated) + if (it->second.expired ()) it = m_SessionsByRouterHash.erase (it); else it++; @@ -1183,7 +1195,7 @@ namespace transport auto it1 = m_SessionsByRouterHash.find (it); if (it1 != m_SessionsByRouterHash.end ()) { - session = it1->second; + session = it1->second.lock (); excluded.insert (it); } if (session && session->IsEstablished () && session->GetRelayTag () && session->IsOutgoing () && // still session with introducer? @@ -1217,8 +1229,8 @@ namespace transport auto it1 = m_SessionsByRouterHash.find (it); if (it1 != m_SessionsByRouterHash.end ()) { - auto session = it1->second; - if (session->IsEstablished () && session->GetRelayTag () && session->IsOutgoing ()) + auto session = it1->second.lock (); + if (session && session->IsEstablished () && session->GetRelayTag () && session->IsOutgoing ()) { session->SetCreationTime (session->GetCreationTime () + SSU2_TO_INTRODUCER_SESSION_DURATION); if (std::find (newList.begin (), newList.end (), it) == newList.end ()) diff --git a/libi2pd/SSU2.h b/libi2pd/SSU2.h index 913cc7475be..54eca83ff43 100644 --- a/libi2pd/SSU2.h +++ b/libi2pd/SSU2.h @@ -162,7 +162,7 @@ namespace transport boost::asio::ip::udp::socket m_SocketV4, m_SocketV6; boost::asio::ip::address m_AddressV4, m_AddressV6; std::unordered_map > m_Sessions; - std::unordered_map > m_SessionsByRouterHash; + std::unordered_map > m_SessionsByRouterHash; std::map > m_PendingOutgoingSessions; mutable std::mutex m_PendingOutgoingSessionsMutex; std::map > m_IncomingTokens, m_OutgoingTokens; // remote endpoint -> (token, expires in seconds) diff --git a/libi2pd/SSU2Session.cpp b/libi2pd/SSU2Session.cpp index 7d5763a27ff..7a7154a9b41 100644 --- a/libi2pd/SSU2Session.cpp +++ b/libi2pd/SSU2Session.cpp @@ -3140,7 +3140,7 @@ namespace transport SendPeerTest (7, buf + offset, len - offset); else LogPrint (eLogWarning, "SSU2: Unknown address for peer test 6"); - GetServer ().RemoveSession (~htobe64 (((uint64_t)nonce << 32) | nonce)); + GetServer ().RemoveSession (GetConnID ()); break; } case 7: // Alice from Charlie 2 @@ -3148,7 +3148,7 @@ namespace transport auto addr = GetAddress (); if (addr && addr->IsV6 ()) i2p::context.SetStatusV6 (eRouterStatusOK); // set status OK for ipv6 even if from SSU2 - GetServer ().RemoveSession (htobe64 (((uint64_t)nonce << 32) | nonce)); + GetServer ().RemoveSession (GetConnID ()); break; } default: @@ -3203,5 +3203,16 @@ namespace transport SetAddress (addr); SendPeerTest (msg, signedData, signedDataLen); } + + void SSU2PeerTestSession::Connect () + { + LogPrint (eLogError, "SSU2: Can't connect peer test session"); + } + + bool SSU2PeerTestSession::ProcessFirstIncomingMessage (uint64_t connID, uint8_t * buf, size_t len) + { + LogPrint (eLogError, "SSU2: Can't handle incoming message in peer test session"); + return false; + } } } diff --git a/libi2pd/SSU2Session.h b/libi2pd/SSU2Session.h index eaf727cc85d..d261bc6dc1a 100644 --- a/libi2pd/SSU2Session.h +++ b/libi2pd/SSU2Session.h @@ -248,7 +248,7 @@ namespace transport void SetOnEstablished (OnEstablished e) { m_OnEstablished = e; }; OnEstablished GetOnEstablished () const { return m_OnEstablished; }; - void Connect (); + virtual void Connect (); bool Introduce (std::shared_ptr session, uint32_t relayTag); void WaitForIntroduction (); void SendPeerTest (); // Alice, Data message @@ -268,7 +268,7 @@ namespace transport SSU2SessionState GetState () const { return m_State; }; void SetState (SSU2SessionState state) { m_State = state; }; - bool ProcessFirstIncomingMessage (uint64_t connID, uint8_t * buf, size_t len); + virtual bool ProcessFirstIncomingMessage (uint64_t connID, uint8_t * buf, size_t len); bool ProcessSessionCreated (uint8_t * buf, size_t len); bool ProcessSessionConfirmed (uint8_t * buf, size_t len); bool ProcessRetry (uint8_t * buf, size_t len); @@ -404,7 +404,9 @@ namespace transport void SendPeerTest (uint8_t msg, const uint8_t * signedData, size_t signedDataLen, std::shared_ptr addr); bool ProcessPeerTest (uint8_t * buf, size_t len) override; - + void Connect () override; // outgoing + bool ProcessFirstIncomingMessage (uint64_t connID, uint8_t * buf, size_t len) override; // incoming + private: void SendPeerTest (uint8_t msg, const uint8_t * signedData, size_t signedDataLen); // PeerTest message