From 48c9c2ea078ff4ff3c7074fcd81b30819247b509 Mon Sep 17 00:00:00 2001 From: Nasr Date: Wed, 26 Jun 2024 22:56:01 -0400 Subject: [PATCH] refactor: use array of models for keysclause --- crates/torii/grpc/proto/types.proto | 2 +- crates/torii/grpc/src/server/mod.rs | 93 +++++++++++-------- .../grpc/src/server/subscriptions/entity.rs | 31 ++----- .../src/server/subscriptions/event_message.rs | 30 ++---- .../grpc/src/server/tests/entities_test.rs | 2 +- crates/torii/grpc/src/types/mod.rs | 6 +- 6 files changed, 74 insertions(+), 90 deletions(-) diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index 40ed7e7265..9575820858 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -114,7 +114,7 @@ message EntityKeysClause { message KeysClause { repeated bytes keys = 1; PatternMatching pattern_matching = 2; - optional string model = 3; + repeated string models = 3; } message HashedKeysClause { diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index c146e30346..9d2ad65e56 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -286,7 +286,7 @@ impl DojoWorld { arrays_rows.insert(name, rows); } - entities.push(Self::map_row_to_entity(&row, &arrays_rows, &schemas)?); + entities.push(map_row_to_entity(&row, &arrays_rows, &schemas)?); } Ok((entities, total_count)) @@ -327,19 +327,27 @@ impl DojoWorld { FROM {table} {} "#, - if let Some(model) = &keys_clause.model { + if !keys_clause.models.is_empty() { + let model_ids = keys_clause + .models + .iter() + .map(|model| get_selector_from_name(model).map_err(ParseError::NonAsciiName)) + .collect::, _>>()?; + let model_ids_str = + model_ids.iter().map(|id| format!("'{:#x}'", id)).collect::>().join(","); format!( r#" - JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - WHERE {model_relation_table}.model_id = '{:#x}' AND {table}.keys REGEXP ? - "#, - get_selector_from_name(model).map_err(ParseError::NonAsciiName)? + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + WHERE {model_relation_table}.model_id IN ({}) + AND {table}.keys REGEXP ? + "#, + model_ids_str ) } else { format!( r#" - WHERE {table}.keys REGEXP ? - "# + WHERE {table}.keys REGEXP ? + "# ) } ); @@ -358,19 +366,29 @@ impl DojoWorld { JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id WHERE {table}.keys REGEXP ? GROUP BY {table}.id - ORDER BY {table}.event_id DESC "# ); - if let Some(model) = &keys_clause.model { + if !keys_clause.models.is_empty() { + // filter by models models_query += &format!( - r#" - HAVING INSTR(model_ids, '{:#x}') > 0 - "#, - get_selector_from_name(model).map_err(ParseError::NonAsciiName)? + "HAVING {}", + keys_clause + .models + .iter() + .map(|model| { + let model_id = + get_selector_from_name(model).map_err(ParseError::NonAsciiName)?; + Ok(format!("INSTR(model_ids, '{:#x}') > 0", model_id)) + }) + .collect::, Error>>()? + .join(" OR ") + .as_str() ); } + models_query += &format!(" ORDER BY {table}.event_id DESC"); + if limit.is_some() { models_query += " LIMIT ?"; } @@ -405,7 +423,7 @@ impl DojoWorld { arrays_rows.insert(name, rows); } - entities.push(Self::map_row_to_entity(&row, &arrays_rows, &schemas)?); + entities.push(map_row_to_entity(&row, &arrays_rows, &schemas)?); } Ok((entities, total_count)) @@ -531,7 +549,7 @@ impl DojoWorld { let entities_collection = db_entities .iter() - .map(|row| Self::map_row_to_entity(row, &arrays_rows, &schemas)) + .map(|row| map_row_to_entity(row, &arrays_rows, &schemas)) .collect::, Error>>()?; // Since there is not limit and offset, total_count is same as number of entities let total_count = entities_collection.len() as u32; @@ -780,30 +798,6 @@ impl DojoWorld { ) -> Result>, Error> { self.event_manager.add_subscriber(clause.try_into().unwrap()).await } - - fn map_row_to_entity( - row: &SqliteRow, - arrays_rows: &HashMap>, - schemas: &[Ty], - ) -> Result { - let hashed_keys = - FieldElement::from_str(&row.get::("id")).map_err(ParseError::FromStr)?; - let models = schemas - .iter() - .map(|schema| { - let mut schema = schema.to_owned(); - map_row_to_ty("", &schema.name(), &mut schema, row, arrays_rows)?; - Ok(schema - .as_struct() - .expect("schema should be struct") - .to_owned() - .try_into() - .unwrap()) - }) - .collect::, Error>>()?; - - Ok(proto::types::Entity { hashed_keys: hashed_keys.to_bytes_be().to_vec(), models }) - } } fn process_event_field(data: &str) -> Result>, Error> { @@ -825,6 +819,25 @@ fn map_row_to_event(row: &(String, String, String)) -> Result>, + schemas: &[Ty], +) -> Result { + let hashed_keys = + FieldElement::from_str(&row.get::("id")).map_err(ParseError::FromStr)?; + let models = schemas + .iter() + .map(|schema| { + let mut schema = schema.to_owned(); + map_row_to_ty("", &schema.name(), &mut schema, row, arrays_rows)?; + Ok(schema.as_struct().expect("schema should be struct").to_owned().try_into().unwrap()) + }) + .collect::, Error>>()?; + + Ok(proto::types::Entity { hashed_keys: hashed_keys.to_bytes_be().to_vec(), models }) +} + type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>; diff --git a/crates/torii/grpc/src/server/subscriptions/entity.rs b/crates/torii/grpc/src/server/subscriptions/entity.rs index a73d34d962..a62ee8dd69 100644 --- a/crates/torii/grpc/src/server/subscriptions/entity.rs +++ b/crates/torii/grpc/src/server/subscriptions/entity.rs @@ -14,7 +14,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; use torii_core::cache::ModelCache; use torii_core::error::{Error, ParseError}; -use torii_core::model::{build_sql_query, map_row_to_ty}; +use torii_core::model::build_sql_query; use torii_core::simple_broker::SimpleBroker; use torii_core::sql::FELT_DELIMITER; use torii_core::types::Entity; @@ -22,6 +22,7 @@ use tracing::{error, trace}; use crate::proto; use crate::proto::world::SubscribeEntityResponse; +use crate::server::map_row_to_entity; use crate::types::{EntityKeysClause, PatternMatching}; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::entity"; @@ -115,11 +116,11 @@ impl Service { Some(EntityKeysClause::Keys(clause)) => { // if we have a model clause, then we need to check that the entity // has an updated model and that the model name matches the clause - if let Some(model) = &clause.model { - if let Some(updated_model) = &entity.updated_model { - if updated_model.name() != model.clone() { - continue; - } + if let Some(updated_model) = &entity.updated_model { + if !clause.models.is_empty() + && !clause.models.contains(&updated_model.name()) + { + continue; } } @@ -199,24 +200,8 @@ impl Service { arrays_rows.insert(name, row); } - let models = schemas - .into_iter() - .map(|mut s| { - map_row_to_ty("", &s.name(), &mut s, &row, &arrays_rows)?; - - Ok(s.as_struct() - .expect("schema should be a struct") - .to_owned() - .try_into() - .unwrap()) - }) - .collect::, Error>>()?; - let resp = proto::world::SubscribeEntityResponse { - entity: Some(proto::types::Entity { - hashed_keys: hashed.to_bytes_be().to_vec(), - models, - }), + entity: Some(map_row_to_entity(&row, &arrays_rows, &schemas)?), }; if sub.sender.send(Ok(resp)).await.is_err() { diff --git a/crates/torii/grpc/src/server/subscriptions/event_message.rs b/crates/torii/grpc/src/server/subscriptions/event_message.rs index c929b5354d..0b4b490ded 100644 --- a/crates/torii/grpc/src/server/subscriptions/event_message.rs +++ b/crates/torii/grpc/src/server/subscriptions/event_message.rs @@ -14,7 +14,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; use torii_core::cache::ModelCache; use torii_core::error::{Error, ParseError}; -use torii_core::model::{build_sql_query, map_row_to_ty}; +use torii_core::model::build_sql_query; use torii_core::simple_broker::SimpleBroker; use torii_core::sql::FELT_DELIMITER; use torii_core::types::EventMessage; @@ -22,6 +22,7 @@ use tracing::{error, trace}; use crate::proto; use crate::proto::world::SubscribeEntityResponse; +use crate::server::map_row_to_entity; use crate::types::{EntityKeysClause, PatternMatching}; pub(crate) const LOG_TARGET: &str = "torii::grpc::server::subscriptions::event_message"; @@ -114,11 +115,11 @@ impl Service { Some(EntityKeysClause::Keys(clause)) => { // if we have a model clause, then we need to check that the entity // has an updated model and that the model name matches the clause - if let Some(model) = &clause.model { - if let Some(updated_model) = &entity.updated_model { - if updated_model.name() != model.clone() { - continue; - } + if let Some(updated_model) = &entity.updated_model { + if !clause.models.is_empty() + && !clause.models.contains(&updated_model.name()) + { + continue; } } @@ -184,23 +185,8 @@ impl Service { arrays_rows.insert(name, rows); } - let models = schemas - .into_iter() - .map(|mut s| { - map_row_to_ty("", &s.name(), &mut s, &row, &arrays_rows)?; - Ok(s.as_struct() - .expect("schema should be a struct") - .to_owned() - .try_into() - .unwrap()) - }) - .collect::, Error>>()?; - let resp = proto::world::SubscribeEntityResponse { - entity: Some(proto::types::Entity { - hashed_keys: hashed.to_bytes_be().to_vec(), - models, - }), + entity: Some(map_row_to_entity(&row, &arrays_rows, &schemas)?), }; if sub.sender.send(Ok(resp)).await.is_err() { diff --git a/crates/torii/grpc/src/server/tests/entities_test.rs b/crates/torii/grpc/src/server/tests/entities_test.rs index 442fec4816..9761d86335 100644 --- a/crates/torii/grpc/src/server/tests/entities_test.rs +++ b/crates/torii/grpc/src/server/tests/entities_test.rs @@ -121,7 +121,7 @@ async fn test_entities_queries() { KeysClause { keys: vec![account.address().to_bytes_be().to_vec()], pattern_matching: 0, - model: None, + models: vec![], }, Some(1), None, diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index 31c767e491..cf00ede93f 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -45,7 +45,7 @@ pub struct ModelKeysClause { pub struct KeysClause { pub keys: Vec, pub pattern_matching: PatternMatching, - pub model: Option, + pub models: Vec, } #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] @@ -186,7 +186,7 @@ impl From for proto::types::KeysClause { Self { keys: value.keys.iter().map(|k| k.to_bytes_be().into()).collect(), pattern_matching: value.pattern_matching as i32, - model: value.model, + models: value.models, } } } @@ -201,7 +201,7 @@ impl TryFrom for KeysClause { .map(|k| FieldElement::from_byte_slice_be(k)) .collect::, _>>()?; - Ok(Self { keys, pattern_matching: value.pattern_matching().into(), model: value.model }) + Ok(Self { keys, pattern_matching: value.pattern_matching().into(), models: value.models }) } }