Skip to content

Commit

Permalink
Support subscribing to events via ws (#652)
Browse files Browse the repository at this point in the history
* Simplify event receiving via new util function
  • Loading branch information
FabijanC authored Nov 27, 2024
1 parent 979a986 commit 55fe695
Show file tree
Hide file tree
Showing 9 changed files with 589 additions and 19 deletions.
2 changes: 1 addition & 1 deletion crates/starknet-devnet-core/src/starknet/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(crate) fn get_events(
/// * `address` - Optional. The address to filter the event by.
/// * `keys_filter` - Optional. The keys to filter the event by.
/// * `event` - The event to check if it applies to the filters.
fn check_if_filter_applies_for_event(
pub fn check_if_filter_applies_for_event(
address: &Option<ContractAddress>,
keys_filter: &Option<Vec<Vec<Felt>>>,
event: &Event,
Expand Down
13 changes: 12 additions & 1 deletion crates/starknet-devnet-core/src/starknet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ mod add_l1_handler_transaction;
mod cheats;
pub(crate) mod defaulter;
mod estimations;
mod events;
pub mod events;
mod get_class_impls;
mod predeployed;
pub mod starknet_config;
Expand Down Expand Up @@ -1034,6 +1034,17 @@ impl Starknet {
.ok_or(Error::NoTransaction)
}

pub fn get_unlimited_events(
&self,
from_block: Option<BlockId>,
to_block: Option<BlockId>,
address: Option<ContractAddress>,
keys: Option<Vec<Vec<Felt>>>,
) -> DevnetResult<Vec<EmittedEvent>> {
events::get_events(self, from_block, to_block, address, keys, 0, None)
.map(|(emitted_events, _)| emitted_events)
}

pub fn get_events(
&self,
from_block: Option<BlockId>,
Expand Down
59 changes: 52 additions & 7 deletions crates/starknet-devnet-server/src/api/json_rpc/endpoints_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use starknet_types::starknet_api::block::{BlockNumber, BlockStatus};

use super::error::ApiError;
use super::models::{
BlockInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput, TransactionBlockInput,
BlockInput, EventsSubscriptionInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput,
TransactionBlockInput,
};
use super::{JsonRpcHandler, JsonRpcSubscriptionRequest};
use crate::rpc_core::request::Id;
Expand All @@ -33,7 +34,9 @@ impl JsonRpcHandler {
JsonRpcSubscriptionRequest::PendingTransactions(data) => {
self.subscribe_pending_txs(data, rpc_request_id, socket_id).await
}
JsonRpcSubscriptionRequest::Events => todo!(),
JsonRpcSubscriptionRequest::Events(data) => {
self.subscribe_events(data, rpc_request_id, socket_id).await
}
JsonRpcSubscriptionRequest::Unsubscribe(SubscriptionIdInput { subscription_id }) => {
let mut sockets = self.api.sockets.lock().await;
let socket_context = sockets.get_mut(&socket_id).ok_or(
Expand All @@ -42,15 +45,14 @@ impl JsonRpcHandler {
}),
)?;

socket_context.unsubscribe(rpc_request_id, subscription_id).await?;
Ok(())
socket_context.unsubscribe(rpc_request_id, subscription_id).await
}
}
}

/// Returns (starting block number, latest block number). Returns an error in case the starting
/// block does not exist or there are too many blocks.
async fn convert_to_block_number_range(
async fn get_validated_block_number_range(
&self,
mut starting_block_id: BlockId,
) -> Result<(u64, u64), ApiError> {
Expand Down Expand Up @@ -105,7 +107,7 @@ impl JsonRpcHandler {
};

let (query_block_number, latest_block_number) =
self.convert_to_block_number_range(block_id).await?;
self.get_validated_block_number_range(block_id).await?;

// perform the actual subscription
let mut sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -233,7 +235,7 @@ impl JsonRpcHandler {
};

let (query_block_number, latest_block_number) =
self.convert_to_block_number_range(query_block_id).await?;
self.get_validated_block_number_range(query_block_id).await?;

// perform the actual subscription
let mut sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -280,4 +282,47 @@ impl JsonRpcHandler {

Ok(())
}

async fn subscribe_events(
&self,
maybe_subscription_input: Option<EventsSubscriptionInput>,
rpc_request_id: Id,
socket_id: SocketId,
) -> Result<(), ApiError> {
let address = maybe_subscription_input
.as_ref()
.and_then(|subscription_input| subscription_input.from_address);

let starting_block_id = maybe_subscription_input
.as_ref()
.and_then(|subscription_input| subscription_input.block.as_ref())
.map(|b| b.0)
.unwrap_or(BlockId::Tag(BlockTag::Latest));

self.get_validated_block_number_range(starting_block_id).await?;

let keys_filter =
maybe_subscription_input.and_then(|subscription_input| subscription_input.keys);

let mut sockets = self.api.sockets.lock().await;
let socket_context = sockets.get_mut(&socket_id).ok_or(ApiError::StarknetDevnetError(
Error::UnexpectedInternalError { msg: format!("Unregistered socket ID: {socket_id}") },
))?;

let subscription = Subscription::Events { address, keys_filter: keys_filter.clone() };
let subscription_id = socket_context.subscribe(rpc_request_id, subscription).await;

let events = self.api.starknet.lock().await.get_unlimited_events(
Some(starting_block_id),
Some(BlockId::Tag(BlockTag::Latest)),
address,
keys_filter,
)?;

for event in events {
socket_context.notify(subscription_id, SubscriptionNotification::Event(event)).await;
}

Ok(())
}
}
19 changes: 15 additions & 4 deletions crates/starknet-devnet-server/src/api/json_rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use models::{
BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput, BlockInput,
CallInput, EstimateFeeInput, EventsInput, GetStorageInput, L1TransactionHashInput,
PendingTransactionsSubscriptionInput, SubscriptionIdInput, TransactionBlockInput,
TransactionHashInput, TransactionHashOutput,
CallInput, EstimateFeeInput, EventsInput, EventsSubscriptionInput, GetStorageInput,
L1TransactionHashInput, PendingTransactionsSubscriptionInput, SubscriptionIdInput,
TransactionBlockInput, TransactionHashInput, TransactionHashOutput,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -311,6 +311,17 @@ impl JsonRpcHandler {
}),
));
}

let events = starknet.get_unlimited_events(
Some(BlockId::Tag(BlockTag::Latest)),
Some(BlockId::Tag(BlockTag::Latest)),
None,
None,
)?;

for event in events {
notifications.push(SubscriptionNotification::Event(event));
}
}

let sockets = self.api.sockets.lock().await;
Expand Down Expand Up @@ -757,7 +768,7 @@ pub enum JsonRpcSubscriptionRequest {
#[serde(rename = "starknet_subscribePendingTransactions", with = "optional_params")]
PendingTransactions(Option<PendingTransactionsSubscriptionInput>),
#[serde(rename = "starknet_subscribeEvents")]
Events,
Events(Option<EventsSubscriptionInput>),
#[serde(rename = "starknet_unsubscribe")]
Unsubscribe(SubscriptionIdInput),
}
Expand Down
12 changes: 11 additions & 1 deletion crates/starknet-devnet-server/src/api/json_rpc/models.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use starknet_rs_core::types::{Hash256, TransactionExecutionStatus, TransactionFinalityStatus};
use starknet_rs_core::types::{
Felt, Hash256, TransactionExecutionStatus, TransactionFinalityStatus,
};
use starknet_types::contract_address::ContractAddress;
use starknet_types::felt::{BlockHash, ClassHash, TransactionHash};
use starknet_types::patricia_key::PatriciaKey;
Expand Down Expand Up @@ -205,6 +207,14 @@ pub struct PendingTransactionsSubscriptionInput {
pub sender_address: Option<Vec<ContractAddress>>,
}

#[derive(Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct EventsSubscriptionInput {
pub block: Option<BlockId>,
pub from_address: Option<ContractAddress>,
pub keys: Option<Vec<Vec<Felt>>>,
}

#[cfg(test)]
mod tests {
use starknet_rs_core::types::{BlockId as ImportedBlockId, BlockTag, Felt};
Expand Down
18 changes: 13 additions & 5 deletions crates/starknet-devnet-server/src/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
use futures::SinkExt;
use serde::{self, Serialize};
use starknet_rs_core::types::BlockTag;
use starknet_core::starknet::events::check_if_filter_applies_for_event;
use starknet_rs_core::types::{BlockTag, Felt};
use starknet_types::contract_address::ContractAddress;
use starknet_types::emitted_event::EmittedEvent;
use starknet_types::felt::TransactionHash;
use starknet_types::rpc::block::BlockHeader;
use starknet_types::rpc::transactions::{TransactionStatus, TransactionWithHash};
Expand Down Expand Up @@ -39,7 +41,7 @@ pub enum Subscription {
TransactionStatus { tag: BlockTag, transaction_hash: TransactionHash },
PendingTransactionsFull { address_filter: AddressFilter },
PendingTransactionsHash { address_filter: AddressFilter },
Events,
Events { address: Option<ContractAddress>, keys_filter: Option<Vec<Vec<Felt>>> },
}

impl Subscription {
Expand All @@ -51,7 +53,7 @@ impl Subscription {
| Subscription::PendingTransactionsHash { .. } => {
SubscriptionConfirmation::NewSubscription(id)
}
Subscription::Events => SubscriptionConfirmation::NewSubscription(id),
Subscription::Events { .. } => SubscriptionConfirmation::NewSubscription(id),
}
}

Expand Down Expand Up @@ -90,7 +92,11 @@ impl Subscription {
};
}
}
Subscription::Events => todo!(),
Subscription::Events { address, keys_filter } => {
if let SubscriptionNotification::Event(event) = notification {
return check_if_filter_applies_for_event(address, keys_filter, &event.into());
}
}
}

false
Expand Down Expand Up @@ -141,6 +147,7 @@ pub enum SubscriptionNotification {
NewHeads(Box<BlockHeader>),
TransactionStatus(NewTransactionStatus),
PendingTransaction(PendingTransactionNotification),
Event(EmittedEvent),
}

impl SubscriptionNotification {
Expand All @@ -152,7 +159,8 @@ impl SubscriptionNotification {
}
SubscriptionNotification::PendingTransaction(_) => {
"starknet_subscriptionPendingTransactions"
} // SubscriptionNotification::Events => "starknet_subscriptionEvents",
}
SubscriptionNotification::Event(_) => "starknet_subscriptionEvents",
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/starknet-devnet-types/src/rpc/emitted_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,13 @@ impl From<&blockifier::execution::call_info::OrderedEvent> for OrderedEvent {
}
}
}

impl From<&EmittedEvent> for Event {
fn from(emitted_event: &EmittedEvent) -> Self {
Self {
from_address: emitted_event.from_address,
keys: emitted_event.keys.clone(),
data: emitted_event.data.clone(),
}
}
}
13 changes: 13 additions & 0 deletions crates/starknet-devnet/tests/common/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ pub async fn receive_rpc_via_ws(
Ok(serde_json::from_str(&msg.into_text()?)?)
}

/// Extract `result` from the notification and assert general properties
pub async fn receive_notification(
ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
method: &str,
expected_subscription_id: i64,
) -> Result<serde_json::Value, anyhow::Error> {
let mut notification = receive_rpc_via_ws(ws).await?;
assert_eq!(notification["jsonrpc"], "2.0");
assert_eq!(notification["method"], method);
assert_eq!(notification["params"]["subscription_id"], expected_subscription_id);
Ok(notification["params"].take()["result"].take())
}

pub async fn assert_no_notifications(ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>) {
match receive_rpc_via_ws(ws).await {
Ok(resp) => panic!("Expected no notifications; found: {resp}"),
Expand Down
Loading

0 comments on commit 55fe695

Please sign in to comment.