diff --git a/.github/workflows/cctl.yml b/.github/workflows/cctl.yml new file mode 100644 index 0000000..ebc4be5 --- /dev/null +++ b/.github/workflows/cctl.yml @@ -0,0 +1,24 @@ +name: Casper 2.0-RC3 CCTL +on: [push] +jobs: + build_and_test: + runs-on: ubuntu-22.04 + services: + casper-cctl: + image: koxu1996/casper-cctl:2.0-rc3 + ports: + - 14101:14101 # RPC + - 21101:21101 # SSE + steps: + - uses: actions/checkout@v3 + - name: Test RPC - info_get_status call + run: > + curl --silent --location 'http://127.0.0.1:21101/rpc' + --header 'Content-Type: application/json' + --data '{"id": 1, "jsonrpc": "2.0", "method": "info_get_status", "params": []}' + | jq + - name: Test SSE - read stream for 5 seconds + continue-on-error: true + run: | + curl --silent --location http://127.0.0.1:14101/events --max-time 5 + (($? != 28)) && { printf '%s\n' "Unexpected exit code"; exit 1; } diff --git a/Cargo.toml b/Cargo.toml index 1b8d169..184d4cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,19 @@ license = "MIT OR Apache-2.0" readme = "README.md" repository = "https://github.com/casper-network/casper-sdk-rs" +# [dependencies] +# l1-types = { path = "../casper-node/types", package = "casper-types", features = [ +# "std", +# ] } +# l1-binary-port = { path = "../casper-node/binary_port", package = "casper-binary-port" } + + [dependencies] -l1-types = { path = "../casper-node/types", package = "casper-types", features = [ - "std", -] } -l1-binary-port = { path = "../casper-node/binary_port", package = "casper-binary-port" } +eventsource-stream = "0.2.3" +reqwest = { version = "0.12.5", features = ["json", "stream"] } +serde = { version = "1.0.189", features = ["derive"] } +serde_json = "1.0.107" +tokio = { version = "1", features = ["full"] } +futures = "0.3.30" +thiserror = "1.0" +casper-types = "4.0.1" diff --git a/src/api/mod.rs b/src/api/mod.rs index a7a02a1..bae9522 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,2 +1,2 @@ -pub(super) mod node; +pub mod node; pub(super) mod sidecar; diff --git a/src/api/node/mod.rs b/src/api/node/mod.rs index 3555650..e5571d5 100644 --- a/src/api/node/mod.rs +++ b/src/api/node/mod.rs @@ -1,3 +1,3 @@ pub(super) mod binary; pub(super) mod rest; -pub(super) mod sse; +pub mod sse; diff --git a/src/api/node/sse/client.rs b/src/api/node/sse/client.rs new file mode 100644 index 0000000..463d60c --- /dev/null +++ b/src/api/node/sse/client.rs @@ -0,0 +1,134 @@ +use super::{ + error::ClientError, + types::{CoreCommand, EventType}, + ClientCore, SseData, +}; +use std::time::Duration; +use tokio::sync::{mpsc, oneshot}; + +pub struct Client { + command_sender: mpsc::Sender, +} + +impl Client { + pub async fn new(url: &str) -> Self { + let client_core = ClientCore::new(url).await; + + let (tx, rx) = mpsc::channel(32); + let _handle = tokio::spawn(async move { + if let Err(e) = run_client_core(rx, client_core).await { + panic!("Unrecoverable client error: {}", e); + } + }); + + Client { command_sender: tx } + } + + pub async fn connect(&self) -> Result<(), ClientError> { + let (tx, rx) = oneshot::channel(); + self.command_sender + .send(CoreCommand::Connect(tx)) + .await + .map_err(|err| ClientError::CommandSendError(err))?; + rx.await.map_err(|err| ClientError::CommandRecvError(err)) + } + + pub async fn on_event( + &mut self, + event_type: EventType, + handler: F, + ) -> Result + where + F: Fn(SseData) + 'static + Send + Sync, + { + let (tx, rx) = oneshot::channel(); + self.command_sender + .send(CoreCommand::AddOnEventHandler( + event_type, + Box::new(handler), + tx, + )) + .await + .map_err(|err| ClientError::CommandSendError(err))?; + rx.await.map_err(|err| ClientError::CommandRecvError(err)) + } + + pub async fn wait_for_event( + &mut self, + event_type: EventType, + predicate: F, + timeout: Duration, + ) -> Result, ClientError> + where + F: Fn(SseData) -> bool + Send + Sync + 'static, + { + let (tx, mut rx) = mpsc::channel(1); + + // Register the event handler + let handler_id = self + .on_event(event_type, move |event_info: SseData| { + if predicate(event_info.clone()) { + // Send the matching event to the channel + let _ = tx + .try_send(event_info) + .map_err(|err| ClientError::ChannelInternalError(err)); + } + }) + .await?; + + // Wait for the event or timeout + let result = if timeout.is_zero() { + rx.recv().await + } else { + tokio::time::timeout(timeout, rx.recv()) + .await + .ok() + .flatten() + }; + + // Remove the event handler after the event is received or timeout occurs + self.remove_handler(handler_id).await?; + + match result { + Some(event_info) => Ok(Some(event_info)), + None => { + eprintln!("Timed out or stream exhausted while waiting for event"); + Ok(None) + } + } + } + + pub async fn remove_handler(&mut self, id: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.command_sender + .send(CoreCommand::RemoveEventHandler(id, tx)) + .await + .map_err(|err| ClientError::CommandSendError(err))?; + rx.await.map_err(|err| ClientError::CommandRecvError(err)) + } +} + +/// Handles incoming commands and delegates tasks to ClientCore. +async fn run_client_core( + mut rx: mpsc::Receiver, + mut client_core: ClientCore, +) -> Result<(), ClientError> { + loop { + if !client_core.is_connected() { + // Not connected yet, so only process Connect commands. + if let Some(command) = rx.recv().await { + client_core.handle_command(command).await? + } + } else { + tokio::select! { + Ok(Some(event)) = client_core.run_once() => { + client_core.handle_event(event)?; + }, + Some(command) = rx.recv() => { + client_core.handle_command(command) + .await? + }, + } + } + } +} diff --git a/src/api/node/sse/client_core.rs b/src/api/node/sse/client_core.rs new file mode 100644 index 0000000..fdc7103 --- /dev/null +++ b/src/api/node/sse/client_core.rs @@ -0,0 +1,139 @@ +use super::{ + error::ClientError, + types::{BoxedEventStream, CoreCommand, EventType, Handler}, + SseData, +}; +use eventsource_stream::{Event, Eventsource}; +use futures::stream::TryStreamExt; +use std::collections::HashMap; + +pub struct ClientCore { + url: String, + event_stream: Option, + next_handler_id: u64, + event_handlers: HashMap>>, + id_types: HashMap, + is_connected: bool, +} + +impl ClientCore { + pub async fn new(url: &str) -> Self { + ClientCore { + url: url.to_string(), + event_stream: None, + next_handler_id: 0, + event_handlers: HashMap::new(), + id_types: HashMap::new(), + is_connected: false, + } + } + + pub async fn connect(&mut self) -> Result<(), ClientError> { + // Connect to SSE endpoint. + let client = reqwest::Client::new(); + let response = client.get(&self.url).send().await?; + + let stream = response.bytes_stream(); + let mut event_stream = stream.eventsource(); + + // Handle the handshake with API version. + let handshake_event = event_stream + .try_next() + .await? + .ok_or(ClientError::StreamExhausted)?; + let handshake_data: SseData = serde_json::from_str(&handshake_event.data)?; + let _api_version = match handshake_data { + SseData::ApiVersion(v) => Ok(v), + _ => Err(ClientError::InvalidHandshake), + }?; + + // Wrap stream with box and store it. + let boxed_event_stream = Box::pin(event_stream); + self.event_stream = Some(boxed_event_stream); + self.is_connected = true; + + Ok(()) + } + + pub fn remove_handler(&mut self, id: u64) -> bool { + if let Some(event_type) = self.id_types.get(&id) { + match self.event_handlers.get_mut(&event_type) { + Some(handlers_for_type) => { + self.id_types.remove(&id); + handlers_for_type.remove(&id).is_some() + } + None => false, + } + } else { + false //not found + } + } + + pub fn is_connected(&self) -> bool { + self.is_connected + } + + pub fn handle_event(&mut self, event: Event) -> Result<(), ClientError> { + let data: SseData = serde_json::from_str(&event.data)?; + + match data { + SseData::ApiVersion(_) => return Err(ClientError::UnexpectedHandshake), // Should only happen once at connection + SseData::Shutdown => return Err(ClientError::NodeShutdown), + + // For each type, find and invoke registered handlers + event => { + if let Some(handlers) = self.event_handlers.get_mut(&event.event_type()) { + for handler in handlers.values() { + handler(event.clone()); // Invoke each handler for the event + } + } + } + } + Ok(()) + } + + pub async fn run_once(&mut self) -> Result, ClientError> { + if let Some(stream) = self.event_stream.as_mut() { + match stream.try_next().await { + Ok(Some(event)) => Ok(Some(event)), + Ok(None) => Err(ClientError::StreamExhausted), + Err(err) => Err(ClientError::EventStreamError(err)), + } + } else { + Err(ClientError::NoEventStreamAvailable) + } + } + + pub fn add_on_event_handler(&mut self, event_type: EventType, handler: Box) -> u64 { + let handlers = self.event_handlers.entry(event_type).or_default(); + let handler_id = self.next_handler_id; + handlers.insert(handler_id, handler); + self.id_types.insert(handler_id, event_type); + self.next_handler_id += 1; + handler_id + } + + pub async fn handle_command(&mut self, command: CoreCommand) -> Result<(), ClientError> { + match command { + CoreCommand::AddOnEventHandler(event_type, callback, completion_ack) => { + let event_id = self.add_on_event_handler(event_type, callback); + completion_ack + .send(event_id) + .map_err(|_| ClientError::ReciverDroppedError())?; + } + CoreCommand::Connect(completion_ack) => { + self.connect().await.map_err(ClientError::from)?; + completion_ack + .send(()) + .map_err(|_| ClientError::ReciverDroppedError())?; + } + CoreCommand::RemoveEventHandler(id, completion_ack) => { + let removed = self.remove_handler(id); + completion_ack + .send(removed) + .map_err(|_| ClientError::ReciverDroppedError())?; + } + } + Ok(()) + } +} diff --git a/src/api/node/sse/error.rs b/src/api/node/sse/error.rs new file mode 100644 index 0000000..970e77b --- /dev/null +++ b/src/api/node/sse/error.rs @@ -0,0 +1,42 @@ +use super::{types::CoreCommand, SseData}; +use eventsource_stream::EventStreamError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ClientError { + #[error("Failed to connect to SSE endpoint: {0}")] + ConnectionError(#[from] reqwest::Error), + + #[error("SSE stream exhausted unexpectedly")] + StreamExhausted, + + #[error("Invalid handshake event")] + InvalidHandshake, + + #[error("Unexpected handshake event")] + UnexpectedHandshake, + + #[error("Deserialization error: {0}")] + DeserializationError(#[from] serde_json::Error), + + #[error("Node shutdown")] + NodeShutdown, + + #[error("Failed to send command to core: {0}")] + CommandSendError(#[from] tokio::sync::mpsc::error::SendError), + + #[error("Failed to send ack to client")] + ReciverDroppedError(), + + #[error("Failed to recive command from core: {0}")] + CommandRecvError(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("Failed to send Event into the channel: {0}")] + ChannelInternalError(#[from] tokio::sync::mpsc::error::TrySendError), + + #[error("Error reading from event stream:{0}")] + EventStreamError(#[from] EventStreamError), + + #[error("No event stream available")] + NoEventStreamAvailable, +} diff --git a/src/api/node/sse/mod.rs b/src/api/node/sse/mod.rs index e69de29..2938fc5 100644 --- a/src/api/node/sse/mod.rs +++ b/src/api/node/sse/mod.rs @@ -0,0 +1,6 @@ +pub mod client_core; +pub mod error; +pub mod types; +pub use client_core::ClientCore; +pub use types::SseData; +pub mod client; diff --git a/src/api/node/sse/types.rs b/src/api/node/sse/types.rs new file mode 100644 index 0000000..176c9f7 --- /dev/null +++ b/src/api/node/sse/types.rs @@ -0,0 +1,63 @@ +use eventsource_stream::{Event, EventStreamError}; +use futures::stream::BoxStream; +use serde::{Deserialize, Serialize}; +use tokio::sync::oneshot; + +//copied from casper-sidecar + +#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash)] +pub enum EventType { + ApiVersion, + SidecarVersion, + BlockAdded, + TransactionAccepted, + TransactionProcessed, + TransactionExpired, + Fault, + FinalitySignature, + Step, + Shutdown, +} + +/// Casper does not expose SSE types directly, so we have to reimplement them. +/// Source: https://github.com/casper-network/casper-node/blob/8a9a864212b7c20fc17e1d0106b02c813ffded9d/node/src/components/event_stream_server/sse_server.rs#L56. +/// TODO: Add full deserialization details. +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] +pub enum SseData { + ApiVersion(casper_types::ProtocolVersion), + SidecarVersion(serde_json::Value), + BlockAdded(serde_json::Value), + TransactionAccepted(serde_json::Value), + TransactionProcessed(serde_json::Value), + TransactionExpired(serde_json::Value), + Fault(serde_json::Value), + FinalitySignature(serde_json::Value), + Step(serde_json::Value), + Shutdown, +} + +impl SseData { + pub fn event_type(&self) -> EventType { + match self { + SseData::ApiVersion(_) => EventType::ApiVersion, + SseData::SidecarVersion(_) => EventType::SidecarVersion, + SseData::BlockAdded(_) => EventType::BlockAdded, + SseData::TransactionAccepted(_) => EventType::TransactionAccepted, + SseData::TransactionProcessed(_) => EventType::TransactionProcessed, + SseData::TransactionExpired(_) => EventType::TransactionExpired, + SseData::Fault(_) => EventType::Fault, + SseData::FinalitySignature(_) => EventType::FinalitySignature, + SseData::Step(_) => EventType::Step, + SseData::Shutdown => EventType::Shutdown, + } + } +} + +pub enum CoreCommand { + Connect(oneshot::Sender<()>), + AddOnEventHandler(EventType, Box, oneshot::Sender), + RemoveEventHandler(u64, oneshot::Sender), +} + +pub type Handler = dyn Fn(SseData) + 'static + Send + Sync; +pub type BoxedEventStream = BoxStream<'static, Result>>; diff --git a/tests/sse_client.rs b/tests/sse_client.rs new file mode 100644 index 0000000..91deb75 --- /dev/null +++ b/tests/sse_client.rs @@ -0,0 +1,293 @@ +#[cfg(test)] +mod utils; +mod tests { + use crate::utils::MockSse; + use casper_sdk_rs::api::node::sse::error::ClientError; + use casper_sdk_rs::api::node::sse::{client::Client, types::EventType, ClientCore, SseData}; + use casper_types::ProtocolVersion; + use core::panic; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use tokio::sync::mpsc; + + /* + *client_core tests + */ + + #[tokio::test] + async fn test_client_core_connect_and_handshake() { + let mock_server = MockSse::start().await; + let mut client_core = ClientCore::new(&mock_server.url()).await; + + // Test successful connection + let result = client_core.connect().await; + assert!(result.is_ok(), "Connection should succeed"); + assert!(client_core.is_connected(), "Should be marked as connected"); + + mock_server + .send_event(SseData::BlockAdded(serde_json::json!({ + "height": 1, + }))) + .await + .unwrap(); + + let event = client_core.run_once().await.expect("Should get event"); + let event = event.expect("Event should not be None"); + assert!(event.data.contains("BlockAdded")); + match client_core.handle_event(event) { + Ok(()) => (), // Success + Err(err) => panic!("Unexpected error: {:?}", err), + } + } + + #[tokio::test] + async fn test_client_core_double_handshake() { + let mock_server = MockSse::start().await; + let mut client_core = ClientCore::new(&mock_server.url()).await; + + client_core.connect().await.unwrap(); + + let handshake = SseData::ApiVersion(ProtocolVersion::from_parts(2, 0, 0)); + mock_server.send_event(handshake.clone()).await.unwrap(); + let event = client_core.run_once().await.expect("Should get event"); + let event = event.expect("Event should not be None"); + match client_core.handle_event(event) { + Ok(()) => panic!("Expected error"), + Err(ClientError::UnexpectedHandshake) => (), // Success + Err(err) => panic!("Unexpected error: {:?}", err), + } + } + + #[tokio::test] + async fn test_client_core_add_on_event_handler_remove_on_event_handler() { + let mock_server = MockSse::start().await; + let mut client_core = ClientCore::new(&mock_server.url()).await; + + client_core.connect().await.unwrap(); + + let handler_invoked = Arc::new(Mutex::new(false)); + + let flipper_handler = { + let handler_invoked = Arc::clone(&handler_invoked); + move |event_info: SseData| { + let mut handler_invoked = handler_invoked.lock().unwrap(); + *handler_invoked = !*handler_invoked; + println!("🏃 Running Handler: {:?}", event_info.event_type()); + } + }; + + let event_type = EventType::BlockAdded; + let handler_id = client_core.add_on_event_handler(event_type, Box::new(flipper_handler)); + + assert!(handler_id == 0, "Handler ID should be assigned"); + + // Test 1: Handler invocation + let block_added_event = SseData::BlockAdded(serde_json::json!({ + "height": 1, + })); + mock_server + .send_event(block_added_event.clone()) + .await + .unwrap(); + + let event = client_core.run_once().await.unwrap().unwrap(); + client_core.handle_event(event).unwrap(); + + assert!( + *handler_invoked.lock().unwrap(), + "Handler should have been called once" + ); + + // Test 2: Second invocation flips the flag back + mock_server + .send_event(block_added_event.clone()) + .await + .unwrap(); + + let event = client_core.run_once().await.unwrap().unwrap(); + client_core.handle_event(event).unwrap(); + assert!( + !*handler_invoked.lock().unwrap(), + "Handler should have been called twice (and flipped back)" + ); + + // Test 3: Removal of the handler, the flag should not be changed by new events + let res = client_core.remove_handler(handler_id); + assert_eq!(res, true); + + mock_server.send_event(block_added_event).await.unwrap(); + + let event = client_core.run_once().await.unwrap().unwrap(); + client_core.handle_event(event).unwrap(); + assert!( + !*handler_invoked.lock().unwrap(), + "Handler should not be called after removal" + ); + } + + /* + *client tests + */ + + #[tokio::test] + async fn test_client_connect() { + let mock_server = MockSse::start().await; + let client = Client::new(&mock_server.url()).await; + let result = client.connect().await; + assert!(result.is_ok(), "Client should connect successfully"); + } + + #[tokio::test] + async fn test_client_on_event() { + let mock_server = MockSse::start().await; + let mut client = Client::new(&mock_server.url()).await; + + let (tx_block_added, mut rx_block_added) = mpsc::channel(1); // Channel for BlockAdded events + let (tx_tx_processed, mut rx_tx_processed) = mpsc::channel(1); // Channel for TransactionProcessed events + + let block_added_handler = move |event: SseData| { + tx_block_added.try_send(event).unwrap(); + }; + + let tx_processed_handler = move |event: SseData| { + tx_tx_processed.try_send(event).unwrap(); + }; + + // Test 1: Register handler before connect + let block_added_handler_id = client + .on_event(EventType::BlockAdded, block_added_handler) + .await + .unwrap(); + assert_eq!(block_added_handler_id, 0, "First handler should have ID 0"); + + // Connect the client + client.connect().await.unwrap(); + + // Test 2: Register handler after connect + let transaction_processed_handler_id = client + .on_event(EventType::TransactionProcessed, tx_processed_handler) + .await + .unwrap(); + assert_eq!( + transaction_processed_handler_id, 1, + "Second handler should have ID 1" + ); + + // Test 3: TransactionProcessed event handling + let transaction_processed_event = SseData::TransactionProcessed(serde_json::json!({ + "height": 1, + })); + mock_server + .send_event(transaction_processed_event.clone()) + .await + .unwrap(); + + let received_event = rx_tx_processed + .recv() + .await + .expect("Should receive TransactionProcessed event"); + assert_eq!( + received_event, transaction_processed_event, + "Received event should match" + ); + + // Test 4: BlockAdded event handling + let block_added_event = SseData::BlockAdded(serde_json::json!({ + "height": 1, + })); + mock_server + .send_event(block_added_event.clone()) + .await + .unwrap(); + + let received_event = rx_block_added + .recv() + .await + .expect("Should receive BlockAdded event"); + assert_eq!( + received_event, block_added_event, + "Received event should match" + ); + } + + #[tokio::test] + async fn test_client_on_event_multiple_invocations() { + let mock_server = MockSse::start().await; + let mut client = Client::new(&mock_server.url()).await; + client.connect().await.unwrap(); + + let invocation_count = Arc::new(Mutex::new(0)); + + let finality_signature_handler = { + let invocation_count = Arc::clone(&invocation_count); + move |_| { + let mut count = invocation_count.lock().unwrap(); + *count += 1; + } + }; + + let _handler_id = client + .on_event(EventType::FinalitySignature, finality_signature_handler) + .await + .unwrap(); + + // Send events + for i in 0..5 { + let finality_signature_event = SseData::FinalitySignature(serde_json::json!({ + "height": i, + })); + mock_server + .send_event(finality_signature_event) + .await + .unwrap(); + + // Short delay after each event to allow processing + tokio::time::sleep(Duration::from_millis(50)).await; + } + let final_count = *invocation_count.lock().unwrap(); + assert_eq!(final_count, 5, "Handler should have been called 5 times"); + } + + #[tokio::test] + async fn test_client_wait_for_event() { + let mock_server = MockSse::start().await; + let mut client = Client::new(&mock_server.url()).await; + client.connect().await.unwrap(); + + let predicate = |data: SseData| { + if let SseData::BlockAdded(block_data) = data { + if let Some(height) = block_data["height"].as_u64() { + return height == 13; + } + } + false + }; + let timeout = Duration::from_millis(100); + + // Spawn a task to wait for the event + let event_future = tokio::spawn(async move { + client + .wait_for_event(EventType::BlockAdded, predicate, timeout) + .await + }); + + let block_added_event = SseData::BlockAdded(serde_json::json!({ + "height": 13, + })); + mock_server + .send_event(block_added_event.clone()) + .await + .unwrap(); + + let result = event_future.await.unwrap(); + + match result { + Ok(Some(SseData::BlockAdded(block_data))) => { + assert_eq!(block_data["height"].as_u64().unwrap(), 13); + } + Ok(Some(event)) => panic!("Expected BlockAdded event, got {:?}", event), + Ok(None) => panic!("Timed out waiting for event"), + Err(err) => panic!("Unexpected error: {:?}", err), + } + } +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index e69de29..619be1e 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -0,0 +1,155 @@ +use casper_sdk_rs::api::node::sse::SseData; +use casper_types::ProtocolVersion; +use std::net::SocketAddr; +use tokio::io::{AsyncWriteExt, BufWriter}; +use tokio::net::TcpListener; +use tokio::sync::mpsc; + +pub struct MockSse { + addr: String, + tx: mpsc::Sender, +} + +impl MockSse { + pub async fn start() -> Self { + let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); // Use port 0 for dynamic allocation + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = format!("http://{}", listener.local_addr().unwrap().to_string()); + let (tx, mut rx) = mpsc::channel(32); + + tokio::spawn(async move { + if let Ok((mut socket, _)) = listener.accept().await { + let mut writer = BufWriter::new(&mut socket); + let version_event = SseData::ApiVersion(ProtocolVersion::from_parts(2, 0, 0)); + let serialized_event = format!( + "data: {}\n\n", + serde_json::to_string(&version_event).unwrap() + ); + writer + .write_all("HTTP/1.1 200 OK\r\n".as_bytes()) + .await + .unwrap(); + writer + .write_all("content-type: text/event-stream\r\n".as_bytes()) + .await + .unwrap(); + writer + .write_all("Cache-Control: no-cache\r\n".as_bytes()) + .await + .unwrap(); + writer + .write_all("Connection: keep-alive\r\n\r\n".as_bytes()) + .await + .unwrap(); + println!( + "🤝 Initializing SSE stream with API version handshake: {:?}", + serialized_event + ); + writer.write_all(serialized_event.as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); + + while let Some(event) = rx.recv().await { + let serialized_event = + format!("data: {}\n\n", serde_json::to_string(&event).unwrap()); + println!("📣 Broadcasting event: {:?}", serialized_event); + writer.write_all(serialized_event.as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); + } + } + }); + Self { addr, tx } + } + pub fn url(&self) -> String { + self.addr.clone() + } + + pub async fn send_event(&self, data: SseData) -> Result<(), Box> { + self.tx + .send(data) + .await + .map_err(|_| "Failed to send event".into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use eventsource_stream::Eventsource; + use futures::StreamExt; + use reqwest::Client; + use serde_json::json; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn test_mock_sse_start_and_connect() { + let mock_server = MockSse::start().await; + let url = mock_server.url(); + let client = Client::new(); + let response = client.get(&url).send().await.unwrap(); + + assert!(response.status().is_success(), "Invalid HTTP status"); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream", + "Invalid Content-Type" + ); + } + + #[tokio::test] + async fn test_mock_sse_handshake() { + let mock_server = MockSse::start().await; + let url = mock_server.url(); + let client = Client::new(); + let response = client.get(&url).send().await.unwrap(); + + let mut stream = response.bytes_stream().eventsource(); + let event = stream.next().await.unwrap().unwrap(); + let expected_version = ProtocolVersion::from_parts(2, 0, 0); + + let handshake_data: SseData = serde_json::from_str(&event.data).unwrap(); + match handshake_data { + SseData::ApiVersion(v) => assert_eq!(v, expected_version), + _ => panic!("Expected ApiVersion event, got {:?}", handshake_data), + } + } + + #[tokio::test] + async fn test_mock_sse_send_event() { + let mock_server = MockSse::start().await; + let url = mock_server.url(); + let client = Client::new(); + let response = client.get(&url).send().await.unwrap(); + + // Skip the handshake + let mut stream = response.bytes_stream().eventsource().skip(1); + + let test_event = SseData::BlockAdded(json!({ + "height": 100, + })); + mock_server.send_event(test_event.clone()).await.unwrap(); + + let event = timeout(Duration::from_secs(2), stream.next()) + .await + .expect("Should receive sent event") + .expect("Event stream should not be empty") + .unwrap(); + + let received_data: SseData = serde_json::from_str(&event.data).unwrap(); + assert_eq!(received_data, test_event); + + let test_event = SseData::FinalitySignature(json!({ + "height": 100, + })); + mock_server.send_event(test_event.clone()).await.unwrap(); + + let event = timeout(Duration::from_secs(2), stream.next()) + .await + .expect("Should receive sent event") + .expect("Event stream should not be empty") + .unwrap(); + + let received_data: SseData = serde_json::from_str(&event.data).unwrap(); + assert_eq!(received_data, test_event); + } +}