diff --git a/crates/torii/core/src/model.rs b/crates/torii/core/src/model.rs index cee6ed29ac..4abd5c923a 100644 --- a/crates/torii/core/src/model.rs +++ b/crates/torii/core/src/model.rs @@ -125,7 +125,6 @@ pub fn build_sql_query( order_by: Option<&str>, limit: Option, offset: Option, - internal_updated_at: u64, ) -> Result<(String, String), Error> { fn collect_columns(table_prefix: &str, path: &str, ty: &Ty, selections: &mut Vec) { match ty { @@ -174,7 +173,6 @@ pub fn build_sql_query( selections.push(format!("{}.id", table_name)); selections.push(format!("{}.keys", table_name)); - let mut internal_updated_at_clause = Vec::with_capacity(schemas.len()); // Process each model schema for model in schemas { let model_table = model.name(); @@ -183,10 +181,6 @@ pub fn build_sql_query( [{model_table}].{entity_relation_column}", )); - if internal_updated_at > 0 { - internal_updated_at_clause.push(format!("[{model_table}].internal_updated_at >= ?")); - } - // Collect columns with table prefix collect_columns(&model_table, "", model, &mut selections); } @@ -204,18 +198,6 @@ pub fn build_sql_query( count_query += &format!(" WHERE {}", where_clause); } - if !internal_updated_at_clause.is_empty() { - if where_clause.is_none() { - query += " WHERE "; - count_query += " WHERE "; - } else { - query += " AND "; - count_query += " AND "; - } - query += &format!(" {}", internal_updated_at_clause.join(" AND ")); - count_query += &format!(" {}", internal_updated_at_clause.join(" AND ")); - } - // Use custom order by if provided, otherwise default to event_id DESC if let Some(order_clause) = order_by { query += &format!(" ORDER BY {}", order_clause); @@ -513,7 +495,6 @@ mod tests { None, None, None, - 0, ) .unwrap(); diff --git a/crates/torii/grpc/proto/types.proto b/crates/torii/grpc/proto/types.proto index 1568128ca2..5080e5f83a 100644 --- a/crates/torii/grpc/proto/types.proto +++ b/crates/torii/grpc/proto/types.proto @@ -76,7 +76,7 @@ message Query { bool dont_include_hashed_keys = 4; repeated OrderBy order_by = 5; repeated string entity_models = 6; - uint64 internal_updated_at = 7; + uint64 entity_updated_after = 7; } message EventQuery { diff --git a/crates/torii/grpc/src/server/mod.rs b/crates/torii/grpc/src/server/mod.rs index 8e36402d3c..8a272cae8d 100644 --- a/crates/torii/grpc/src/server/mod.rs +++ b/crates/torii/grpc/src/server/mod.rs @@ -230,7 +230,7 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, + entity_updated_after: Option, ) -> Result<(Vec, u32), Error> { self.query_by_hashed_keys( table, @@ -242,7 +242,7 @@ impl DojoWorld { dont_include_hashed_keys, order_by, entity_models, - internal_updated_at, + entity_updated_after, ) .await } @@ -270,7 +270,6 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, ) -> Result, Error> { let entity_models = entity_models.iter().map(|tag| compute_selector_from_tag(tag)).collect::>(); @@ -359,20 +358,9 @@ impl DojoWorld { order_by, None, None, - internal_updated_at, )?; - let mut query = sqlx::query(&entity_query).bind(models_str); - if internal_updated_at > 0 { - for _ in 0..schemas.len() { - let time = DateTime::::from_timestamp(internal_updated_at as i64, 0) - .ok_or_else(|| { - Error::from(QueryError::InvalidTimestamp(internal_updated_at)) - })? - .to_rfc3339(); - query = query.bind(time.clone()); - } - } + let query = sqlx::query(&entity_query).bind(models_str); let rows = query.fetch_all(&mut *tx).await?; let schemas = Arc::new(schemas); @@ -448,20 +436,32 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, + entity_updated_after: Option, ) -> Result<(Vec, u32), Error> { - // TODO: use prepared statement for where clause - let filter_ids = match hashed_keys { + let where_clause = match &hashed_keys { Some(hashed_keys) => { let ids = hashed_keys .hashed_keys .iter() - .map(|id| Ok(format!("{table}.id = '{:#x}'", Felt::from_bytes_be_slice(id)))) + .map(|_| Ok("{table}.id = ?")) .collect::, Error>>()?; - - format!("WHERE {}", ids.join(" OR ")) + format!( + "WHERE {} {}", + ids.join(" OR "), + if entity_updated_after.is_some() { + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } + ) + } + None => { + if entity_updated_after.is_some() { + format!("WHERE {table}.updated_at >= ?") + } else { + String::new() + } } - None => String::new(), }; // count query that matches filter_ids @@ -469,12 +469,23 @@ impl DojoWorld { r#" SELECT count(*) FROM {table} - {filter_ids} + {where_clause} "# ); + // total count of rows without limit and offset - let total_count: u32 = - sqlx::query_scalar(&count_query).fetch_optional(&self.pool).await?.unwrap_or(0); + let mut count_query = sqlx::query_scalar(&count_query); + if let Some(hashed_keys) = &hashed_keys { + for key in &hashed_keys.hashed_keys { + let key = Felt::from_bytes_be_slice(key); + count_query = count_query.bind(format!("{:#x}", key)); + } + } + + if let Some(entity_updated_after) = entity_updated_after.clone() { + count_query = count_query.bind(entity_updated_after); + } + let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); if total_count == 0 { return Ok((Vec::new(), 0)); } @@ -486,7 +497,7 @@ impl DojoWorld { SELECT {table}.id, {table}.data, {table}.model_id, group_concat({model_relation_table}.model_id) as model_ids FROM {table} JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - {filter_ids} + {where_clause} GROUP BY {table}.event_id ORDER BY {table}.event_id DESC "# @@ -497,7 +508,7 @@ impl DojoWorld { SELECT {table}.id, group_concat({model_relation_table}.model_id) as model_ids FROM {table} JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id - {filter_ids} + {where_clause} GROUP BY {table}.id ORDER BY {table}.event_id DESC "# @@ -518,8 +529,19 @@ impl DojoWorld { return Ok((entities, total_count)); } - let db_entities: Vec<(String, String)> = - sqlx::query_as(&query).bind(limit).bind(offset).fetch_all(&self.pool).await?; + let mut query = sqlx::query_as(&query); + if let Some(hashed_keys) = hashed_keys { + for key in hashed_keys.hashed_keys { + let key = Felt::from_bytes_be_slice(&key); + query = query.bind(format!("{:#x}", key)); + } + } + + if let Some(entity_updated_after) = entity_updated_after.clone() { + query = query.bind(entity_updated_after); + } + query = query.bind(limit).bind(offset); + let db_entities: Vec<(String, String)> = query.fetch_all(&self.pool).await?; let entities = self .fetch_entities( @@ -529,7 +551,6 @@ impl DojoWorld { dont_include_hashed_keys, order_by, entity_models, - internal_updated_at, ) .await?; Ok((entities, total_count)) @@ -547,7 +568,7 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, + entity_updated_after: Option, ) -> Result<(Vec, u32), Error> { let keys_pattern = build_keys_pattern(keys_clause)?; @@ -582,20 +603,33 @@ impl DojoWorld { JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id WHERE {model_relation_table}.model_id IN ({}) AND {table}.keys REGEXP ? + {} "#, - model_ids_str + model_ids_str, + if entity_updated_after.is_some() { + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } ) } else { format!( r#" WHERE {table}.keys REGEXP ? - "# + {} + "#, + if entity_updated_after.is_some() { + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } ) } ); let total_count = sqlx::query_scalar(&count_query) .bind(&keys_pattern) + .bind(entity_updated_after.clone()) .fetch_optional(&self.pool) .await? .unwrap_or(0); @@ -610,8 +644,14 @@ impl DojoWorld { FROM {table} JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id WHERE {table}.keys REGEXP ? + {} GROUP BY {table}.event_id "#, + if entity_updated_after.is_some() { + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } ) } else { format!( @@ -620,8 +660,14 @@ impl DojoWorld { FROM {table} JOIN {model_relation_table} ON {table}.id = {model_relation_table}.entity_id WHERE {table}.keys REGEXP ? + {} GROUP BY {table}.id "#, + if entity_updated_after.is_some() { + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } ) }; @@ -661,12 +707,12 @@ impl DojoWorld { return Ok((entities, total_count)); } - let db_entities: Vec<(String, String)> = sqlx::query_as(&models_query) - .bind(&keys_pattern) - .bind(limit) - .bind(offset) - .fetch_all(&self.pool) - .await?; + let mut query = sqlx::query_as(&models_query).bind(&keys_pattern); + if let Some(entity_updated_after) = entity_updated_after.clone() { + query = query.bind(entity_updated_after); + } + query = query.bind(limit).bind(offset); + let db_entities: Vec<(String, String)> = query.fetch_all(&self.pool).await?; let entities = self .fetch_entities( @@ -676,7 +722,6 @@ impl DojoWorld { dont_include_hashed_keys, order_by, entity_models, - internal_updated_at, ) .await?; Ok((entities, total_count)) @@ -721,7 +766,7 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, + entity_updated_after: Option, ) -> Result<(Vec, u32), Error> { let entity_models = entity_models.iter().map(|model| compute_selector_from_tag(model)).collect::>(); @@ -777,8 +822,11 @@ impl DojoWorld { self.model_cache.models(&model_ids).await?.into_iter().map(|m| m.schema).collect(); // Use the member name directly as the column name since it's already flattened - let where_clause = + let mut where_clause = format!("[{}].[{}] {comparison_operator} ?", member_clause.model, member_clause.member); + if entity_updated_after.is_some() { + where_clause += &format!(" AND {table}.updated_at >= ?"); + } let (entity_query, count_query) = build_sql_query( &schemas, @@ -788,20 +836,21 @@ impl DojoWorld { order_by, limit, offset, - internal_updated_at, )?; let total_count = sqlx::query_scalar(&count_query) .bind(comparison_value.clone()) + .bind(entity_updated_after.clone()) .fetch_optional(&self.pool) .await? .unwrap_or(0); - let db_entities = sqlx::query(&entity_query) - .bind(comparison_value) - .bind(limit) - .bind(offset) - .fetch_all(&self.pool) - .await?; + + let mut query = sqlx::query(&entity_query).bind(comparison_value); + if let Some(entity_updated_after) = entity_updated_after.clone() { + query = query.bind(entity_updated_after); + } + query = query.bind(limit).bind(offset); + let db_entities = query.fetch_all(&self.pool).await?; let entities_collection: Result, Error> = db_entities .par_iter() @@ -822,10 +871,10 @@ impl DojoWorld { dont_include_hashed_keys: bool, order_by: Option<&str>, entity_models: Vec, - internal_updated_at: u64, + entity_updated_after: Option, ) -> Result<(Vec, u32), Error> { let (where_clause, having_clause, join_clause, bind_values) = - build_composite_clause(table, model_relation_table, &composite)?; + build_composite_clause(table, model_relation_table, &composite, entity_updated_after)?; let count_query = if !having_clause.is_empty() { format!( @@ -839,7 +888,7 @@ impl DojoWorld { GROUP BY [{table}].id {having_clause} ) as filtered_count - "# + "#, ) } else { format!( @@ -849,7 +898,7 @@ impl DojoWorld { JOIN {model_relation_table} ON [{table}].id = {model_relation_table}.entity_id {join_clause} {where_clause} - "# + "#, ) }; @@ -857,7 +906,6 @@ impl DojoWorld { for value in &bind_values { count_query = count_query.bind(value); } - let total_count = count_query.fetch_optional(&self.pool).await?.unwrap_or(0); if total_count == 0 { return Ok((Vec::new(), 0)); @@ -881,7 +929,7 @@ impl DojoWorld { for value in &bind_values { db_query = db_query.bind(value); } - db_query = db_query.bind(limit.unwrap_or(u32::MAX)).bind(offset.unwrap_or(0)); + db_query = db_query.bind(limit).bind(offset); let db_entities: Vec<(String, String)> = db_query.fetch_all(&self.pool).await?; @@ -893,7 +941,6 @@ impl DojoWorld { dont_include_hashed_keys, order_by, entity_models, - internal_updated_at, ) .await?; Ok((entities, total_count)) @@ -1051,6 +1098,23 @@ impl DojoWorld { let order_by = if order_by.is_empty() { None } else { Some(order_by.as_str()) }; + let entity_updated_after = match query.entity_updated_after { + 0 => None, + _ => Some( + // This conversion would include a `UTC` suffix, which is not valid for the SQL + // query when comparing the timestamp with equality. + // To have `>=` working, we need to remove the `UTC` suffix. + DateTime::::from_timestamp(query.entity_updated_after as i64, 0) + .ok_or_else(|| { + Error::from(QueryError::InvalidTimestamp(query.entity_updated_after)) + })? + .to_string() + .replace("UTC", "") + .trim() + .to_string(), + ), + }; + let (entities, total_count) = match query.clause { None => { self.entities_all( @@ -1062,7 +1126,7 @@ impl DojoWorld { query.dont_include_hashed_keys, order_by, query.entity_models, - query.internal_updated_at, + entity_updated_after, ) .await? } @@ -1086,7 +1150,7 @@ impl DojoWorld { query.dont_include_hashed_keys, order_by, query.entity_models, - query.internal_updated_at, + entity_updated_after, ) .await? } @@ -1101,7 +1165,7 @@ impl DojoWorld { query.dont_include_hashed_keys, order_by, query.entity_models, - query.internal_updated_at, + entity_updated_after, ) .await? } @@ -1116,7 +1180,7 @@ impl DojoWorld { query.dont_include_hashed_keys, order_by, query.entity_models, - query.internal_updated_at, + entity_updated_after, ) .await? } @@ -1131,7 +1195,7 @@ impl DojoWorld { query.dont_include_hashed_keys, order_by, query.entity_models, - query.internal_updated_at, + entity_updated_after, ) .await? } @@ -1248,6 +1312,7 @@ fn build_composite_clause( table: &str, model_relation_table: &str, composite: &proto::types::CompositeClause, + entity_updated_after: Option, ) -> Result<(String, String, String, Vec), Error> { let is_or = composite.operator == LogicalOperator::Or as i32; let mut where_clauses = Vec::new(); @@ -1334,7 +1399,12 @@ fn build_composite_clause( ClauseType::Composite(nested) => { // Handle nested composite by recursively building the clause let (nested_where, nested_having, nested_join, nested_values) = - build_composite_clause(table, model_relation_table, nested)?; + build_composite_clause( + table, + model_relation_table, + nested, + entity_updated_after.clone(), + )?; if !nested_where.is_empty() { where_clauses.push(format!("({})", nested_where.trim_start_matches("WHERE "))); @@ -1356,10 +1426,23 @@ fn build_composite_clause( let join_clause = join_clauses.join(" "); let where_clause = if !where_clauses.is_empty() { - format!("WHERE {}", where_clauses.join(if is_or { " OR " } else { " AND " })) + format!( + "WHERE {} {}", + where_clauses.join(if is_or { " OR " } else { " AND " }), + if let Some(entity_updated_after) = entity_updated_after.clone() { + bind_values.push(entity_updated_after); + format!("AND {table}.updated_at >= ?") + } else { + String::new() + } + ) + } else if let Some(entity_updated_after) = entity_updated_after.clone() { + bind_values.push(entity_updated_after); + format!("WHERE {table}.updated_at >= ?") } else { String::new() }; + let having_clause = if !having_clauses.is_empty() { format!("HAVING {}", having_clauses.join(if is_or { " OR " } else { " AND " })) } else { diff --git a/crates/torii/grpc/src/server/tests/entities_test.rs b/crates/torii/grpc/src/server/tests/entities_test.rs index 58a795421f..22a38fda91 100644 --- a/crates/torii/grpc/src/server/tests/entities_test.rs +++ b/crates/torii/grpc/src/server/tests/entities_test.rs @@ -143,7 +143,7 @@ async fn test_entities_queries(sequencer: &RunnerCtx) { false, None, vec![], - 0, + None, ) .await .unwrap() diff --git a/crates/torii/grpc/src/types/mod.rs b/crates/torii/grpc/src/types/mod.rs index 35fe1897f7..aec1604b62 100644 --- a/crates/torii/grpc/src/types/mod.rs +++ b/crates/torii/grpc/src/types/mod.rs @@ -103,6 +103,8 @@ pub struct Query { pub clause: Option, pub limit: u32, pub offset: u32, + /// Whether or not to include the hashed keys (entity id) of the entities. + /// This is useful for large queries compressed with GZIP to reduce the size of the response. pub dont_include_hashed_keys: bool, pub order_by: Vec, /// If the array is not empty, only the given models are retrieved. @@ -110,7 +112,7 @@ pub struct Query { pub entity_models: Vec, /// The internal updated at timestamp in seconds (unix timestamp) from which entities are /// retrieved (inclusive). Use 0 to retrieve all entities. - pub internal_updated_at: u64, + pub entity_updated_after: u64, } #[derive(Debug, Serialize, Deserialize, PartialEq, Hash, Eq, Clone)] @@ -276,7 +278,7 @@ impl From for proto::types::Query { dont_include_hashed_keys: value.dont_include_hashed_keys, order_by: value.order_by.into_iter().map(|o| o.into()).collect(), entity_models: value.entity_models, - internal_updated_at: value.internal_updated_at, + entity_updated_after: value.entity_updated_after, } } }