diff --git a/crates/torii/graphql/src/object/connection/mod.rs b/crates/torii/graphql/src/object/connection/mod.rs index 787e36229e..802236ec07 100644 --- a/crates/torii/graphql/src/object/connection/mod.rs +++ b/crates/torii/graphql/src/object/connection/mod.rs @@ -7,6 +7,7 @@ use super::ObjectTrait; use crate::query::order::Order; use crate::query::value_mapping_from_row; use crate::types::{GraphqlType, TypeData, TypeMapping, ValueMapping}; +use crate::utils::parse_argument::ParseArgument; pub mod cursor; pub mod edge; @@ -14,10 +15,12 @@ pub mod page_info; #[derive(Debug)] pub struct ConnectionArguments { - pub first: Option, - pub last: Option, + pub first: Option, + pub last: Option, pub after: Option, pub before: Option, + pub offset: Option, + pub limit: Option, } pub struct ConnectionObject { @@ -59,11 +62,12 @@ impl ObjectTrait for ConnectionObject { } pub fn parse_connection_arguments(ctx: &ResolverContext<'_>) -> Result { - let first = ctx.args.try_get("first").and_then(|first| first.i64()).ok(); - let last = ctx.args.try_get("last").and_then(|last| last.i64()).ok(); - let after = ctx.args.try_get("after").and_then(|after| Ok(after.string()?.to_string())).ok(); - let before = - ctx.args.try_get("before").and_then(|before| Ok(before.string()?.to_string())).ok(); + let first: Option = ParseArgument::parse(ctx, "first").ok(); + let last: Option = ParseArgument::parse(ctx, "last").ok(); + let after: Option = ParseArgument::parse(ctx, "after").ok(); + let before: Option = ParseArgument::parse(ctx, "before").ok(); + let offset: Option = ParseArgument::parse(ctx, "offset").ok(); + let limit: Option = ParseArgument::parse(ctx, "limit").ok(); if first.is_some() && last.is_some() { return Err( @@ -77,19 +81,15 @@ pub fn parse_connection_arguments(ctx: &ResolverContext<'_>) -> Result Field { @@ -98,6 +98,8 @@ pub fn connection_arguments(field: Field) -> Field { .argument(InputValue::new("last", TypeRef::named(TypeRef::INT))) .argument(InputValue::new("before", TypeRef::named(GraphqlType::Cursor.to_string()))) .argument(InputValue::new("after", TypeRef::named(GraphqlType::Cursor.to_string()))) + .argument(InputValue::new("offset", TypeRef::named(TypeRef::INT))) + .argument(InputValue::new("limit", TypeRef::named(TypeRef::INT))) } pub fn connection_output( diff --git a/crates/torii/graphql/src/query/constants.rs b/crates/torii/graphql/src/query/constants.rs index 6e4f335f80..c14f5015fa 100644 --- a/crates/torii/graphql/src/query/constants.rs +++ b/crates/torii/graphql/src/query/constants.rs @@ -1,4 +1,4 @@ -pub const DEFAULT_LIMIT: i64 = 10; +pub const DEFAULT_LIMIT: u64 = 10; pub const BOOLEAN_TRUE: i64 = 1; pub const ENTITY_TABLE: &str = "entities"; diff --git a/crates/torii/graphql/src/query/data.rs b/crates/torii/graphql/src/query/data.rs index 889f10912e..d790cc8672 100644 --- a/crates/torii/graphql/src/query/data.rs +++ b/crates/torii/graphql/src/query/data.rs @@ -85,9 +85,10 @@ pub async fn fetch_multiple_rows( query.push_str(&format!(" WHERE {}", conditions.join(" AND "))); } + let limit = connection.first.or(connection.last).or(connection.limit).unwrap_or(DEFAULT_LIMIT); + // NOTE: Order is determined by the `order` param if provided, otherwise it's inferred from the // `first` or `last` param. Explicit ordering take precedence - let limit = connection.first.or(connection.last).unwrap_or(DEFAULT_LIMIT); match order { Some(order) => { let column_name = format!("external_{}", order.field); @@ -111,6 +112,10 @@ pub async fn fetch_multiple_rows( } }; + if let Some(offset) = connection.offset { + query.push_str(&format!(" OFFSET {}", offset)); + } + sqlx::query(&query).fetch_all(conn).await } diff --git a/crates/torii/graphql/src/tests/entities_test.rs b/crates/torii/graphql/src/tests/entities_test.rs index 1755742cc1..d27bb664de 100644 --- a/crates/torii/graphql/src/tests/entities_test.rs +++ b/crates/torii/graphql/src/tests/entities_test.rs @@ -6,7 +6,8 @@ mod tests { use torii_core::sql::Sql; use crate::tests::{ - entity_fixtures, paginate, run_graphql_query, Entity, Moves, Paginate, Position, + cursor_paginate, entity_fixtures, offset_paginate, run_graphql_query, Entity, Moves, + Paginate, Position, }; #[sqlx::test(migrations = "../migrations")] @@ -77,35 +78,59 @@ mod tests { } #[sqlx::test(migrations = "../migrations")] - async fn test_entities_pagination(pool: SqlitePool) { + async fn test_entities_cursor_pagination(pool: SqlitePool) { let mut db = Sql::new(pool.clone(), FieldElement::ZERO).await.unwrap(); entity_fixtures(&mut db).await; let page_size = 2; // Forward pagination - let entities_connection = paginate(&pool, None, Paginate::Forward, page_size).await; + let entities_connection = cursor_paginate(&pool, None, Paginate::Forward, page_size).await; assert_eq!(entities_connection.total_count, 3); assert_eq!(entities_connection.edges.len(), page_size); let cursor: String = entities_connection.edges[0].cursor.clone(); let next_cursor: String = entities_connection.edges[1].cursor.clone(); - let entities_connection = paginate(&pool, Some(cursor), Paginate::Forward, page_size).await; + let entities_connection = + cursor_paginate(&pool, Some(cursor), Paginate::Forward, page_size).await; assert_eq!(entities_connection.total_count, 3); assert_eq!(entities_connection.edges.len(), page_size); assert_eq!(entities_connection.edges[0].cursor, next_cursor); // Backward pagination - let entities_connection = paginate(&pool, None, Paginate::Backward, page_size).await; + let entities_connection = cursor_paginate(&pool, None, Paginate::Backward, page_size).await; assert_eq!(entities_connection.total_count, 3); assert_eq!(entities_connection.edges.len(), page_size); let cursor: String = entities_connection.edges[0].cursor.clone(); let next_cursor: String = entities_connection.edges[1].cursor.clone(); let entities_connection = - paginate(&pool, Some(cursor), Paginate::Backward, page_size).await; + cursor_paginate(&pool, Some(cursor), Paginate::Backward, page_size).await; assert_eq!(entities_connection.total_count, 3); assert_eq!(entities_connection.edges.len(), page_size); assert_eq!(entities_connection.edges[0].cursor, next_cursor); } + + #[sqlx::test(migrations = "../migrations")] + async fn test_entities_offset_pagination(pool: SqlitePool) { + let mut db = Sql::new(pool.clone(), FieldElement::ZERO).await.unwrap(); + entity_fixtures(&mut db).await; + + let limit = 3; + let mut offset = 0; + let entities_connection = offset_paginate(&pool, offset, limit).await; + let offset_plus_one = entities_connection.edges[1].node.model_names.clone(); + let offset_plus_two = entities_connection.edges[2].node.model_names.clone(); + assert_eq!(entities_connection.edges.len(), 3); + + offset = 1; + let entities_connection = offset_paginate(&pool, offset, limit).await; + assert_eq!(entities_connection.edges[0].node.model_names, offset_plus_one); + assert_eq!(entities_connection.edges.len(), 2); + + offset = 2; + let entities_connection = offset_paginate(&pool, offset, limit).await; + assert_eq!(entities_connection.edges[0].node.model_names, offset_plus_two); + assert_eq!(entities_connection.edges.len(), 1); + } } diff --git a/crates/torii/graphql/src/tests/mod.rs b/crates/torii/graphql/src/tests/mod.rs index 392a98e337..6ad0aa7cdc 100644 --- a/crates/torii/graphql/src/tests/mod.rs +++ b/crates/torii/graphql/src/tests/mod.rs @@ -301,7 +301,7 @@ pub async fn entity_fixtures(db: &mut Sql) { db.execute().await.unwrap(); } -pub async fn paginate( +pub async fn cursor_paginate( pool: &SqlitePool, cursor: Option, direction: Paginate, @@ -334,3 +334,26 @@ pub async fn paginate( let entities = value.get("entities").ok_or("entities not found").unwrap(); serde_json::from_value(entities.clone()).unwrap() } + +pub async fn offset_paginate(pool: &SqlitePool, offset: u64, limit: u64) -> Connection { + let query = format!( + " + {{ + entities (offset: {offset}, limit: {limit}) + {{ + total_count + edges {{ + cursor + node {{ + model_names + }} + }} + }} + }} + " + ); + + let value = run_graphql_query(pool, &query).await; + let entities = value.get("entities").ok_or("entities not found").unwrap(); + serde_json::from_value(entities.clone()).unwrap() +} diff --git a/crates/torii/graphql/src/tests/types-test/src/systems.cairo b/crates/torii/graphql/src/tests/types-test/src/systems.cairo index 1d72a36b59..be57506a21 100644 --- a/crates/torii/graphql/src/tests/types-test/src/systems.cairo +++ b/crates/torii/graphql/src/tests/types-test/src/systems.cairo @@ -1,4 +1,3 @@ -use dojo::world::{IWorldDispatcher, IWorldDispatcherTrait}; use starknet::{ContractAddress, ClassHash}; #[starknet::interface] @@ -18,7 +17,7 @@ mod records { fn create(self: @ContractState, num_records: u8) { let world = self.world_dispatcher.read(); let mut record_idx = 0; - + loop { if record_idx == num_records { break (); diff --git a/crates/torii/graphql/src/utils/parse_argument.rs b/crates/torii/graphql/src/utils/parse_argument.rs index e5246252bc..1076739eea 100644 --- a/crates/torii/graphql/src/utils/parse_argument.rs +++ b/crates/torii/graphql/src/utils/parse_argument.rs @@ -2,19 +2,19 @@ use async_graphql::dynamic::ResolverContext; use async_graphql::Result; pub trait ParseArgument: Sized { - fn parse(ctx: &ResolverContext<'_>, input: String) -> Result; + fn parse(ctx: &ResolverContext<'_>, input: &str) -> Result; } impl ParseArgument for u64 { - fn parse(ctx: &ResolverContext<'_>, input: String) -> Result { - let arg = ctx.args.try_get(input.as_str()); + fn parse(ctx: &ResolverContext<'_>, input: &str) -> Result { + let arg = ctx.args.try_get(input); arg?.u64() } } impl ParseArgument for String { - fn parse(ctx: &ResolverContext<'_>, input: String) -> Result { - let arg = ctx.args.try_get(input.as_str()); + fn parse(ctx: &ResolverContext<'_>, input: &str) -> Result { + let arg = ctx.args.try_get(input); Ok(arg?.string()?.to_string()) } }