From c6dab678a2fc129869fceb5ab487d9d3aed56898 Mon Sep 17 00:00:00 2001 From: Yun Date: Tue, 19 Dec 2023 11:56:54 -0800 Subject: [PATCH] Torii grpc member clause query (#1312) --- Cargo.lock | 1 + crates/torii/grpc/Cargo.toml | 1 + crates/torii/grpc/src/server/mod.rs | 61 ++++++++++++++++++++++++++--- crates/torii/grpc/src/types/mod.rs | 38 +++++++++++++++++- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 05748eb8b5..039efe4190 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9872,6 +9872,7 @@ dependencies = [ "sqlx", "starknet", "starknet-crypto 0.6.1", + "strum 0.25.0", "strum_macros 0.25.3", "thiserror", "tokio", diff --git a/crates/torii/grpc/Cargo.toml b/crates/torii/grpc/Cargo.toml index b1101c82fc..24987b5534 100644 --- a/crates/torii/grpc/Cargo.toml +++ b/crates/torii/grpc/Cargo.toml @@ -18,6 +18,7 @@ thiserror.workspace = true torii-core = { path = "../core", optional = true } serde.workspace = true +strum.workspace = true strum_macros.workspace = true crypto-bigint.workspace = true diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 8216990e27..a6c3625a28 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -34,6 +34,7 @@ use crate::proto::types::clause::ClauseType; use crate::proto::world::world_server::WorldServer; use crate::proto::world::{SubscribeEntitiesRequest, SubscribeEntityResponse}; use crate::proto::{self}; +use crate::types::ComparisonOperator; #[derive(Clone)] pub struct DojoWorld { @@ -242,14 +243,62 @@ impl DojoWorld { db_entities.iter().map(|row| Self::map_row_to_entity(row, &schemas)).collect() } - async fn entities_by_attribute( + async fn entities_by_member( &self, - _attribute: proto::types::MemberClause, + member_clause: proto::types::MemberClause, _limit: u32, _offset: u32, ) -> Result, Error> { - // TODO: Implement - Err(QueryError::UnsupportedQuery.into()) + let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) + .expect("invalid comparison operator"); + + let value_type = member_clause + .value + .ok_or(QueryError::MissingParam("value".into()))? + .value_type + .ok_or(QueryError::MissingParam("value_type".into()))?; + + let comparison_value = match value_type { + proto::types::value::ValueType::StringValue(string) => string, + proto::types::value::ValueType::IntValue(int) => int.to_string(), + proto::types::value::ValueType::UintValue(uint) => uint.to_string(), + proto::types::value::ValueType::BoolValue(bool) => { + if bool { + "1".to_string() + } else { + "0".to_string() + } + } + _ => return Err(QueryError::UnsupportedQuery.into()), + }; + + let models_query = format!( + r#" + SELECT group_concat(entity_model.model_id) as model_names + FROM entities + JOIN entity_model ON entities.id = entity_model.entity_id + GROUP BY entities.id + HAVING model_names REGEXP '(^|,){}(,|$)' + LIMIT 1 + "#, + member_clause.model + ); + let (models_str,): (String,) = sqlx::query_as(&models_query).fetch_one(&self.pool).await?; + + let model_names = models_str.split(',').collect::>(); + let schemas = self.model_cache.schemas(model_names).await?; + + let table_name = member_clause.model; + let column_name = format!("external_{}", member_clause.member); + let member_query = format!( + "{} WHERE {table_name}.{column_name} {comparison_operator} ?", + build_sql_query(&schemas)? + ); + + let db_entities = + sqlx::query(&member_query).bind(comparison_value).fetch_all(&self.pool).await?; + + db_entities.iter().map(|row| Self::map_row_to_entity(row, &schemas)).collect() } async fn entities_by_composite( @@ -350,8 +399,8 @@ impl DojoWorld { self.entities_by_keys(keys, query.limit, query.offset).await? } - ClauseType::Member(attribute) => { - self.entities_by_attribute(attribute, query.limit, query.offset).await? + ClauseType::Member(member) => { + self.entities_by_member(member, query.limit, query.offset).await? } ClauseType::Composite(composite) => { self.entities_by_composite(composite, query.limit, query.offset).await? diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index da97af938b..bc18d7f5cf 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -1,3 +1,4 @@ +use core::fmt; use std::collections::HashMap; use std::str::FromStr; @@ -8,6 +9,7 @@ use starknet::core::types::{ ContractStorageDiffItem, FromByteSliceError, FromStrError, StateDiff, StateUpdate, StorageEntry, }; use starknet_crypto::FieldElement; +use strum_macros::{AsRefStr, EnumIter, FromRepr}; use crate::proto::{self}; @@ -48,13 +50,19 @@ pub struct CompositeClause { pub clauses: Vec, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +#[derive( + Debug, AsRefStr, Serialize, Deserialize, EnumIter, FromRepr, PartialEq, Hash, Eq, Clone, +)] +#[strum(serialize_all = "UPPERCASE")] pub enum LogicalOperator { And, Or, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] +#[derive( + Debug, AsRefStr, Serialize, Deserialize, EnumIter, FromRepr, PartialEq, Hash, Eq, Clone, +)] +#[strum(serialize_all = "UPPERCASE")] pub enum ComparisonOperator { Eq, Neq, @@ -64,6 +72,32 @@ pub enum ComparisonOperator { Lte, } +impl fmt::Display for ComparisonOperator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ComparisonOperator::Gt => write!(f, ">"), + ComparisonOperator::Gte => write!(f, ">="), + ComparisonOperator::Lt => write!(f, "<"), + ComparisonOperator::Lte => write!(f, "<="), + ComparisonOperator::Neq => write!(f, "!="), + ComparisonOperator::Eq => write!(f, "="), + } + } +} + +impl From for ComparisonOperator { + fn from(operator: proto::types::ComparisonOperator) -> Self { + match operator { + proto::types::ComparisonOperator::Eq => ComparisonOperator::Eq, + proto::types::ComparisonOperator::Gte => ComparisonOperator::Gte, + proto::types::ComparisonOperator::Gt => ComparisonOperator::Gt, + proto::types::ComparisonOperator::Lt => ComparisonOperator::Lt, + proto::types::ComparisonOperator::Lte => ComparisonOperator::Lte, + proto::types::ComparisonOperator::Neq => ComparisonOperator::Neq, + } + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] pub struct Value { pub primitive_type: Primitive,