From 2c18b721832d1656a2ebea4ec72633cf401c09d3 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 17 Dec 2024 14:02:08 +0700 Subject: [PATCH 1/4] feat: in operator grpc finish --- crates/torii/grpc/proto/types.proto | 7 +++++ crates/torii/grpc/src/server/mod.rs | 41 ++++++++++++++++++----------- crates/torii/grpc/src/types/mod.rs | 6 +++++ 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index 5080e5f83a..2b5545e497 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -120,9 +120,14 @@ message MemberValue { oneof value_type { Primitive primitive = 1; string string = 2; + MemberValueList list = 3; } } +message MemberValueList { + repeated MemberValue values = 1; +} + message MemberClause { string model = 2; string member = 3; @@ -152,6 +157,8 @@ enum ComparisonOperator { GTE = 3; LT = 4; LTE = 5; + IN = 6; + NOT_IN = 7; } message Token { diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 8a272cae8d..f1d8f77285 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -773,15 +773,19 @@ impl DojoWorld { let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) .expect("invalid comparison operator"); - let comparison_value = - match member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.value_type { - Some(ValueType::String(value)) => value, + fn comparison_value(value: &proto::types::MemberValue) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => Ok(value.to_string()), Some(ValueType::Primitive(value)) => { - let primitive: Primitive = value.try_into()?; - primitive.to_sql_value() + let primitive: Primitive = (value.clone()).try_into()?; + Ok(primitive.to_sql_value()) + } + Some(ValueType::List(values)) => { + Ok(format!("({})", values.values.iter().map(|v| comparison_value(v)).collect::, Error>>()?.join(", "))) } None => return Err(QueryError::MissingParam("value_type".into()).into()), - }; + } + } let (namespace, model) = member_clause .model @@ -838,14 +842,15 @@ impl DojoWorld { offset, )?; + let value = comparison_value(&member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?)?; let total_count = sqlx::query_scalar(&count_query) - .bind(comparison_value.clone()) + .bind(value.clone()) .bind(entity_updated_after.clone()) .fetch_optional(&self.pool) .await? .unwrap_or(0); - let mut query = sqlx::query(&entity_query).bind(comparison_value); + let mut query = sqlx::query(&entity_query).bind(value); if let Some(entity_updated_after) = entity_updated_after.clone() { query = query.bind(entity_updated_after); } @@ -1356,17 +1361,21 @@ fn build_composite_clause( ClauseType::Member(member) => { let comparison_operator = ComparisonOperator::from_repr(member.operator as usize) .expect("invalid comparison operator"); - let value = member.value.clone(); - let comparison_value = - match value.ok_or(QueryError::MissingParam("value".into()))?.value_type { - Some(ValueType::String(value)) => value, + let value = member.value.clone().ok_or(QueryError::MissingParam("value".into()))?; + fn comparison_value(value: &proto::types::MemberValue) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => Ok(value.to_string()), Some(ValueType::Primitive(value)) => { - let primitive: Primitive = value.try_into()?; - primitive.to_sql_value() + let primitive: Primitive = (value.clone()).try_into()?; + Ok(primitive.to_sql_value()) + } + Some(ValueType::List(values)) => { + Ok(format!("({})", values.values.iter().map(|v| comparison_value(v)).collect::, Error>>()?.join(", "))) } None => return Err(QueryError::MissingParam("value_type".into()).into()), - }; - bind_values.push(comparison_value); + } + } + bind_values.push(comparison_value(&value)?); let model = member.model.clone(); // Get or create unique alias for this model diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index aec1604b62..e9fe1fca42 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -187,6 +187,8 @@ pub enum ComparisonOperator { Gte, Lt, Lte, + In, + NotIn, } impl fmt::Display for ComparisonOperator { @@ -198,6 +200,8 @@ impl fmt::Display for ComparisonOperator { ComparisonOperator::Lte => write!(f, "<="), ComparisonOperator::Neq => write!(f, "!="), ComparisonOperator::Eq => write!(f, "="), + ComparisonOperator::In => write!(f, "IN"), + ComparisonOperator::NotIn => write!(f, "NOT IN"), } } } @@ -211,6 +215,8 @@ impl From for ComparisonOperator { proto::types::ComparisonOperator::Lt => ComparisonOperator::Lt, proto::types::ComparisonOperator::Lte => ComparisonOperator::Lte, proto::types::ComparisonOperator::Neq => ComparisonOperator::Neq, + proto::types::ComparisonOperator::In => ComparisonOperator::In, + proto::types::ComparisonOperator::NotIn => ComparisonOperator::NotIn, } } } From 0abc2e80a8522cf1a8cb9bd45433ec6c63d2b990 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 17 Dec 2024 17:21:07 +0700 Subject: [PATCH 2/4] feat: clean up grpc --- crates/torii/grpc/src/server/mod.rs | 82 ++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 24 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index f1d8f77285..f02853371c 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -773,16 +773,29 @@ impl DojoWorld { let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) .expect("invalid comparison operator"); - fn comparison_value(value: &proto::types::MemberValue) -> Result { + fn prepare_comparison( + value: &proto::types::MemberValue, + bind_values: &mut Vec, + ) -> Result { match &value.value_type { - Some(ValueType::String(value)) => Ok(value.to_string()), + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } Some(ValueType::Primitive(value)) => { let primitive: Primitive = (value.clone()).try_into()?; - Ok(primitive.to_sql_value()) - } - Some(ValueType::List(values)) => { - Ok(format!("({})", values.values.iter().map(|v| comparison_value(v)).collect::, Error>>()?.join(", "))) + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) } + Some(ValueType::List(values)) => Ok(format!( + "({})", + values + .values + .iter() + .map(|v| prepare_comparison(v, bind_values)) + .collect::, Error>>()? + .join(", ") + )), None => return Err(QueryError::MissingParam("value_type".into()).into()), } } @@ -826,8 +839,13 @@ impl DojoWorld { self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); // Use the member name directly as the column name since it's already flattened + let mut bind_values = Vec::new(); + let value = prepare_comparison( + &member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?, + &mut bind_values, + )?; let mut where_clause = - format!("[{}].[{}] {comparison_operator} ?", member_clause.model, member_clause.member); + format!("[{}].[{}] {comparison_operator} {value}", member_clause.model, member_clause.member); if entity_updated_after.is_some() { where_clause += &format!(" AND {table}.updated_at >= ?"); } @@ -841,16 +859,19 @@ impl DojoWorld { limit, offset, )?; + let mut count_query = sqlx::query_scalar(&count_query); + for value in &bind_values { + count_query = count_query.bind(value); + } + if let Some(entity_updated_after) = entity_updated_after.clone() { + count_query = count_query.bind(entity_updated_after); + } + let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); - let value = comparison_value(&member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?)?; - let total_count = sqlx::query_scalar(&count_query) - .bind(value.clone()) - .bind(entity_updated_after.clone()) - .fetch_optional(&self.pool) - .await? - .unwrap_or(0); - - let mut query = sqlx::query(&entity_query).bind(value); + let mut query = sqlx::query(&entity_query); + for value in &bind_values { + query = query.bind(value); + } if let Some(entity_updated_after) = entity_updated_after.clone() { query = query.bind(entity_updated_after); } @@ -1362,20 +1383,33 @@ fn build_composite_clause( let comparison_operator = ComparisonOperator::from_repr(member.operator as usize) .expect("invalid comparison operator"); let value = member.value.clone().ok_or(QueryError::MissingParam("value".into()))?; - fn comparison_value(value: &proto::types::MemberValue) -> Result { + fn prepare_comparison( + value: &proto::types::MemberValue, + bind_values: &mut Vec, + ) -> Result { match &value.value_type { - Some(ValueType::String(value)) => Ok(value.to_string()), + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } Some(ValueType::Primitive(value)) => { let primitive: Primitive = (value.clone()).try_into()?; - Ok(primitive.to_sql_value()) - } - Some(ValueType::List(values)) => { - Ok(format!("({})", values.values.iter().map(|v| comparison_value(v)).collect::, Error>>()?.join(", "))) + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) } + Some(ValueType::List(values)) => Ok(format!( + "({})", + values + .values + .iter() + .map(|v| prepare_comparison(v, bind_values)) + .collect::, Error>>()? + .join(", ") + )), None => return Err(QueryError::MissingParam("value_type".into()).into()), } } - bind_values.push(comparison_value(&value)?); + let value = prepare_comparison(&value, &mut bind_values)?; let model = member.model.clone(); // Get or create unique alias for this model @@ -1403,7 +1437,7 @@ fn build_composite_clause( // Use the column name directly since it's already flattened where_clauses - .push(format!("([{alias}].[{}] {comparison_operator} ?)", member.member)); + .push(format!("([{alias}].[{}] {comparison_operator} {value})", member.member)); } ClauseType::Composite(nested) => { // Handle nested composite by recursively building the clause From 06e557f30b008e7ccefe4cab09c353fde632eb99 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 17 Dec 2024 17:21:32 +0700 Subject: [PATCH 3/4] fmt --- crates/torii/grpc/src/server/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index f02853371c..9c9bd35aca 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -844,8 +844,10 @@ impl DojoWorld { &member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?, &mut bind_values, )?; - let mut where_clause = - format!("[{}].[{}] {comparison_operator} {value}", member_clause.model, member_clause.member); + let mut where_clause = format!( + "[{}].[{}] {comparison_operator} {value}", + member_clause.model, member_clause.member + ); if entity_updated_after.is_some() { where_clause += &format!(" AND {table}.updated_at >= ?"); } From 5dc0c8ccdd3899eef7776e91f09f61ac37e7299f Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 17 Dec 2024 17:24:08 +0700 Subject: [PATCH 4/4] fmt --- crates/torii/grpc/src/server/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 9c9bd35aca..8e884a08ce 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -796,7 +796,7 @@ impl DojoWorld { .collect::, Error>>()? .join(", ") )), - None => return Err(QueryError::MissingParam("value_type".into()).into()), + None => Err(QueryError::MissingParam("value_type".into()).into()), } } @@ -1408,7 +1408,7 @@ fn build_composite_clause( .collect::, Error>>()? .join(", ") )), - None => return Err(QueryError::MissingParam("value_type".into()).into()), + None => Err(QueryError::MissingParam("value_type".into()).into()), } } let value = prepare_comparison(&value, &mut bind_values)?;