diff --git a/crates/papyrus_network/src/streamed_data_protocol/handler.rs b/crates/papyrus_network/src/streamed_data_protocol/handler.rs index 65fda3059c..e558705471 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/handler.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/handler.rs @@ -269,11 +269,21 @@ impl ConnectionHandler for Handler { self.inbound_sessions_marked_to_end.insert(inbound_session_id); + self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( + ToBehaviourEvent::SessionClosedByRequest { + session_id: SessionId::InboundSessionId(inbound_session_id), + }, + )); } RequestFromBehaviourEvent::CloseSession { session_id: SessionId::OutboundSessionId(outbound_session_id), } => { self.id_to_outbound_session.remove(&outbound_session_id); + self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( + ToBehaviourEvent::SessionClosedByRequest { + session_id: SessionId::OutboundSessionId(outbound_session_id), + }, + )); } } } diff --git a/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs b/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs index af248a8acb..bf2e04d157 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/handler_test.rs @@ -107,6 +107,22 @@ async fn validate_received_data_event( + handler: &mut Handler, + session_id: SessionId, +) { + let event = handler.next().await.unwrap(); + assert_matches!( + event, + ConnectionHandlerEvent::NotifyBehaviour(ToBehaviourEvent::SessionClosedByRequest { + session_id: event_session_id + }) if event_session_id == session_id + ); +} + async fn validate_outbound_session_closed_by_peer_event< Query: QueryBound, Data: DataBound + PartialEq, @@ -226,6 +242,11 @@ async fn closed_inbound_session_ignores_behaviour_request_to_send_data() { &mut handler, SessionId::InboundSessionId(inbound_session_id), ); + validate_session_closed_by_request_event( + &mut handler, + SessionId::InboundSessionId(inbound_session_id), + ) + .await; let hardcoded_data_vec = hardcoded_data(); for data in &hardcoded_data_vec { @@ -322,6 +343,11 @@ async fn closed_outbound_session_doesnt_emit_events_when_data_is_sent() { &mut handler, SessionId::OutboundSessionId(outbound_session_id), ); + validate_session_closed_by_request_event( + &mut handler, + SessionId::OutboundSessionId(outbound_session_id), + ) + .await; for data in hardcoded_data() { write_message(data, &mut inbound_stream).await.unwrap(); diff --git a/crates/papyrus_network/src/streamed_data_protocol/mod.rs b/crates/papyrus_network/src/streamed_data_protocol/mod.rs index e667e75711..34f8059199 100644 --- a/crates/papyrus_network/src/streamed_data_protocol/mod.rs +++ b/crates/papyrus_network/src/streamed_data_protocol/mod.rs @@ -15,7 +15,7 @@ pub struct InboundSessionId { value: usize, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] // TODO(shahak) remove allow(dead_code). #[allow(dead_code)] pub(crate) enum SessionId { @@ -38,5 +38,6 @@ pub(crate) enum GenericEvent { NewInboundSession { query: Query, inbound_session_id: InboundSessionId }, ReceivedData { outbound_session_id: OutboundSessionId, data: Data }, SessionFailed { session_id: SessionId, error: SessionError }, + SessionClosedByRequest { session_id: SessionId }, OutboundSessionClosedByPeer { outbound_session_id: OutboundSessionId }, }