From 93ccf660f402cb4b3a04076a6bf27692a0406799 Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Fri, 28 Jun 2024 16:58:32 -0400 Subject: [PATCH] feat(torii-grpc): composite query (#2113) * feat(torii-grpc): composite query * chore --- crates/torii/grpc/src/server/mod.rs | 261 +++++++++++++----- .../grpc/src/server/tests/entities_test.rs | 2 +- 2 files changed, 198 insertions(+), 65 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 9d2ad65e56..f0394fb30b 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -297,28 +297,11 @@ impl DojoWorld { table: &str, model_relation_table: &str, entity_relation_column: &str, - keys_clause: proto::types::KeysClause, + keys_clause: &proto::types::KeysClause, limit: Option, offset: Option, ) -> Result<(Vec, u32), Error> { - let keys = keys_clause - .keys - .iter() - .map(|bytes| { - if bytes.is_empty() { - return Ok("0x[0-9a-fA-F]+".to_string()); - } - Ok(FieldElement::from_byte_slice_be(bytes) - .map(|felt| format!("{felt:#x}")) - .map_err(ParseError::FromByteSliceError)?) - }) - .collect::, Error>>()?; - let mut keys_pattern = format!("^{}", keys.join("/")); - - if keys_clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 { - keys_pattern += "(/0x[0-9a-fA-F]+)*"; - } - keys_pattern += "/$"; + let keys_pattern = build_keys_pattern(keys_clause)?; // total count of rows that matches keys_pattern without limit and offset let count_query = format!( @@ -431,28 +414,11 @@ impl DojoWorld { pub(crate) async fn events_by_keys( &self, - keys_clause: proto::types::KeysClause, + keys_clause: &proto::types::KeysClause, limit: Option, offset: Option, ) -> Result, Error> { - let keys = keys_clause - .keys - .iter() - .map(|bytes| { - if bytes.is_empty() { - return Ok("0x[0-9a-fA-F]+".to_string()); - } - Ok(FieldElement::from_byte_slice_be(bytes) - .map(|felt| format!("{felt:#x}")) - .map_err(ParseError::FromByteSliceError)?) - }) - .collect::, Error>>()?; - let mut keys_pattern = format!("^{}", keys.join("/")); - - if keys_clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 { - keys_pattern += "(/0x[0-9a-fA-F]+)*"; - } - keys_pattern += "/$"; + let keys_pattern = build_keys_pattern(keys_clause)?; let events_query = r#" SELECT keys, data, transaction_hash @@ -491,19 +457,7 @@ impl DojoWorld { .value_type .ok_or(QueryError::MissingParam("value_type".into()))?; - let comparison_value = match value_type { - proto::types::value::ValueType::StringValue(string) => string, - proto::types::value::ValueType::IntValue(int) => int.to_string(), - proto::types::value::ValueType::UintValue(uint) => uint.to_string(), - proto::types::value::ValueType::BoolValue(bool) => { - if bool { - "1".to_string() - } else { - "0".to_string() - } - } - _ => return Err(QueryError::UnsupportedQuery.into()), - }; + let comparison_value = value_to_string(&value_type)?; let models_query = format!( r#" @@ -558,15 +512,158 @@ impl DojoWorld { async fn query_by_composite( &self, - _table: &str, - _model_relation_table: &str, - _entity_relation_column: &str, - _composite: proto::types::CompositeClause, - _limit: Option, - _offset: Option, + table: &str, + model_relation_table: &str, + entity_relation_column: &str, + composite: proto::types::CompositeClause, + limit: Option, + offset: Option, ) -> Result<(Vec, u32), Error> { - // TODO: Implement - Err(QueryError::UnsupportedQuery.into()) + // different types of clauses + let mut where_clauses = Vec::new(); + let mut model_clauses: HashMap> = + HashMap::new(); + let mut having_clauses = Vec::new(); + + // bind valeus for prepared statement + let mut bind_values = Vec::new(); + + for clause in composite.clauses { + match clause.clause_type.unwrap() { + ClauseType::HashedKeys(hashed_keys) => { + let ids = hashed_keys + .hashed_keys + .iter() + .map(|id| { + Ok(FieldElement::from_byte_slice_be(id) + .map(|id| format!("{table}.id = '{id:#x}'")) + .map_err(ParseError::FromByteSliceError)?) + }) + .collect::, Error>>()?; + where_clauses.push(format!("({})", ids.join(" OR "))); + } + ClauseType::Keys(keys) => { + let keys_pattern = build_keys_pattern(&keys)?; + where_clauses.push(format!("{table}.keys REGEXP '{keys_pattern}'")); + } + ClauseType::Member(member) => { + let comparison_operator = + ComparisonOperator::from_repr(member.operator as usize) + .expect("invalid comparison operator"); + let value = member.value.unwrap().value_type.unwrap(); + let comparison_value = value_to_string(&value)?; + + let column_name = format!("external_{}", member.member); + + model_clauses.entry(member.model.clone()).or_default().push(( + column_name, + comparison_operator, + comparison_value, + )); + + let model_id = + get_selector_from_name(&member.model).map_err(ParseError::NonAsciiName)?; + having_clauses.push(format!("INSTR(model_ids, '{:#x}') > 0", model_id)); + } + _ => return Err(QueryError::UnsupportedQuery.into()), + } + } + + let mut join_clauses = Vec::new(); + for (model, clauses) in model_clauses { + let model_conditions = clauses + .into_iter() + .map(|(column, op, value)| { + bind_values.push(value); + format!("{}.{} {} ?", model, column, op) + }) + .collect::>() + .join(" AND "); + + join_clauses.push(format!( + "JOIN {} ON {}.id = {}.entity_id AND ({})", + model, table, model, model_conditions + )); + } + + let join_clause = join_clauses.join(" "); + let where_clause = if !where_clauses.is_empty() { + format!("WHERE {}", where_clauses.join(" AND ")) + } else { + String::new() + }; + let having_clause = if !having_clauses.is_empty() { + format!("HAVING {}", having_clauses.join(" AND ")) + } else { + String::new() + }; + + let count_query = format!( + r#" + SELECT COUNT(DISTINCT {table}.id) + FROM {table} + {join_clause} + {where_clause} + "# + ); + + let mut count_query = sqlx::query_scalar::<_, u32>(&count_query); + for value in &bind_values { + count_query = count_query.bind(value); + } + + let total_count = count_query.fetch_one(&self.pool).await?; + + if total_count == 0 { + return Ok((Vec::new(), 0)); + } + + let query = format!( + r#" + SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_ids + FROM {table} + JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id + {join_clause} + {where_clause} + GROUP BY {table}.id + {having_clause} + ORDER BY {table}.event_id DESC + LIMIT ? OFFSET ? + "# + ); + + let mut db_query = sqlx::query_as(&query); + for value in bind_values { + db_query = db_query.bind(value); + } + db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0)); + + let db_entities: Vec<(String, String)> = db_query.fetch_all(&self.pool).await?; + + let mut entities = Vec::with_capacity(db_entities.len()); + for (entity_id, models_str) in &db_entities { + let model_ids: Vec<&str> = models_str.split(',').collect(); + let schemas = self.model_cache.schemas(model_ids).await?; + + let (entity_query, arrays_queries) = build_sql_query( + &schemas, + table, + entity_relation_column, + Some(&format!("{table}.id = ?")), + Some(&format!("{table}.id = ?")), + )?; + + let row = sqlx::query(&entity_query).bind(entity_id).fetch_one(&self.pool).await?; + let mut arrays_rows = HashMap::new(); + for (name, query) in arrays_queries { + let rows = sqlx::query(&query).bind(entity_id).fetch_all(&self.pool).await?; + arrays_rows.insert(name, rows); + } + + entities.push(map_row_to_entity(&row, &arrays_rows, &schemas)?); + } + + Ok((entities, total_count)) } pub async fn model_metadata(&self, model: &str) -> Result { @@ -669,7 +766,7 @@ impl DojoWorld { ENTITIES_TABLE, ENTITIES_MODEL_RELATION_TABLE, ENTITIES_ENTITY_RELATION_COLUMN, - keys, + &keys, Some(query.limit), Some(query.offset), ) @@ -746,7 +843,7 @@ impl DojoWorld { EVENT_MESSAGES_TABLE, EVENT_MESSAGES_MODEL_RELATION_TABLE, EVENT_MESSAGES_ENTITY_RELATION_COLUMN, - keys, + &keys, Some(query.limit), Some(query.offset), ) @@ -783,9 +880,9 @@ impl DojoWorld { async fn retrieve_events( &self, - query: proto::types::EventQuery, + query: &proto::types::EventQuery, ) -> Result { - let events = match query.keys { + let events = match &query.keys { None => self.events_all(query.limit, query.offset).await?, Some(keys) => self.events_by_keys(keys, Some(query.limit), Some(query.offset)).await?, }; @@ -838,6 +935,42 @@ fn map_row_to_entity( Ok(proto::types::Entity { hashed_keys: hashed_keys.to_bytes_be().to_vec(), models }) } +// this builds a sql safe regex pattern to match against for keys +fn build_keys_pattern(clause: &proto::types::KeysClause) -> Result { + let keys = clause + .keys + .iter() + .map(|bytes| { + if bytes.is_empty() { + return Ok("0x[0-9a-fA-F]+".to_string()); + } + Ok(FieldElement::from_byte_slice_be(bytes) + .map(|felt| format!("{felt:#x}")) + .map_err(ParseError::FromByteSliceError)?) + }) + .collect::, Error>>()?; + let mut keys_pattern = format!("^{}", keys.join("/")); + + if clause.pattern_matching == proto::types::PatternMatching::VariableLen as i32 { + keys_pattern += "(/0x[0-9a-fA-F]+)*"; + } + keys_pattern += "/$"; + + Ok(keys_pattern) +} + +fn value_to_string(value: &proto::types::value::ValueType) -> Result { + match value { + proto::types::value::ValueType::StringValue(string) => Ok(string.clone()), + proto::types::value::ValueType::IntValue(int) => Ok(int.to_string()), + proto::types::value::ValueType::UintValue(uint) => Ok(uint.to_string()), + proto::types::value::ValueType::BoolValue(bool) => { + Ok(if *bool { "1".to_string() } else { "0".to_string() }) + } + _ => Err(QueryError::UnsupportedQuery.into()), + } +} + type ServiceResult = Result, Status>; type SubscribeModelsResponseStream = Pin> + Send>>; @@ -943,7 +1076,7 @@ impl proto::world::world_server::World for DojoWorld { .ok_or_else(|| Status::invalid_argument("Missing query argument"))?; let events = - self.retrieve_events(query).await.map_err(|e| Status::internal(e.to_string()))?; + self.retrieve_events(&query).await.map_err(|e| Status::internal(e.to_string()))?; Ok(Response::new(events)) } diff --git a/crates/torii/grpc/src/server/tests/entities_test.rs b/crates/torii/grpc/src/server/tests/entities_test.rs index 9761d86335..3e4b5081f1 100644 --- a/crates/torii/grpc/src/server/tests/entities_test.rs +++ b/crates/torii/grpc/src/server/tests/entities_test.rs @@ -118,7 +118,7 @@ async fn test_entities_queries() { "entities", "entity_model", "entity_id", - KeysClause { + &KeysClause { keys: vec![account.address().to_bytes_be().to_vec()], pattern_matching: 0, models: vec![],