Skip to content

Commit

Permalink
refactor: grpc use model hash as id & libp2p command queue (#1754)
Browse files Browse the repository at this point in the history
* refactor: use model hash as ids for grpc

* fmt

* remove println

* refactor: get rid of wait for relay

* relay runner

* remove async
  • Loading branch information
Larkooo authored Apr 3, 2024
1 parent 24f40c6 commit f9acaad
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 64 deletions.
10 changes: 7 additions & 3 deletions crates/torii/client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use dojo_types::packing::unpack;
use dojo_types::schema::Ty;
use dojo_types::WorldMetadata;
use dojo_world::contracts::WorldContractReader;
use futures::lock::Mutex;
use futures::Future;
use parking_lot::{RwLock, RwLockReadGuard};
use starknet::core::utils::cairo_short_string_to_felt;
use starknet::providers::jsonrpc::HttpTransport;
Expand All @@ -20,6 +22,7 @@ use torii_grpc::client::{EntityUpdateStreaming, ModelDiffsStreaming};
use torii_grpc::proto::world::RetrieveEntitiesResponse;
use torii_grpc::types::schema::Entity;
use torii_grpc::types::{KeysClause, Query};
use torii_relay::client::EventLoop;
use torii_relay::types::Message;

use crate::client::error::{Error, ParseError};
Expand Down Expand Up @@ -99,9 +102,10 @@ impl Client {
})
}

/// Waits for the relay to be ready and listening for messages.
pub async fn wait_for_relay(&self) -> Result<(), Error> {
self.relay_client.command_sender.wait_for_relay().await.map_err(Error::RelayClient)
/// Starts the relay client event loop.
/// This is a blocking call. Spawn this on a separate task.
pub fn relay_runner(&self) -> Arc<Mutex<EventLoop>> {
self.relay_client.event_loop.clone()
}

/// Publishes a message to a topic.
Expand Down
6 changes: 5 additions & 1 deletion crates/torii/core/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ impl ModelCache {
}

async fn update_schema(&self, model: &str) -> Result<Ty, Error> {
let model_name: String = sqlx::query_scalar("SELECT name FROM models WHERE id = ?")
.bind(model)
.fetch_one(&self.pool)
.await?;
let model_members: Vec<SqlModelMember> = sqlx::query_as(
"SELECT id, model_idx, member_idx, name, type, type_enum, enum_options, key FROM \
model_members WHERE model_id = ? ORDER BY model_idx ASC, member_idx ASC",
Expand All @@ -52,7 +56,7 @@ impl ModelCache {
return Err(QueryError::ModelNotFound(model.into()).into());
}

let ty = parse_sql_model_members(model, &model_members);
let ty = parse_sql_model_members(&model_name, &model_members);
let mut cache = self.cache.write().await;
cache.insert(model.into(), ty.clone());

Expand Down
55 changes: 29 additions & 26 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use proto::world::{
};
use sqlx::sqlite::SqliteRow;
use sqlx::{Pool, Row, Sqlite};
use starknet::core::utils::cairo_short_string_to_felt;
use starknet::core::utils::{cairo_short_string_to_felt, get_selector_from_name};
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::JsonRpcClient;
use starknet_crypto::FieldElement;
Expand Down Expand Up @@ -99,9 +99,9 @@ impl DojoWorld {
.fetch_one(&self.pool)
.await?;

let models: Vec<(String, String, String, u32, u32, String)> = sqlx::query_as(
"SELECT name, class_hash, contract_address, packed_size, unpacked_size, layout FROM \
models",
let models: Vec<(String, String, String, String, u32, u32, String)> = sqlx::query_as(
"SELECT id, name, class_hash, contract_address, packed_size, unpacked_size, layout \
FROM models",
)
.fetch_all(&self.pool)
.await?;
Expand All @@ -110,12 +110,12 @@ impl DojoWorld {
for model in models {
let schema = self.model_cache.schema(&model.0).await?;
models_metadata.push(proto::types::ModelMetadata {
name: model.0,
class_hash: model.1,
contract_address: model.2,
packed_size: model.3,
unpacked_size: model.4,
layout: hex::decode(&model.5).unwrap(),
name: model.1,
class_hash: model.2,
contract_address: model.3,
packed_size: model.4,
unpacked_size: model.5,
layout: hex::decode(&model.6).unwrap(),
schema: serde_json::to_vec(&schema).unwrap(),
});
}
Expand Down Expand Up @@ -191,7 +191,7 @@ impl DojoWorld {
// query to filter with limit and offset
let query = format!(
r#"
SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_names
SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
{filter_ids}
Expand All @@ -206,8 +206,8 @@ impl DojoWorld {

let mut entities = Vec::with_capacity(db_entities.len());
for (entity_id, models_str) in db_entities {
let model_names: Vec<&str> = models_str.split(',').collect();
let schemas = self.model_cache.schemas(model_names).await?;
let model_ids: Vec<&str> = models_str.split(',').collect();
let schemas = self.model_cache.schemas(model_ids).await?;

let entity_query = format!("{} WHERE {table}.id = ?", build_sql_query(&schemas)?);
let row = sqlx::query(&entity_query).bind(&entity_id).fetch_one(&self.pool).await?;
Expand Down Expand Up @@ -261,7 +261,7 @@ impl DojoWorld {
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
WHERE {model_relation_table}.model_id = '{}' and {table}.keys LIKE ?
"#,
keys_clause.model
get_selector_from_name(&keys_clause.model).map_err(ParseError::NonAsciiName)?,
);

// total count of rows that matches keys_pattern without limit and offset
Expand All @@ -270,21 +270,21 @@ impl DojoWorld {

let models_query = format!(
r#"
SELECT group_concat({model_relation_table}.model_id) as model_names
SELECT group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
WHERE {table}.keys LIKE ?
GROUP BY {table}.id
HAVING model_names REGEXP '(^|,){}(,|$)'
HAVING model_ids REGEXP '(^|,){}(,|$)'
LIMIT 1
"#,
keys_clause.model
get_selector_from_name(&keys_clause.model).map_err(ParseError::NonAsciiName)?,
);
let (models_str,): (String,) =
sqlx::query_as(&models_query).bind(&keys_pattern).fetch_one(&self.pool).await?;

let model_names = models_str.split(',').collect::<Vec<&str>>();
let schemas = self.model_cache.schemas(model_names).await?;
let model_ids = models_str.split(',').collect::<Vec<&str>>();
let schemas = self.model_cache.schemas(model_ids).await?;

// query to filter with limit and offset
let entities_query = format!(
Expand Down Expand Up @@ -377,19 +377,19 @@ impl DojoWorld {

let models_query = format!(
r#"
SELECT group_concat({model_relation_table}.model_id) as model_names
SELECT group_concat({model_relation_table}.model_id) as model_ids
FROM {table}
JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id
GROUP BY {table}.id
HAVING model_names REGEXP '(^|,){}(,|$)'
HAVING model_ids REGEXP '(^|,){}(,|$)'
LIMIT 1
"#,
member_clause.model
get_selector_from_name(&member_clause.model).map_err(ParseError::NonAsciiName)?,
);
let (models_str,): (String,) = sqlx::query_as(&models_query).fetch_one(&self.pool).await?;

let model_names = models_str.split(',').collect::<Vec<&str>>();
let schemas = self.model_cache.schemas(model_names).await?;
let model_ids = models_str.split(',').collect::<Vec<&str>>();
let schemas = self.model_cache.schemas(model_ids).await?;

let table_name = member_clause.model;
let column_name = format!("external_{}", member_clause.member);
Expand Down Expand Up @@ -422,6 +422,9 @@ impl DojoWorld {
}

pub async fn model_metadata(&self, model: &str) -> Result<proto::types::ModelMetadata, Error> {
// selector
let model = get_selector_from_name(model).map_err(ParseError::NonAsciiName)?;

let (name, class_hash, contract_address, packed_size, unpacked_size, layout): (
String,
String,
Expand All @@ -433,11 +436,11 @@ impl DojoWorld {
"SELECT name, class_hash, contract_address, packed_size, unpacked_size, layout FROM \
models WHERE id = ?",
)
.bind(model)
.bind(format!("{:#x}", model))
.fetch_one(&self.pool)
.await?;

let schema = self.model_cache.schema(model).await?;
let schema = self.model_cache.schema(&format!("{:#x}", model)).await?;
let layout = hex::decode(&layout).unwrap();

Ok(proto::types::ModelMetadata {
Expand Down
8 changes: 4 additions & 4 deletions crates/torii/grpc/src/server/subscriptions/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ impl Service {
// publish all updates if ids is empty or only ids that are subscribed to
if sub.hashed_keys.is_empty() || sub.hashed_keys.contains(&hashed) {
let models_query = r#"
SELECT group_concat(entity_model.model_id) as model_names
SELECT group_concat(entity_model.model_id) as model_ids
FROM entities
JOIN entity_model ON entities.id = entity_model.entity_id
WHERE entities.id = ?
GROUP BY entities.id
"#;
let (model_names,): (String,) =
let (model_ids,): (String,) =
sqlx::query_as(models_query).bind(hashed_keys).fetch_one(&pool).await?;
let model_names: Vec<&str> = model_names.split(',').collect();
let schemas = cache.schemas(model_names).await?;
let model_ids: Vec<&str> = model_ids.split(',').collect();
let schemas = cache.schemas(model_ids).await?;

let entity_query = format!("{} WHERE entities.id = ?", build_sql_query(&schemas)?);
let row = sqlx::query(&entity_query).bind(hashed_keys).fetch_one(&pool).await?;
Expand Down
8 changes: 4 additions & 4 deletions crates/torii/grpc/src/server/subscriptions/event_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ impl Service {
// publish all updates if ids is empty or only ids that are subscribed to
if sub.hashed_keys.is_empty() || sub.hashed_keys.contains(&hashed) {
let models_query = r#"
SELECT group_concat(event_model.model_id) as model_names
SELECT group_concat(event_model.model_id) as model_ids
FROM event_messages
JOIN event_model ON event_messages.id = event_model.entity_id
WHERE event_messages.id = ?
GROUP BY event_messages.id
"#;
let (model_names,): (String,) =
let (model_ids,): (String,) =
sqlx::query_as(models_query).bind(hashed_keys).fetch_one(&pool).await?;
let model_names: Vec<&str> = model_names.split(',').collect();
let schemas = cache.schemas(model_names).await?;
let model_ids: Vec<&str> = model_ids.split(',').collect();
let schemas = cache.schemas(model_ids).await?;

let entity_query =
format!("{} WHERE event_messages.id = ?", build_sql_query(&schemas)?);
Expand Down
52 changes: 27 additions & 25 deletions crates/torii/libp2p/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ pub struct EventLoop {
#[derive(Debug)]
enum Command {
Publish(Message, oneshot::Sender<Result<MessageId, Error>>),
WaitForRelay(oneshot::Sender<Result<(), Error>>),
}

impl RelayClient {
Expand Down Expand Up @@ -162,47 +161,50 @@ impl CommandSender {

rx.await.expect("Failed to receive response")
}

pub async fn wait_for_relay(&self) -> Result<(), Error> {
let (tx, rx) = oneshot::channel();

self.sender.unbounded_send(Command::WaitForRelay(tx)).expect("Failed to send command");

rx.await.expect("Failed to receive response")
}
}

impl EventLoop {
async fn handle_command(
&mut self,
command: Command,
is_relay_ready: bool,
commands_queue: Arc<Mutex<Vec<Command>>>,
) {
match command {
Command::Publish(data, sender) => {
// if the relay is not ready yet, add the message to the queue
if !is_relay_ready {
commands_queue.lock().await.push(Command::Publish(data, sender));
} else {
sender.send(self.publish(&data)).expect("Failed to send response");
}
}
}
}

pub async fn run(&mut self) {
let mut is_relay_ready = false;
let mut relay_ready_tx = None;
let commands_queue = Arc::new(Mutex::new(Vec::new()));

loop {
// Poll the swarm for new events.
select! {
command = self.command_receiver.select_next_some() => {
match command {
Command::Publish(data, sender) => {
sender.send(self.publish(&data)).expect("Failed to send response");
}
Command::WaitForRelay(sender) => {
if is_relay_ready {
sender.send(Ok(())).expect("Failed to send response");
} else {
relay_ready_tx = Some(sender);
}
}
}
self.handle_command(command, is_relay_ready, commands_queue.clone()).await;
},
event = self.swarm.select_next_some() => {
match event {
SwarmEvent::Behaviour(ClientEvent::Gossipsub(gossipsub::Event::Subscribed { topic, .. })) => {
// Handle behaviour events.
info!(target: LOG_TARGET, topic = ?topic, "Relay ready. Received subscription confirmation.");

is_relay_ready = true;
if let Some(tx) = relay_ready_tx.take() {
tx.send(Ok(())).expect("Failed to send response");
if !is_relay_ready {
is_relay_ready = true;

// Execute all the commands that were queued while the relay was not ready.
for command in commands_queue.lock().await.drain(..) {
self.handle_command(command, is_relay_ready, commands_queue.clone()).await;
}
}
}
SwarmEvent::ConnectionClosed { cause: Some(cause), .. } => {
Expand Down
1 change: 0 additions & 1 deletion crates/torii/libp2p/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ mod test {
client.event_loop.lock().await.run().await;
});

client.command_sender.wait_for_relay().await?;
let mut data = Struct { name: "Message".to_string(), children: vec![] };

data.children.push(Member {
Expand Down

0 comments on commit f9acaad

Please sign in to comment.