diff --git a/crates/torii/client/src/client/mod.rs b/crates/torii/client/src/client/mod.rs index fb71a18a11..869ff87d2e 100644 --- a/crates/torii/client/src/client/mod.rs +++ b/crates/torii/client/src/client/mod.rs @@ -141,23 +141,45 @@ impl Client { /// A direct stream to grpc subscribe entities pub async fn on_entity_updated( &self, - clause: Option, + clauses: Vec, ) -> Result { let mut grpc_client = self.inner.write().await; - let stream = grpc_client.subscribe_entities(clause).await?; + let stream = grpc_client.subscribe_entities(clauses).await?; Ok(stream) } + /// Update the entities subscription + pub async fn update_entity_subscription( + &self, + subscription_id: u64, + clauses: Vec, + ) -> Result<(), Error> { + let mut grpc_client = self.inner.write().await; + grpc_client.update_entities_subscription(subscription_id, clauses).await?; + Ok(()) + } + /// A direct stream to grpc subscribe event messages pub async fn on_event_message_updated( &self, - clause: Option, + clauses: Vec, ) -> Result { let mut grpc_client = self.inner.write().await; - let stream = grpc_client.subscribe_event_messages(clause).await?; + let stream = grpc_client.subscribe_event_messages(clauses).await?; Ok(stream) } + /// Update the event messages subscription + pub async fn update_event_message_subscription( + &self, + subscription_id: u64, + clauses: Vec, + ) -> Result<(), Error> { + let mut grpc_client = self.inner.write().await; + grpc_client.update_event_messages_subscription(subscription_id, clauses).await?; + Ok(()) + } + pub async fn on_starknet_event( &self, keys: Option, diff --git a/crates/torii/grpc/proto/world.proto b/crates/torii/grpc/proto/world.proto index 76b88981a7..8e8010fef1 100644 --- a/crates/torii/grpc/proto/world.proto +++ b/crates/torii/grpc/proto/world.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package world; import "types.proto"; +import "google/protobuf/empty.proto"; + // The World service provides information about the world. service World { @@ -14,12 +16,18 @@ service World { // Subscribe to entity updates. rpc SubscribeEntities (SubscribeEntitiesRequest) returns (stream SubscribeEntityResponse); + // Update entity subscription + rpc UpdateEntitiesSubscription (UpdateEntitiesSubscriptionRequest) returns (google.protobuf.Empty); + // Retrieve entities rpc RetrieveEntities (RetrieveEntitiesRequest) returns (RetrieveEntitiesResponse); // Subscribe to entity updates. rpc SubscribeEventMessages (SubscribeEntitiesRequest) returns (stream SubscribeEntityResponse); + // Update entity subscription + rpc UpdateEventMessagesSubscription (UpdateEntitiesSubscriptionRequest) returns (google.protobuf.Empty); + // Retrieve entities rpc RetrieveEventMessages (RetrieveEntitiesRequest) returns (RetrieveEntitiesResponse); @@ -52,15 +60,17 @@ message SubscribeModelsResponse { } message SubscribeEntitiesRequest { - types.EntityKeysClause clause = 1; + repeated types.EntityKeysClause clauses = 1; } -message SubscribeEventMessagesRequest { - types.EntityKeysClause clause = 1; +message UpdateEntitiesSubscriptionRequest { + uint64 subscription_id = 1; + repeated types.EntityKeysClause clauses = 2; } message SubscribeEntityResponse { types.Entity entity = 1; + uint64 subscription_id = 2; } message RetrieveEntitiesRequest { diff --git a/crates/torii/grpc/src/client.rs b/crates/torii/grpc/src/client.rs index 035922b141..9e5fb71bbd 100644 --- a/crates/torii/grpc/src/client.rs +++ b/crates/torii/grpc/src/client.rs @@ -9,7 +9,7 @@ use crate::proto::world::{ world_client, MetadataRequest, RetrieveEntitiesRequest, RetrieveEntitiesResponse, RetrieveEventsRequest, RetrieveEventsResponse, SubscribeEntitiesRequest, SubscribeEntityResponse, SubscribeEventsRequest, SubscribeEventsResponse, - SubscribeModelsRequest, SubscribeModelsResponse, + SubscribeModelsRequest, SubscribeModelsResponse, UpdateEntitiesSubscriptionRequest, }; use crate::types::schema::{self, Entity, SchemaError}; use crate::types::{EntityKeysClause, Event, EventQuery, KeysClause, ModelKeysClause, Query}; @@ -104,41 +104,80 @@ impl WorldClient { /// Subscribe to entities updates of a World. pub async fn subscribe_entities( &mut self, - clause: Option, + clauses: Vec, ) -> Result { - let clause = clause.map(|c| c.into()); + let clauses = clauses.into_iter().map(|c| c.into()).collect(); let stream = self .inner - .subscribe_entities(SubscribeEntitiesRequest { clause }) + .subscribe_entities(SubscribeEntitiesRequest { clauses }) .await .map_err(Error::Grpc) .map(|res| res.into_inner())?; - Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| match res.entity { - Some(entity) => entity.try_into().expect("must able to serialize"), - None => Entity { hashed_keys: Felt::ZERO, models: vec![] }, + Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| { + res.entity.map_or( + (res.subscription_id, Entity { hashed_keys: Felt::ZERO, models: vec![] }), + |entity| (res.subscription_id, entity.try_into().expect("must able to serialize")), + ) })))) } + /// Update an entities subscription. + pub async fn update_entities_subscription( + &mut self, + subscription_id: u64, + clauses: Vec, + ) -> Result<(), Error> { + let clauses = clauses.into_iter().map(|c| c.into()).collect(); + + self.inner + .update_entities_subscription(UpdateEntitiesSubscriptionRequest { + subscription_id, + clauses, + }) + .await + .map_err(Error::Grpc) + .map(|res| res.into_inner()) + } + /// Subscribe to event messages of a World. pub async fn subscribe_event_messages( &mut self, - clause: Option, + clauses: Vec, ) -> Result { - let clause = clause.map(|c| c.into()); + let clauses = clauses.into_iter().map(|c| c.into()).collect(); let stream = self .inner - .subscribe_event_messages(SubscribeEntitiesRequest { clause }) + .subscribe_event_messages(SubscribeEntitiesRequest { clauses }) .await .map_err(Error::Grpc) .map(|res| res.into_inner())?; - Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| match res.entity { - Some(entity) => entity.try_into().expect("must able to serialize"), - None => Entity { hashed_keys: Felt::ZERO, models: vec![] }, + Ok(EntityUpdateStreaming(stream.map_ok(Box::new(|res| { + res.entity.map_or( + (res.subscription_id, Entity { hashed_keys: Felt::ZERO, models: vec![] }), + |entity| (res.subscription_id, entity.try_into().expect("must able to serialize")), + ) })))) } + /// Update an event messages subscription. + pub async fn update_event_messages_subscription( + &mut self, + subscription_id: u64, + clauses: Vec, + ) -> Result<(), Error> { + let clauses = clauses.into_iter().map(|c| c.into()).collect(); + self.inner + .update_event_messages_subscription(UpdateEntitiesSubscriptionRequest { + subscription_id, + clauses, + }) + .await + .map_err(Error::Grpc) + .map(|res| res.into_inner()) + } + /// Subscribe to the events of a World. pub async fn subscribe_events( &mut self, @@ -200,9 +239,10 @@ impl Stream for ModelDiffsStreaming { } } +type SubscriptionId = u64; type EntityMappedStream = MapOk< tonic::Streaming, - Box Entity + Send>, + Box (SubscriptionId, Entity) + Send>, >; #[derive(Debug)] diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index f667a60ba3..e69ada8461 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -18,6 +18,7 @@ use futures::Stream; use proto::world::{ MetadataRequest, MetadataResponse, RetrieveEntitiesRequest, RetrieveEntitiesResponse, RetrieveEventsRequest, RetrieveEventsResponse, SubscribeModelsRequest, SubscribeModelsResponse, + UpdateEntitiesSubscriptionRequest, }; use sqlx::prelude::FromRow; use sqlx::sqlite::SqliteRow; @@ -772,9 +773,9 @@ impl DojoWorld { async fn subscribe_entities( &self, - keys: Option, + keys: Vec, ) -> Result>, Error> { - self.entity_manager.add_subscriber(keys.map(|keys| keys.into())).await + self.entity_manager.add_subscriber(keys.into_iter().map(|keys| keys.into()).collect()).await } async fn retrieve_entities( @@ -849,9 +850,11 @@ impl DojoWorld { async fn subscribe_event_messages( &self, - keys: Option, + clauses: Vec, ) -> Result>, Error> { - self.event_message_manager.add_subscriber(keys.map(|keys| keys.into())).await + self.event_message_manager + .add_subscriber(clauses.into_iter().map(|keys| keys.into()).collect()) + .await } async fn retrieve_event_messages( @@ -1054,13 +1057,28 @@ impl proto::world::world_server::World for DojoWorld { &self, request: Request, ) -> ServiceResult { - let SubscribeEntitiesRequest { clause } = request.into_inner(); + let SubscribeEntitiesRequest { clauses } = request.into_inner(); let rx = - self.subscribe_entities(clause).await.map_err(|e| Status::internal(e.to_string()))?; + self.subscribe_entities(clauses).await.map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream)) } + async fn update_entities_subscription( + &self, + request: Request, + ) -> ServiceResult<()> { + let UpdateEntitiesSubscriptionRequest { subscription_id, clauses } = request.into_inner(); + self.entity_manager + .update_subscriber( + subscription_id, + clauses.into_iter().map(|keys| keys.into()).collect(), + ) + .await; + + Ok(Response::new(())) + } + async fn retrieve_entities( &self, request: Request, @@ -1080,15 +1098,30 @@ impl proto::world::world_server::World for DojoWorld { &self, request: Request, ) -> ServiceResult { - let SubscribeEntitiesRequest { clause } = request.into_inner(); + let SubscribeEntitiesRequest { clauses } = request.into_inner(); let rx = self - .subscribe_event_messages(clause) + .subscribe_event_messages(clauses) .await .map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream)) } + async fn update_event_messages_subscription( + &self, + request: Request, + ) -> ServiceResult<()> { + let UpdateEntitiesSubscriptionRequest { subscription_id, clauses } = request.into_inner(); + self.event_message_manager + .update_subscriber( + subscription_id, + clauses.into_iter().map(|keys| keys.into()).collect(), + ) + .await; + + Ok(Response::new(())) + } + async fn retrieve_event_messages( &self, request: Request, diff --git a/crates/torii/grpc/src/server/subscriptions/entity.rs b/crates/torii/grpc/src/server/subscriptions/entity.rs index 252c4df755..2b0c976d69 100644 --- a/crates/torii/grpc/src/server/subscriptions/entity.rs +++ b/crates/torii/grpc/src/server/subscriptions/entity.rs @@ -30,35 +30,50 @@ pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::entity" #[derive(Debug)] pub struct EntitiesSubscriber { /// Entity ids that the subscriber is interested in - keys: Option, + pub(crate) clauses: Vec, /// The channel to send the response back to the subscriber. - sender: Sender>, + pub(crate) sender: Sender>, } - #[derive(Debug, Default)] pub struct EntityManager { - subscribers: RwLock>, + subscribers: RwLock>, } impl EntityManager { pub async fn add_subscriber( &self, - keys: Option, + clauses: Vec, ) -> Result>, Error> { - let id = rand::thread_rng().gen::(); + let subscription_id = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); // NOTE: unlock issue with firefox/safari // initially send empty stream message to return from // initial subscribe call - let _ = sender.send(Ok(SubscribeEntityResponse { entity: None })).await; + let _ = sender.send(Ok(SubscribeEntityResponse { entity: None, subscription_id })).await; - self.subscribers.write().await.insert(id, EntitiesSubscriber { keys, sender }); + self.subscribers + .write() + .await + .insert(subscription_id, EntitiesSubscriber { clauses, sender }); Ok(receiver) } - pub(super) async fn remove_subscriber(&self, id: usize) { + pub async fn update_subscriber(&self, id: u64, clauses: Vec) { + let sender = { + let subscribers = self.subscribers.read().await; + if let Some(subscriber) = subscribers.get(&id) { + subscriber.sender.clone() + } else { + return; // Subscriber not found, exit early + } + }; + + self.subscribers.write().await.insert(id, EntitiesSubscriber { clauses, sender }); + } + + pub(super) async fn remove_subscriber(&self, id: u64) { self.subscribers.write().await.remove(&id); } } @@ -109,13 +124,11 @@ impl Service { // If we have a clause of keys, then check that the key pattern of the entity // matches the key pattern of the subscriber. - match &sub.keys { - Some(EntityKeysClause::HashedKeys(hashed_keys)) => { - if !hashed_keys.is_empty() && !hashed_keys.contains(&hashed) { - continue; - } + if !sub.clauses.iter().any(|clause| match clause { + EntityKeysClause::HashedKeys(hashed_keys) => { + hashed_keys.is_empty() || hashed_keys.contains(&hashed) } - Some(EntityKeysClause::Keys(clause)) => { + EntityKeysClause::Keys(clause) => { // if we have a model clause, then we need to check that the entity // has an updated model and that the model name matches the clause if let Some(updated_model) = &entity.updated_model { @@ -139,7 +152,7 @@ impl Service { || clause_model == "*") }) { - continue; + return false; } } @@ -148,10 +161,10 @@ impl Service { if clause.pattern_matching == PatternMatching::FixedLen && keys.len() != clause.keys.len() { - continue; + return false; } - if !keys.iter().enumerate().all(|(idx, key)| { + return keys.iter().enumerate().all(|(idx, key)| { // this is going to be None if our key pattern overflows the subscriber // key pattern in this case we should skip let sub_key = clause.keys.get(idx); @@ -166,12 +179,10 @@ impl Service { // so we should match all next keys _ => true, } - }) { - continue; - } + }); } - // if None, then we are interested in all entities - None => {} + }) { + continue; } if entity.updated_model.is_none() { @@ -180,6 +191,7 @@ impl Service { hashed_keys: hashed.to_bytes_be().to_vec(), models: vec![], }), + subscription_id: *idx, }; if sub.sender.send(Ok(resp)).await.is_err() { @@ -222,6 +234,7 @@ impl Service { let resp = proto::world::SubscribeEntityResponse { entity: Some(map_row_to_entity(&row, &arrays_rows, schemas.clone())?), + subscription_id: *idx, }; if sub.sender.send(Ok(resp)).await.is_err() { diff --git a/crates/torii/grpc/src/server/subscriptions/event_message.rs b/crates/torii/grpc/src/server/subscriptions/event_message.rs index cc1739f0bb..5cabd5f03a 100644 --- a/crates/torii/grpc/src/server/subscriptions/event_message.rs +++ b/crates/torii/grpc/src/server/subscriptions/event_message.rs @@ -10,7 +10,7 @@ use futures_util::StreamExt; use rand::Rng; use sqlx::{Pool, Sqlite}; use starknet::core::types::Felt; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::mpsc::{channel, Receiver}; use tokio::sync::RwLock; use torii_core::cache::ModelCache; use torii_core::error::{Error, ParseError}; @@ -20,44 +20,54 @@ use torii_core::sql::FELT_DELIMITER; use torii_core::types::EventMessage; use tracing::{error, trace}; +use super::entity::EntitiesSubscriber; use crate::proto; use crate::proto::world::SubscribeEntityResponse; use crate::server::map_row_to_entity; use crate::types::{EntityKeysClause, PatternMatching}; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::event_message"; -#[derive(Debug)] -pub struct EventMessagesSubscriber { - /// Entity keys that the subscriber is interested in - keys: Option, - /// The channel to send the response back to the subscriber. - sender: Sender>, -} #[derive(Debug, Default)] pub struct EventMessageManager { - subscribers: RwLock>, + subscribers: RwLock>, } impl EventMessageManager { pub async fn add_subscriber( &self, - keys: Option, + clauses: Vec, ) -> Result>, Error> { - let id = rand::thread_rng().gen::(); + let subscription_id = rand::thread_rng().gen::(); let (sender, receiver) = channel(1); // NOTE: unlock issue with firefox/safari // initially send empty stream message to return from // initial subscribe call - let _ = sender.send(Ok(SubscribeEntityResponse { entity: None })).await; + let _ = sender.send(Ok(SubscribeEntityResponse { entity: None, subscription_id })).await; - self.subscribers.write().await.insert(id, EventMessagesSubscriber { keys, sender }); + self.subscribers + .write() + .await + .insert(subscription_id, EntitiesSubscriber { clauses, sender }); Ok(receiver) } - pub(super) async fn remove_subscriber(&self, id: usize) { + pub async fn update_subscriber(&self, id: u64, clauses: Vec) { + let sender = { + let subscribers = self.subscribers.read().await; + if let Some(subscriber) = subscribers.get(&id) { + subscriber.sender.clone() + } else { + return; // Subscriber not found, exit early + } + }; + + self.subscribers.write().await.insert(id, EntitiesSubscriber { clauses, sender }); + } + + pub(super) async fn remove_subscriber(&self, id: u64) { self.subscribers.write().await.remove(&id); } } @@ -108,13 +118,11 @@ impl Service { // If we have a clause of keys, then check that the key pattern of the entity // matches the key pattern of the subscriber. - match &sub.keys { - Some(EntityKeysClause::HashedKeys(hashed_keys)) => { - if !hashed_keys.is_empty() && !hashed_keys.contains(&hashed) { - continue; - } + if !sub.clauses.iter().any(|clause| match clause { + EntityKeysClause::HashedKeys(hashed_keys) => { + hashed_keys.is_empty() || hashed_keys.contains(&hashed) } - Some(EntityKeysClause::Keys(clause)) => { + EntityKeysClause::Keys(clause) => { // if we have a model clause, then we need to check that the entity // has an updated model and that the model name matches the clause if let Some(updated_model) = &entity.updated_model { @@ -138,7 +146,7 @@ impl Service { || clause_model == "*") }) { - continue; + return false; } } @@ -147,10 +155,10 @@ impl Service { if clause.pattern_matching == PatternMatching::FixedLen && keys.len() != clause.keys.len() { - continue; + return false; } - if !keys.iter().enumerate().all(|(idx, key)| { + return keys.iter().enumerate().all(|(idx, key)| { // this is going to be None if our key pattern overflows the subscriber // key pattern in this case we should skip let sub_key = clause.keys.get(idx); @@ -165,12 +173,10 @@ impl Service { // so we should match all next keys _ => true, } - }) { - continue; - } + }); } - // if None, then we are interested in all entities - None => {} + }) { + continue; } // publish all updates if ids is empty or only ids that are subscribed to @@ -207,6 +213,7 @@ impl Service { let resp = proto::world::SubscribeEntityResponse { entity: Some(map_row_to_entity(&row, &arrays_rows, schemas.clone())?), + subscription_id: *idx, }; if sub.sender.send(Ok(resp)).await.is_err() {