From 8ce7fcd4870904f3482643d16cb2414c6c699162 Mon Sep 17 00:00:00 2001 From: broody Date: Sun, 10 Dec 2023 11:12:37 -1000 Subject: [PATCH] entity sub all --- crates/torii/grpc/src/server/mod.rs | 58 +++++-- .../grpc/src/server/subscriptions/entity.rs | 151 ++++++++++++++++++ .../src/server/subscriptions/entity_update.rs | 92 ----------- .../grpc/src/server/subscriptions/error.rs | 2 + .../grpc/src/server/subscriptions/mod.rs | 8 +- .../{model_update.rs => state_diff.rs} | 0 crates/torii/grpc/src/types/schema.rs | 2 +- scripts/rust_fmt.sh | 2 +- 8 files changed, 207 insertions(+), 108 deletions(-) create mode 100644 crates/torii/grpc/src/server/subscriptions/entity.rs delete mode 100644 crates/torii/grpc/src/server/subscriptions/entity_update.rs rename crates/torii/grpc/src/server/subscriptions/{model_update.rs => state_diff.rs} (100%) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 058c9303e1..f8a42346a2 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -28,17 +28,20 @@ use torii_core::cache::ModelCache; use torii_core::error::{Error, ParseError, QueryError}; use torii_core::model::{build_sql_query, map_row_to_ty}; -use self::subscriptions::model_update::{ModelSubscriptionRequest, ModelSubscriberManager}; +use self::subscriptions::entity::EntitySubscriberManager; +use self::subscriptions::state_diff::{ModelSubscriberManager, ModelSubscriptionRequest}; use crate::proto::types::clause::ClauseType; use crate::proto::world::world_server::WorldServer; +use crate::proto::world::{SubscribeEntitiesRequest, SubscribeEntityResponse}; use crate::proto::{self}; #[derive(Clone)] pub struct DojoWorld { - world_address: FieldElement, pool: Pool, - model_subscriber: Arc, + world_address: FieldElement, model_cache: Arc, + model_manager: Arc, + entity_manager: Arc, } impl DojoWorld { @@ -48,18 +51,24 @@ impl DojoWorld { world_address: FieldElement, provider: Arc>, ) -> Self { - let model_subscriber = Arc::new(ModelSubscriberManager::default()); + let model_cache = Arc::new(ModelCache::new(pool.clone())); + let model_manager = Arc::new(ModelSubscriberManager::default()); + let entity_manager = Arc::new(EntitySubscriberManager::default()); - tokio::task::spawn(subscriptions::model_update::Service::new_with_block_rcv( + tokio::task::spawn(subscriptions::state_diff::Service::new_with_block_rcv( block_rx, world_address, provider, - Arc::clone(&model_subscriber), + Arc::clone(&model_manager), )); - let model_cache = Arc::new(ModelCache::new(pool.clone())); + tokio::task::spawn(subscriptions::entity::Service::new( + pool.clone(), + Arc::clone(&entity_manager), + Arc::clone(&model_cache), + )); - Self { pool, model_cache, world_address, model_subscriber } + Self { pool, world_address, model_cache, model_manager, entity_manager } } } @@ -260,14 +269,26 @@ impl DojoWorld { subs.push(ModelSubscriptionRequest { keys, - model: subscriptions::model_update::ModelMetadata { + model: subscriptions::state_diff::ModelMetadata { name: model, packed_size: packed_size as usize, }, }); } - self.model_subscriber.add_subscriber(subs).await + self.model_manager.add_subscriber(subs).await + } + + async fn subscribe_entities( + &self, + ids: Vec, + ) -> Result>, Error> { + let ids = ids + .iter() + .map(|id| Ok(FieldElement::from_str(&id).map_err(ParseError::FromStr)?)) + .collect::, Error>>()?; + + self.entity_manager.add_subscriber(ids).await } async fn retrieve_entities( @@ -325,9 +346,14 @@ impl DojoWorld { type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>; +type SubscribeEntitiesResponseStream = + Pin> + Send>>; #[tonic::async_trait] impl proto::world::world_server::World for DojoWorld { + type SubscribeModelsStream = SubscribeModelsResponseStream; + type SubscribeEntitiesStream = SubscribeEntitiesResponseStream; + async fn world_metadata( &self, _request: Request, @@ -340,8 +366,6 @@ impl proto::world::world_server::World for DojoWorld { Ok(Response::new(MetadataResponse { metadata: Some(metadata) })) } - type SubscribeModelsStream = SubscribeModelsResponseStream; - async fn subscribe_models( &self, request: Request, @@ -354,6 +378,16 @@ impl proto::world::world_server::World for DojoWorld { Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeModelsStream)) } + async fn subscribe_entities( + &self, + request: Request, + ) -> ServiceResult { + let SubscribeEntitiesRequest { ids } = request.into_inner(); + let rx = self.subscribe_entities(ids).await.map_err(|e| Status::internal(e.to_string()))?; + + Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream)) + } + async fn retrieve_entities( &self, request: Request, diff --git a/crates/torii/grpc/src/server/subscriptions/entity.rs b/crates/torii/grpc/src/server/subscriptions/entity.rs new file mode 100644 index 0000000000..31f645a699 --- /dev/null +++ b/crates/torii/grpc/src/server/subscriptions/entity.rs @@ -0,0 +1,151 @@ +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use dojo_types::packing::ParseError; +use futures::Stream; +use futures_util::StreamExt; +use rand::Rng; +use sqlx::{Pool, Sqlite}; +use starknet_crypto::FieldElement; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; +use torii_core::cache::ModelCache; +use torii_core::error::Error; +use torii_core::model::{build_sql_query, map_row_to_ty}; +use torii_core::simple_broker::SimpleBroker; +use torii_core::types::Entity; +use tracing::trace; + +use crate::proto; + +pub struct EntitiesSubscriber { + /// Entity ids that the subscriber is interested in + ids: HashSet, + /// The channel to send the response back to the subscriber. + sender: Sender>, +} + +#[derive(Default)] +pub struct EntitySubscriberManager { + subscribers: RwLock>, +} + +impl EntitySubscriberManager { + pub async fn add_subscriber( + &self, + ids: Vec, + ) -> Result>, Error> { + let id = rand::thread_rng().gen::(); + let (sender, receiver) = channel(1); + + self.subscribers + .write() + .await + .insert(id, EntitiesSubscriber { ids: ids.iter().cloned().collect(), sender }); + + Ok(receiver) + } + + pub(super) async fn remove_subscriber(&self, id: usize) { + self.subscribers.write().await.remove(&id); + } +} + +#[must_use = "Service does nothing unless polled"] +pub struct Service { + pool: Pool, + subs_manager: Arc, + model_cache: Arc, + simple_broker: Pin + Send>>, +} + +impl Service { + pub fn new( + pool: Pool, + subs_manager: Arc, + model_cache: Arc, + ) -> Self { + Self { + pool, + subs_manager, + model_cache, + simple_broker: Box::pin(SimpleBroker::::subscribe()), + } + } + + async fn publish_updates( + subs: Arc, + cache: Arc, + pool: Pool, + id: &str, + ) -> Result<(), Error> { + let mut closed_stream = Vec::new(); + + for (idx, sub) in subs.subscribers.read().await.iter() { + let query = r#" + SELECT group_concat(entity_model.model_id) as model_names + FROM entities + JOIN entity_model ON entities.id = entity_model.entity_id + WHERE entities.id = ? + GROUP BY entities.id + "#; + let result: (String,) = sqlx::query_as(query).bind(id).fetch_one(&pool).await?; + let model_names: Vec<&str> = result.0.split(',').collect(); + let schemas = cache.schemas(model_names).await?; + + let entity_query = format!("{} WHERE entities.id = ?", build_sql_query(&schemas)?); + let row = sqlx::query(&entity_query).bind(&id).fetch_one(&pool).await?; + + let models = schemas + .iter() + .map(|s| { + let mut struct_ty = s.as_struct().expect("schema should be struct").to_owned(); + map_row_to_ty(&s.name(), &mut struct_ty, &row)?; + + Ok(struct_ty.try_into().unwrap()) + }) + .collect::, Error>>()?; + + let resp = proto::world::SubscribeEntityResponse { + entity: Some(proto::types::Entity { + id: FieldElement::from_str(&id).unwrap().to_bytes_be().to_vec(), + models, + }), + }; + + if sub.sender.send(Ok(resp)).await.is_err() { + closed_stream.push(*idx); + } + } + + for id in closed_stream { + trace!(target = "subscription", "closing stream idx: {id}"); + subs.remove_subscriber(id).await + } + + Ok(()) + } +} + +impl Future for Service { + type Output = (); + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll { + let pin = self.get_mut(); + + while let Poll::Ready(Some(entity)) = pin.simple_broker.poll_next_unpin(cx) { + let subs = Arc::clone(&pin.subs_manager); + let cache = Arc::clone(&pin.model_cache); + let pool = pin.pool.clone(); + tokio::spawn( + async move { Service::publish_updates(subs, cache, pool, &entity.id).await }, + ); + } + + Poll::Pending + } +} diff --git a/crates/torii/grpc/src/server/subscriptions/entity_update.rs b/crates/torii/grpc/src/server/subscriptions/entity_update.rs deleted file mode 100644 index f3585b29ea..0000000000 --- a/crates/torii/grpc/src/server/subscriptions/entity_update.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::collections::{HashMap, HashSet, VecDeque}; -use std::future::Future; -use std::sync::Arc; -use std::task::{Poll, Context}; -use std::pin::Pin; - -use futures_util::StreamExt; -use rand::Rng; - -use sqlx::{Pool, Sqlite}; -use starknet::macros::short_string; -use starknet::providers::Provider; -use starknet_crypto::{poseidon_hash_many, FieldElement}; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio::sync::RwLock; -use torii_core::error::{Error, ParseError}; -use torii_core::simple_broker::SimpleBroker; -use torii_core::types::Entity; -use futures::Stream; -use tracing::{debug, error, trace}; - -use super::error::SubscriptionError; -use crate::proto; - -pub struct EntitiesSubscriber { - /// Entity ids that the subscriber is interested in - ids: HashSet, - /// The channel to send the response back to the subscriber. - sender: Sender>, -} - -#[derive(Default)] -pub struct EntitySubscriberManager { - subscribers: RwLock>, -} - -impl EntitySubscriberManager { - pub async fn add_subscriber( - &self, - ids: Vec - ) -> Result>, Error> { - let id = rand::thread_rng().gen::(); - let (sender, receiver) = channel(1); - - self.subscribers.write().await.insert( - id, - EntitiesSubscriber { ids: ids.iter().cloned().collect(), sender } - ); - - Ok(receiver) - } -} - -#[must_use = "Service does nothing unless polled"] -pub struct Service { - pool: Pool, - simple_broker: Pin + Send>>, - entity_update_queue: VecDeque -} - -type PublishEntityUpdateResult = Result<(), SubscriptionError>; - -impl Service { - pub fn new(pool: Pool) -> Self { - Self { - pool, - simple_broker: Box::pin(SimpleBroker::::subscribe()), - entity_update_queue: VecDeque::new() - } - } - - async fn publish_entity_updates(subs: Arc, id: FieldElement) { - - } -} - -impl Future for Service { - type Output = (); - - fn poll( - self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_> - ) -> std::task::Poll { - let pin = self.get_mut(); - - while let Poll::Ready(Some(entity)) = pin.simple_broker.poll_next_unpin(cx) { - println!("GOT IT!"); - } - - Poll::Pending - } -} \ No newline at end of file diff --git a/crates/torii/grpc/src/server/subscriptions/error.rs b/crates/torii/grpc/src/server/subscriptions/error.rs index c901113f98..ed8d0a1da6 100644 --- a/crates/torii/grpc/src/server/subscriptions/error.rs +++ b/crates/torii/grpc/src/server/subscriptions/error.rs @@ -7,4 +7,6 @@ pub enum SubscriptionError { Parse(#[from] ParseError), #[error(transparent)] Provider(ProviderError), + #[error(transparent)] + Sql(#[from] sqlx::Error), } diff --git a/crates/torii/grpc/src/server/subscriptions/mod.rs b/crates/torii/grpc/src/server/subscriptions/mod.rs index dda72cd19f..882ef899b8 100644 --- a/crates/torii/grpc/src/server/subscriptions/mod.rs +++ b/crates/torii/grpc/src/server/subscriptions/mod.rs @@ -1,3 +1,7 @@ -pub mod model_update; -pub mod entity_update; +pub mod entity; pub mod error; +pub mod state_diff; + +// TODO +// pub mod event +// pub mod transaction diff --git a/crates/torii/grpc/src/server/subscriptions/model_update.rs b/crates/torii/grpc/src/server/subscriptions/state_diff.rs similarity index 100% rename from crates/torii/grpc/src/server/subscriptions/model_update.rs rename to crates/torii/grpc/src/server/subscriptions/state_diff.rs diff --git a/crates/torii/grpc/src/types/schema.rs b/crates/torii/grpc/src/types/schema.rs index 72047b6f06..1931989a5c 100644 --- a/crates/torii/grpc/src/types/schema.rs +++ b/crates/torii/grpc/src/types/schema.rs @@ -22,7 +22,7 @@ impl TryFrom for Entity { type Error = ClientError; fn try_from(entity: proto::types::Entity) -> Result { Ok(Self { - key: FieldElement::from_byte_slice_be(&entity.key).map_err(ClientError::SliceError)?, + key: FieldElement::from_byte_slice_be(&entity.id).map_err(ClientError::SliceError)?, models: entity .models .into_iter() diff --git a/scripts/rust_fmt.sh b/scripts/rust_fmt.sh index 62a418693a..095c105201 100755 --- a/scripts/rust_fmt.sh +++ b/scripts/rust_fmt.sh @@ -1,3 +1,3 @@ #!/bin/bash -cargo +nightly fmt --check --all -- "$@" +cargo +nightly fmt --all -- "$@"