diff --git a/crates/torii/client/src/client/mod.rs b/crates/torii/client/src/client/mod.rs index 7fa96c17ab..e2fd1f58bd 100644 --- a/crates/torii/client/src/client/mod.rs +++ b/crates/torii/client/src/client/mod.rs @@ -10,7 +10,9 @@ use starknet::core::types::Felt; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; use tokio::sync::RwLock as AsyncRwLock; -use torii_grpc::client::{EntityUpdateStreaming, EventUpdateStreaming, IndexerUpdateStreaming}; +use torii_grpc::client::{ + EntityUpdateStreaming, EventUpdateStreaming, IndexerUpdateStreaming, TokenBalanceStreaming, +}; use torii_grpc::proto::world::{ RetrieveEntitiesResponse, RetrieveEventsResponse, RetrieveTokenBalancesResponse, RetrieveTokensResponse, @@ -209,4 +211,37 @@ impl Client { .await?; Ok(stream) } + + /// Subscribes to token balances updates. + /// If no contract addresses are provided, it will subscribe to updates for all contract + /// addresses. If no account addresses are provided, it will subscribe to updates for all + /// account addresses. + pub async fn on_token_balance_updated( + &self, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result { + let mut grpc_client = self.inner.write().await; + let stream = + grpc_client.subscribe_token_balances(contract_addresses, account_addresses).await?; + Ok(stream) + } + + /// Update the token balances subscription + pub async fn update_token_balance_subscription( + &self, + subscription_id: u64, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result<(), Error> { + let mut grpc_client = self.inner.write().await; + grpc_client + .update_token_balances_subscription( + subscription_id, + contract_addresses, + account_addresses, + ) + .await?; + Ok(()) + } } diff --git a/crates/torii/core/src/executor/erc.rs b/crates/torii/core/src/executor/erc.rs index 2f4225f345..3e115e834e 100644 --- a/crates/torii/core/src/executor/erc.rs +++ b/crates/torii/core/src/executor/erc.rs @@ -15,8 +15,9 @@ use tracing::{debug, trace, warn}; use super::{ApplyBalanceDiffQuery, Executor}; use crate::constants::{IPFS_CLIENT_MAX_RETRY, SQL_FELT_DELIMITER, TOKEN_BALANCE_TABLE}; use crate::executor::LOG_TARGET; +use crate::simple_broker::SimpleBroker; use crate::sql::utils::{felt_to_sql_string, sql_string_to_u256, u256_to_sql_string, I256}; -use crate::types::ContractType; +use crate::types::{ContractType, TokenBalance}; use crate::utils::fetch_content_from_ipfs; #[derive(Debug, Clone)] @@ -159,18 +160,21 @@ impl<'c, P: Provider + Sync + Send + 'static> Executor<'c, P> { } // write the new balance to the database - sqlx::query(&format!( + let token_balance: TokenBalance = sqlx::query_as(&format!( "INSERT OR REPLACE INTO {TOKEN_BALANCE_TABLE} (id, contract_address, account_address, \ - token_id, balance) VALUES (?, ?, ?, ?, ?)", + token_id, balance) VALUES (?, ?, ?, ?, ?) RETURNING *", )) .bind(id) .bind(contract_address) .bind(account_address) .bind(token_id) .bind(u256_to_sql_string(&balance)) - .execute(&mut **tx) + .fetch_one(&mut **tx) .await?; + debug!(target: LOG_TARGET, token_balance = ?token_balance, "Applied balance diff"); + SimpleBroker::publish(token_balance); + Ok(()) } diff --git a/crates/torii/grpc/proto/world.proto b/crates/torii/grpc/proto/world.proto index 2c7e7b1270..57e4ef76db 100644 --- a/crates/torii/grpc/proto/world.proto +++ b/crates/torii/grpc/proto/world.proto @@ -34,6 +34,12 @@ service World { // Update entity subscription rpc UpdateEventMessagesSubscription (UpdateEventMessagesSubscriptionRequest) returns (google.protobuf.Empty); + // Subscribe to token balance updates. + rpc SubscribeTokenBalances (RetrieveTokenBalancesRequest) returns (stream SubscribeTokenBalancesResponse); + + // Update token balance subscription + rpc UpdateTokenBalancesSubscription (UpdateTokenBalancesSubscriptionRequest) returns (google.protobuf.Empty); + // Retrieve entities rpc RetrieveEventMessages (RetrieveEventMessagesRequest) returns (RetrieveEntitiesResponse); @@ -50,6 +56,24 @@ service World { rpc RetrieveTokenBalances (RetrieveTokenBalancesRequest) returns (RetrieveTokenBalancesResponse); } +// A request to update a token balance subscription +message UpdateTokenBalancesSubscriptionRequest { + // The subscription ID + uint64 subscription_id = 1; + // The list of contract addresses to subscribe to + repeated bytes contract_addresses = 2; + // The list of account addresses to subscribe to + repeated bytes account_addresses = 3; +} + +// A response containing token balances +message SubscribeTokenBalancesResponse { + // The subscription ID + uint64 subscription_id = 1; + // The token balance + types.TokenBalance balance = 2; +} + // A request to retrieve tokens message RetrieveTokensRequest { // The list of contract addresses to retrieve tokens for diff --git a/crates/torii/grpc/src/client.rs b/crates/torii/grpc/src/client.rs index f24d5f5f7e..53d34f77ff 100644 --- a/crates/torii/grpc/src/client.rs +++ b/crates/torii/grpc/src/client.rs @@ -16,11 +16,14 @@ use crate::proto::world::{ SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventMessagesRequest, SubscribeEventsRequest, SubscribeEventsResponse, SubscribeIndexerRequest, SubscribeIndexerResponse, SubscribeModelsRequest, SubscribeModelsResponse, - UpdateEntitiesSubscriptionRequest, UpdateEventMessagesSubscriptionRequest, + SubscribeTokenBalancesResponse, UpdateEntitiesSubscriptionRequest, + UpdateEventMessagesSubscriptionRequest, UpdateTokenBalancesSubscriptionRequest, WorldMetadataRequest, }; use crate::types::schema::{Entity, SchemaError}; -use crate::types::{EntityKeysClause, Event, EventQuery, IndexerUpdate, ModelKeysClause, Query}; +use crate::types::{ + EntityKeysClause, Event, EventQuery, IndexerUpdate, ModelKeysClause, Query, TokenBalance, +}; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -295,6 +298,76 @@ impl WorldClient { None => empty_state_update(), })))) } + + /// Subscribe to token balances. + pub async fn subscribe_token_balances( + &mut self, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result { + let request = RetrieveTokenBalancesRequest { + contract_addresses: contract_addresses + .into_iter() + .map(|c| c.to_bytes_be().to_vec()) + .collect(), + account_addresses: account_addresses + .into_iter() + .map(|a| a.to_bytes_be().to_vec()) + .collect(), + }; + let stream = self + .inner + .subscribe_token_balances(request) + .await + .map_err(Error::Grpc) + .map(|res| res.into_inner())?; + Ok(TokenBalanceStreaming(stream.map_ok(Box::new(|res| { + (res.subscription_id, res.balance.unwrap().try_into().expect("must able to serialize")) + })))) + } + + /// Update a token balances subscription. + pub async fn update_token_balances_subscription( + &mut self, + subscription_id: u64, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result<(), Error> { + let request = UpdateTokenBalancesSubscriptionRequest { + subscription_id, + contract_addresses: contract_addresses + .into_iter() + .map(|c| c.to_bytes_be().to_vec()) + .collect(), + account_addresses: account_addresses + .into_iter() + .map(|a| a.to_bytes_be().to_vec()) + .collect(), + }; + self.inner + .update_token_balances_subscription(request) + .await + .map_err(Error::Grpc) + .map(|res| res.into_inner()) + } +} + +type TokenBalanceMappedStream = MapOk< + tonic::Streaming, + Box (SubscriptionId, TokenBalance) + Send>, +>; + +#[derive(Debug)] +pub struct TokenBalanceStreaming(TokenBalanceMappedStream); + +impl Stream for TokenBalanceStreaming { + type Item = ::Item; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_next_unpin(cx) + } } type ModelDiffMappedStream = MapOk< diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 8e884a08ce..3e208d7acb 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -34,6 +34,7 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; use subscriptions::event::EventManager; use subscriptions::indexer::IndexerManager; +use subscriptions::token_balance::TokenBalanceManager; use tokio::net::TcpListener; use tokio::sync::mpsc::{channel, Receiver}; use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; @@ -59,7 +60,8 @@ use crate::proto::world::{ RetrieveTokenBalancesResponse, RetrieveTokensRequest, RetrieveTokensResponse, SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventMessagesRequest, SubscribeEventsResponse, SubscribeIndexerRequest, SubscribeIndexerResponse, - UpdateEventMessagesSubscriptionRequest, WorldMetadataRequest, WorldMetadataResponse, + SubscribeTokenBalancesResponse, UpdateEventMessagesSubscriptionRequest, + UpdateTokenBalancesSubscriptionRequest, WorldMetadataRequest, WorldMetadataResponse, }; use crate::proto::{self}; use crate::types::schema::SchemaError; @@ -123,6 +125,7 @@ pub struct DojoWorld { event_manager: Arc, state_diff_manager: Arc, indexer_manager: Arc, + token_balance_manager: Arc, } impl DojoWorld { @@ -138,6 +141,7 @@ impl DojoWorld { let event_manager = Arc::new(EventManager::default()); let state_diff_manager = Arc::new(StateDiffManager::default()); let indexer_manager = Arc::new(IndexerManager::default()); + let token_balance_manager = Arc::new(TokenBalanceManager::default()); tokio::task::spawn(subscriptions::model_diff::Service::new_with_block_rcv( block_rx, @@ -156,6 +160,10 @@ impl DojoWorld { tokio::task::spawn(subscriptions::indexer::Service::new(Arc::clone(&indexer_manager))); + tokio::task::spawn(subscriptions::token_balance::Service::new(Arc::clone( + &token_balance_manager, + ))); + Self { pool, world_address, @@ -165,6 +173,7 @@ impl DojoWorld { event_manager, state_diff_manager, indexer_manager, + token_balance_manager, } } } @@ -1056,6 +1065,15 @@ impl DojoWorld { Ok(RetrieveTokenBalancesResponse { balances }) } + async fn subscribe_token_balances( + &self, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result>, Error> + { + self.token_balance_manager.add_subscriber(contract_addresses, account_addresses).await + } + async fn subscribe_indexer( &self, contract_address: Felt, @@ -1508,6 +1526,8 @@ type SubscribeIndexerResponseStream = Pin> + Send>>; type RetrieveEntitiesStreamingResponseStream = Pin> + Send>>; +type SubscribeTokenBalancesResponseStream = + Pin> + Send>>; #[tonic::async_trait] impl proto::world::world_server::World for DojoWorld { @@ -1517,6 +1537,7 @@ impl proto::world::world_server::World for DojoWorld { type SubscribeEventsStream = SubscribeEventsResponseStream; type SubscribeIndexerStream = SubscribeIndexerResponseStream; type RetrieveEntitiesStreamingStream = RetrieveEntitiesStreamingResponseStream; + type SubscribeTokenBalancesStream = SubscribeTokenBalancesResponseStream; async fn world_metadata( &self, @@ -1619,6 +1640,52 @@ impl proto::world::world_server::World for DojoWorld { Ok(Response::new(())) } + async fn subscribe_token_balances( + &self, + request: Request, + ) -> ServiceResult { + let RetrieveTokenBalancesRequest { contract_addresses, account_addresses } = + request.into_inner(); + let contract_addresses = contract_addresses + .iter() + .map(|address| Felt::from_bytes_be_slice(address)) + .collect::>(); + let account_addresses = account_addresses + .iter() + .map(|address| Felt::from_bytes_be_slice(address)) + .collect::>(); + + let rx = self + .subscribe_token_balances(contract_addresses, account_addresses) + .await + .map_err(|e| Status::internal(e.to_string()))?; + Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeTokenBalancesStream)) + } + + async fn update_token_balances_subscription( + &self, + request: Request, + ) -> ServiceResult<()> { + let UpdateTokenBalancesSubscriptionRequest { + subscription_id, + contract_addresses, + account_addresses, + } = request.into_inner(); + let contract_addresses = contract_addresses + .iter() + .map(|address| Felt::from_bytes_be_slice(address)) + .collect::>(); + let account_addresses = account_addresses + .iter() + .map(|address| Felt::from_bytes_be_slice(address)) + .collect::>(); + + self.token_balance_manager + .update_subscriber(subscription_id, contract_addresses, account_addresses) + .await; + Ok(Response::new(())) + } + async fn retrieve_entities( &self, request: Request, diff --git a/crates/torii/grpc/src/server/subscriptions/mod.rs b/crates/torii/grpc/src/server/subscriptions/mod.rs index b58810d611..caaa38736e 100644 --- a/crates/torii/grpc/src/server/subscriptions/mod.rs +++ b/crates/torii/grpc/src/server/subscriptions/mod.rs @@ -9,6 +9,7 @@ pub mod event; pub mod event_message; pub mod indexer; pub mod model_diff; +pub mod token_balance; pub(crate) fn match_entity_keys( id: Felt, diff --git a/crates/torii/grpc/src/server/subscriptions/token_balance.rs b/crates/torii/grpc/src/server/subscriptions/token_balance.rs new file mode 100644 index 0000000000..cf932894d4 --- /dev/null +++ b/crates/torii/grpc/src/server/subscriptions/token_balance.rs @@ -0,0 +1,193 @@ +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::{Stream, StreamExt}; +use rand::Rng; +use starknet_crypto::Felt; +use tokio::sync::mpsc::{ + channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, +}; +use tokio::sync::RwLock; +use torii_core::error::{Error, ParseError}; +use torii_core::simple_broker::SimpleBroker; +use torii_core::types::TokenBalance; +use tracing::{error, trace}; + +use crate::proto; +use crate::proto::world::SubscribeTokenBalancesResponse; + +pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::balance"; + +#[derive(Debug)] +pub struct TokenBalanceSubscriber { + /// Contract addresses that the subscriber is interested in + /// If empty, subscriber receives updates for all contracts + pub contract_addresses: HashSet, + /// Account addresses that the subscriber is interested in + /// If empty, subscriber receives updates for all accounts + pub account_addresses: HashSet, + /// The channel to send the response back to the subscriber. + pub sender: Sender>, +} + +#[derive(Debug, Default)] +pub struct TokenBalanceManager { + subscribers: RwLock>, +} + +impl TokenBalanceManager { + pub async fn add_subscriber( + &self, + contract_addresses: Vec, + account_addresses: Vec, + ) -> Result>, Error> { + let subscription_id = rand::thread_rng().gen::(); + let (sender, receiver) = channel(1); + + // Send initial empty response + let _ = sender + .send(Ok(SubscribeTokenBalancesResponse { subscription_id, balance: None })) + .await; + + self.subscribers.write().await.insert( + subscription_id, + TokenBalanceSubscriber { + contract_addresses: contract_addresses.into_iter().collect(), + account_addresses: account_addresses.into_iter().collect(), + sender, + }, + ); + + Ok(receiver) + } + + pub async fn update_subscriber( + &self, + id: u64, + contract_addresses: Vec, + account_addresses: Vec, + ) { + let sender = { + let subscribers = self.subscribers.read().await; + if let Some(subscriber) = subscribers.get(&id) { + subscriber.sender.clone() + } else { + return; // Subscriber not found, exit early + } + }; + + self.subscribers.write().await.insert( + id, + TokenBalanceSubscriber { + contract_addresses: contract_addresses.into_iter().collect(), + account_addresses: account_addresses.into_iter().collect(), + sender, + }, + ); + } + + pub(super) async fn remove_subscriber(&self, id: u64) { + self.subscribers.write().await.remove(&id); + } +} + +#[must_use = "Service does nothing unless polled"] +#[allow(missing_debug_implementations)] +pub struct Service { + simple_broker: Pin + Send>>, + balance_sender: UnboundedSender, +} + +impl Service { + pub fn new(subs_manager: Arc) -> Self { + let (balance_sender, balance_receiver) = unbounded_channel(); + let service = Self { + simple_broker: Box::pin(SimpleBroker::::subscribe()), + balance_sender, + }; + + tokio::spawn(Self::publish_updates(subs_manager, balance_receiver)); + + service + } + + async fn publish_updates( + subs: Arc, + mut balance_receiver: UnboundedReceiver, + ) { + while let Some(balance) = balance_receiver.recv().await { + if let Err(e) = Self::process_balance_update(&subs, &balance).await { + error!(target = LOG_TARGET, error = %e, "Processing balance update."); + } + } + } + + async fn process_balance_update( + subs: &Arc, + balance: &TokenBalance, + ) -> Result<(), Error> { + let mut closed_stream = Vec::new(); + + for (idx, sub) in subs.subscribers.read().await.iter() { + let contract_address = + Felt::from_str(&balance.contract_address).map_err(ParseError::FromStr)?; + let account_address = + Felt::from_str(&balance.account_address).map_err(ParseError::FromStr)?; + + // Skip if contract address filter doesn't match + if !sub.contract_addresses.is_empty() + && !sub.contract_addresses.contains(&contract_address) + { + continue; + } + + // Skip if account address filter doesn't match + if !sub.account_addresses.is_empty() + && !sub.account_addresses.contains(&account_address) + { + continue; + } + + let resp = SubscribeTokenBalancesResponse { + subscription_id: *idx, + balance: Some(proto::types::TokenBalance { + contract_address: balance.contract_address.clone(), + account_address: balance.account_address.clone(), + token_id: balance.token_id.clone(), + balance: balance.balance.clone(), + }), + }; + + if sub.sender.send(Ok(resp)).await.is_err() { + closed_stream.push(*idx); + } + } + + for id in closed_stream { + trace!(target = LOG_TARGET, id = %id, "Closing balance stream."); + subs.remove_subscriber(id).await + } + + Ok(()) + } +} + +impl Future for Service { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + while let Poll::Ready(Some(balance)) = this.simple_broker.poll_next_unpin(cx) { + if let Err(e) = this.balance_sender.send(balance) { + error!(target = LOG_TARGET, error = %e, "Sending balance update to processor."); + } + } + + Poll::Pending + } +}