diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index 5080e5f83a..2b5545e497 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -120,9 +120,14 @@ message MemberValue { oneof value_type { Primitive primitive = 1; string string = 2; + MemberValueList list = 3; } } +message MemberValueList { + repeated MemberValue values = 1; +} + message MemberClause { string model = 2; string member = 3; @@ -152,6 +157,8 @@ enum ComparisonOperator { GTE = 3; LT = 4; LTE = 5; + IN = 6; + NOT_IN = 7; } message Token { diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 8a272cae8d..8e884a08ce 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -773,15 +773,32 @@ impl DojoWorld { let comparison_operator = ComparisonOperator::from_repr(member_clause.operator as usize) .expect("invalid comparison operator"); - let comparison_value = - match member_clause.value.ok_or(QueryError::MissingParam("value".into()))?.value_type { - Some(ValueType::String(value)) => value, + fn prepare_comparison( + value: &proto::types::MemberValue, + bind_values: &mut Vec, + ) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } Some(ValueType::Primitive(value)) => { - let primitive: Primitive = value.try_into()?; - primitive.to_sql_value() + let primitive: Primitive = (value.clone()).try_into()?; + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) } - None => return Err(QueryError::MissingParam("value_type".into()).into()), - }; + Some(ValueType::List(values)) => Ok(format!( + "({})", + values + .values + .iter() + .map(|v| prepare_comparison(v, bind_values)) + .collect::, Error>>()? + .join(", ") + )), + None => Err(QueryError::MissingParam("value_type".into()).into()), + } + } let (namespace, model) = member_clause .model @@ -822,8 +839,15 @@ impl DojoWorld { self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); // Use the member name directly as the column name since it's already flattened - let mut where_clause = - format!("[{}].[{}] {comparison_operator} ?", member_clause.model, member_clause.member); + let mut bind_values = Vec::new(); + let value = prepare_comparison( + &member_clause.value.clone().ok_or(QueryError::MissingParam("value".into()))?, + &mut bind_values, + )?; + let mut where_clause = format!( + "[{}].[{}] {comparison_operator} {value}", + member_clause.model, member_clause.member + ); if entity_updated_after.is_some() { where_clause += &format!(" AND {table}.updated_at >= ?"); } @@ -837,15 +861,19 @@ impl DojoWorld { limit, offset, )?; + let mut count_query = sqlx::query_scalar(&count_query); + for value in &bind_values { + count_query = count_query.bind(value); + } + if let Some(entity_updated_after) = entity_updated_after.clone() { + count_query = count_query.bind(entity_updated_after); + } + let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); - let total_count = sqlx::query_scalar(&count_query) - .bind(comparison_value.clone()) - .bind(entity_updated_after.clone()) - .fetch_optional(&self.pool) - .await? - .unwrap_or(0); - - let mut query = sqlx::query(&entity_query).bind(comparison_value); + let mut query = sqlx::query(&entity_query); + for value in &bind_values { + query = query.bind(value); + } if let Some(entity_updated_after) = entity_updated_after.clone() { query = query.bind(entity_updated_after); } @@ -1356,17 +1384,34 @@ fn build_composite_clause( ClauseType::Member(member) => { let comparison_operator = ComparisonOperator::from_repr(member.operator as usize) .expect("invalid comparison operator"); - let value = member.value.clone(); - let comparison_value = - match value.ok_or(QueryError::MissingParam("value".into()))?.value_type { - Some(ValueType::String(value)) => value, + let value = member.value.clone().ok_or(QueryError::MissingParam("value".into()))?; + fn prepare_comparison( + value: &proto::types::MemberValue, + bind_values: &mut Vec, + ) -> Result { + match &value.value_type { + Some(ValueType::String(value)) => { + bind_values.push(value.to_string()); + Ok("?".to_string()) + } Some(ValueType::Primitive(value)) => { - let primitive: Primitive = value.try_into()?; - primitive.to_sql_value() + let primitive: Primitive = (value.clone()).try_into()?; + bind_values.push(primitive.to_sql_value()); + Ok("?".to_string()) } - None => return Err(QueryError::MissingParam("value_type".into()).into()), - }; - bind_values.push(comparison_value); + Some(ValueType::List(values)) => Ok(format!( + "({})", + values + .values + .iter() + .map(|v| prepare_comparison(v, bind_values)) + .collect::, Error>>()? + .join(", ") + )), + None => Err(QueryError::MissingParam("value_type".into()).into()), + } + } + let value = prepare_comparison(&value, &mut bind_values)?; let model = member.model.clone(); // Get or create unique alias for this model @@ -1394,7 +1439,7 @@ fn build_composite_clause( // Use the column name directly since it's already flattened where_clauses - .push(format!("([{alias}].[{}] {comparison_operator} ?)", member.member)); + .push(format!("([{alias}].[{}] {comparison_operator} {value})", member.member)); } ClauseType::Composite(nested) => { // Handle nested composite by recursively building the clause diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index aec1604b62..e9fe1fca42 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -187,6 +187,8 @@ pub enum ComparisonOperator { Gte, Lt, Lte, + In, + NotIn, } impl fmt::Display for ComparisonOperator { @@ -198,6 +200,8 @@ impl fmt::Display for ComparisonOperator { ComparisonOperator::Lte => write!(f, "<="), ComparisonOperator::Neq => write!(f, "!="), ComparisonOperator::Eq => write!(f, "="), + ComparisonOperator::In => write!(f, "IN"), + ComparisonOperator::NotIn => write!(f, "NOT IN"), } } } @@ -211,6 +215,8 @@ impl From for ComparisonOperator { proto::types::ComparisonOperator::Lt => ComparisonOperator::Lt, proto::types::ComparisonOperator::Lte => ComparisonOperator::Lte, proto::types::ComparisonOperator::Neq => ComparisonOperator::Neq, + proto::types::ComparisonOperator::In => ComparisonOperator::In, + proto::types::ComparisonOperator::NotIn => ComparisonOperator::NotIn, } } }