Skip to content

Commit

Permalink
entity sub all
Browse files Browse the repository at this point in the history
  • Loading branch information
broody committed Dec 10, 2023
1 parent 05fa5a3 commit 8ce7fcd
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 108 deletions.
58 changes: 46 additions & 12 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ use torii_core::cache::ModelCache;
use torii_core::error::{Error, ParseError, QueryError};
use torii_core::model::{build_sql_query, map_row_to_ty};

use self::subscriptions::model_update::{ModelSubscriptionRequest, ModelSubscriberManager};
use self::subscriptions::entity::EntitySubscriberManager;
use self::subscriptions::state_diff::{ModelSubscriberManager, ModelSubscriptionRequest};
use crate::proto::types::clause::ClauseType;
use crate::proto::world::world_server::WorldServer;
use crate::proto::world::{SubscribeEntitiesRequest, SubscribeEntityResponse};
use crate::proto::{self};

#[derive(Clone)]
pub struct DojoWorld {
world_address: FieldElement,
pool: Pool<Sqlite>,
model_subscriber: Arc<ModelSubscriberManager>,
world_address: FieldElement,
model_cache: Arc<ModelCache>,
model_manager: Arc<ModelSubscriberManager>,
entity_manager: Arc<EntitySubscriberManager>,
}

impl DojoWorld {
Expand All @@ -48,18 +51,24 @@ impl DojoWorld {
world_address: FieldElement,
provider: Arc<JsonRpcClient<HttpTransport>>,
) -> Self {
let model_subscriber = Arc::new(ModelSubscriberManager::default());
let model_cache = Arc::new(ModelCache::new(pool.clone()));
let model_manager = Arc::new(ModelSubscriberManager::default());
let entity_manager = Arc::new(EntitySubscriberManager::default());

tokio::task::spawn(subscriptions::model_update::Service::new_with_block_rcv(
tokio::task::spawn(subscriptions::state_diff::Service::new_with_block_rcv(
block_rx,
world_address,
provider,
Arc::clone(&model_subscriber),
Arc::clone(&model_manager),
));

let model_cache = Arc::new(ModelCache::new(pool.clone()));
tokio::task::spawn(subscriptions::entity::Service::new(
pool.clone(),
Arc::clone(&entity_manager),
Arc::clone(&model_cache),
));

Self { pool, model_cache, world_address, model_subscriber }
Self { pool, world_address, model_cache, model_manager, entity_manager }
}
}

Expand Down Expand Up @@ -260,14 +269,26 @@ impl DojoWorld {

subs.push(ModelSubscriptionRequest {
keys,
model: subscriptions::model_update::ModelMetadata {
model: subscriptions::state_diff::ModelMetadata {
name: model,
packed_size: packed_size as usize,
},
});
}

self.model_subscriber.add_subscriber(subs).await
self.model_manager.add_subscriber(subs).await
}

async fn subscribe_entities(
&self,
ids: Vec<String>,
) -> Result<Receiver<Result<proto::world::SubscribeEntityResponse, tonic::Status>>, Error> {
let ids = ids
.iter()
.map(|id| Ok(FieldElement::from_str(&id).map_err(ParseError::FromStr)?))
.collect::<Result<Vec<_>, Error>>()?;

self.entity_manager.add_subscriber(ids).await
}

async fn retrieve_entities(
Expand Down Expand Up @@ -325,9 +346,14 @@ impl DojoWorld {
type ServiceResult<T> = Result<Response<T>, Status>;
type SubscribeModelsResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeModelsResponse, Status>> + Send>>;
type SubscribeEntitiesResponseStream =
Pin<Box<dyn Stream<Item = Result<SubscribeEntityResponse, Status>> + Send>>;

#[tonic::async_trait]
impl proto::world::world_server::World for DojoWorld {
type SubscribeModelsStream = SubscribeModelsResponseStream;
type SubscribeEntitiesStream = SubscribeEntitiesResponseStream;

async fn world_metadata(
&self,
_request: Request<MetadataRequest>,
Expand All @@ -340,8 +366,6 @@ impl proto::world::world_server::World for DojoWorld {
Ok(Response::new(MetadataResponse { metadata: Some(metadata) }))
}

type SubscribeModelsStream = SubscribeModelsResponseStream;

async fn subscribe_models(
&self,
request: Request<SubscribeModelsRequest>,
Expand All @@ -354,6 +378,16 @@ impl proto::world::world_server::World for DojoWorld {
Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeModelsStream))
}

async fn subscribe_entities(
&self,
request: Request<SubscribeEntitiesRequest>,
) -> ServiceResult<Self::SubscribeEntitiesStream> {
let SubscribeEntitiesRequest { ids } = request.into_inner();
let rx = self.subscribe_entities(ids).await.map_err(|e| Status::internal(e.to_string()))?;

Ok(Response::new(Box::pin(ReceiverStream::new(rx)) as Self::SubscribeEntitiesStream))
}

async fn retrieve_entities(
&self,
request: Request<RetrieveEntitiesRequest>,
Expand Down
151 changes: 151 additions & 0 deletions crates/torii/grpc/src/server/subscriptions/entity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};

use dojo_types::packing::ParseError;
use futures::Stream;
use futures_util::StreamExt;
use rand::Rng;
use sqlx::{Pool, Sqlite};
use starknet_crypto::FieldElement;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock;
use torii_core::cache::ModelCache;
use torii_core::error::Error;
use torii_core::model::{build_sql_query, map_row_to_ty};
use torii_core::simple_broker::SimpleBroker;
use torii_core::types::Entity;
use tracing::trace;

use crate::proto;

pub struct EntitiesSubscriber {
/// Entity ids that the subscriber is interested in
ids: HashSet<FieldElement>,
/// The channel to send the response back to the subscriber.
sender: Sender<Result<proto::world::SubscribeEntityResponse, tonic::Status>>,
}

#[derive(Default)]
pub struct EntitySubscriberManager {
subscribers: RwLock<HashMap<usize, EntitiesSubscriber>>,
}

impl EntitySubscriberManager {
pub async fn add_subscriber(
&self,
ids: Vec<FieldElement>,
) -> Result<Receiver<Result<proto::world::SubscribeEntityResponse, tonic::Status>>, Error> {
let id = rand::thread_rng().gen::<usize>();
let (sender, receiver) = channel(1);

self.subscribers
.write()
.await
.insert(id, EntitiesSubscriber { ids: ids.iter().cloned().collect(), sender });

Ok(receiver)
}

pub(super) async fn remove_subscriber(&self, id: usize) {
self.subscribers.write().await.remove(&id);
}
}

#[must_use = "Service does nothing unless polled"]
pub struct Service {
pool: Pool<Sqlite>,
subs_manager: Arc<EntitySubscriberManager>,
model_cache: Arc<ModelCache>,
simple_broker: Pin<Box<dyn Stream<Item = Entity> + Send>>,
}

impl Service {
pub fn new(
pool: Pool<Sqlite>,
subs_manager: Arc<EntitySubscriberManager>,
model_cache: Arc<ModelCache>,
) -> Self {
Self {
pool,
subs_manager,
model_cache,
simple_broker: Box::pin(SimpleBroker::<Entity>::subscribe()),
}
}

async fn publish_updates(
subs: Arc<EntitySubscriberManager>,
cache: Arc<ModelCache>,
pool: Pool<Sqlite>,
id: &str,
) -> Result<(), Error> {
let mut closed_stream = Vec::new();

for (idx, sub) in subs.subscribers.read().await.iter() {
let query = r#"
SELECT group_concat(entity_model.model_id) as model_names
FROM entities
JOIN entity_model ON entities.id = entity_model.entity_id
WHERE entities.id = ?
GROUP BY entities.id
"#;
let result: (String,) = sqlx::query_as(query).bind(id).fetch_one(&pool).await?;
let model_names: Vec<&str> = result.0.split(',').collect();
let schemas = cache.schemas(model_names).await?;

let entity_query = format!("{} WHERE entities.id = ?", build_sql_query(&schemas)?);
let row = sqlx::query(&entity_query).bind(&id).fetch_one(&pool).await?;

let models = schemas
.iter()
.map(|s| {
let mut struct_ty = s.as_struct().expect("schema should be struct").to_owned();
map_row_to_ty(&s.name(), &mut struct_ty, &row)?;

Ok(struct_ty.try_into().unwrap())
})
.collect::<Result<Vec<_>, Error>>()?;

let resp = proto::world::SubscribeEntityResponse {
entity: Some(proto::types::Entity {
id: FieldElement::from_str(&id).unwrap().to_bytes_be().to_vec(),
models,
}),
};

if sub.sender.send(Ok(resp)).await.is_err() {
closed_stream.push(*idx);
}
}

for id in closed_stream {
trace!(target = "subscription", "closing stream idx: {id}");
subs.remove_subscriber(id).await
}

Ok(())
}
}

impl Future for Service {
type Output = ();

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Self::Output> {
let pin = self.get_mut();

while let Poll::Ready(Some(entity)) = pin.simple_broker.poll_next_unpin(cx) {
let subs = Arc::clone(&pin.subs_manager);
let cache = Arc::clone(&pin.model_cache);
let pool = pin.pool.clone();
tokio::spawn(
async move { Service::publish_updates(subs, cache, pool, &entity.id).await },
);
}

Poll::Pending
}
}
92 changes: 0 additions & 92 deletions crates/torii/grpc/src/server/subscriptions/entity_update.rs

This file was deleted.

2 changes: 2 additions & 0 deletions crates/torii/grpc/src/server/subscriptions/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ pub enum SubscriptionError {
Parse(#[from] ParseError),
#[error(transparent)]
Provider(ProviderError),
#[error(transparent)]
Sql(#[from] sqlx::Error),
}
8 changes: 6 additions & 2 deletions crates/torii/grpc/src/server/subscriptions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
pub mod model_update;
pub mod entity_update;
pub mod entity;
pub mod error;
pub mod state_diff;

// TODO
// pub mod event
// pub mod transaction
2 changes: 1 addition & 1 deletion crates/torii/grpc/src/types/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl TryFrom<proto::types::Entity> for Entity {
type Error = ClientError;
fn try_from(entity: proto::types::Entity) -> Result<Self, Self::Error> {
Ok(Self {
key: FieldElement::from_byte_slice_be(&entity.key).map_err(ClientError::SliceError)?,
key: FieldElement::from_byte_slice_be(&entity.id).map_err(ClientError::SliceError)?,
models: entity
.models
.into_iter()
Expand Down
Loading

0 comments on commit 8ce7fcd

Please sign in to comment.