Skip to content

Commit

Permalink
sui-graphql-client: simplify pagination filter usage (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-mysten authored Oct 22, 2024
1 parent c5a25ce commit 494b3a6
Showing 1 changed file with 52 additions and 55 deletions.
107 changes: 52 additions & 55 deletions crates/sui-graphql-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<T> Page<T> {
Self { page_info, data }
}

/// Check if the page has data.
/// Check if the page has no data.
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
Expand All @@ -122,29 +122,22 @@ impl<T> Page<T> {
}

/// Pagination direction.
#[derive(Default)]
pub enum Direction {
#[default]
Forward,
Backward,
}

/// Pagination options for querying the GraphQL server. It defaults to forward pagination with the
/// GraphQL server's default items per page limit.
#[derive(Default)]
pub struct PaginationFilter<'a> {
direction: Direction,
cursor: Option<&'a str>,
limit: Option<i32>,
}

impl Default for PaginationFilter<'_> {
fn default() -> Self {
Self {
direction: Direction::Forward,
cursor: None,
limit: None,
}
}
}

/// The GraphQL client for interacting with the Sui blockchain.
/// By default, it uses the `reqwest` crate as the HTTP client.
pub struct Client {
Expand Down Expand Up @@ -204,6 +197,28 @@ impl Client {
self.rpc.as_str()
}

/// Internal function to handle pagination filters and return the appropriate values.
fn pagination_filter<'a>(
&self,
pagination_filter: PaginationFilter<'a>,
) -> (Option<&'a str>, Option<&'a str>, Option<i32>, Option<i32>) {
let (after, before, first, last) = match pagination_filter.direction {
Direction::Forward => (
pagination_filter.cursor,
None,
pagination_filter.limit,
None,
),
Direction::Backward => (
None,
pagination_filter.cursor,
None,
pagination_filter.limit,
),
};
(after, before, first, last)
}

/// Run a query on the GraphQL server and return the response.
/// This method returns [`cynic::GraphQlResponse`] over the query type `T`, and it is
/// intended to be used with custom queries.
Expand Down Expand Up @@ -289,13 +304,9 @@ impl Client {
pub async fn active_validators<'a>(
&self,
epoch: Option<u64>,
pagination_filter: Option<PaginationFilter<'a>>,
pagination_filter: PaginationFilter<'a>,
) -> Result<Page<Validator>, Error> {
let pagination = pagination_filter.unwrap_or_default();
let (after, before, first, last) = match pagination.direction {
Direction::Forward => (pagination.cursor, None, pagination.limit, None),
Direction::Backward => (None, pagination.cursor, None, pagination.limit),
};
let (after, before, first, last) = self.pagination_filter(pagination_filter);

let operation = ActiveValidatorsQuery::build(ActiveValidatorsArgs {
id: epoch,
Expand Down Expand Up @@ -371,7 +382,7 @@ impl Client {
&self,
owner: Address,
coin_type: Option<&str>,
pagination_filter: Option<PaginationFilter<'a>>,
pagination_filter: PaginationFilter<'a>,
) -> Result<Page<Coin>, Error> {
let response = self
.objects(
Expand Down Expand Up @@ -412,13 +423,12 @@ impl Client {
object_ids: None,
object_keys: None,
}),
Some(PaginationFilter {
PaginationFilter {
cursor: after.as_deref(),
..Default::default()
}),
},
).await?;

if !response.is_empty() {
for object in response.data() {
if let Some(coin) = Coin::try_from_object(object) {
yield coin.into_owned();
Expand All @@ -430,9 +440,6 @@ impl Client {
} else {
break;
}
} else {
break;
}
}
})
}
Expand Down Expand Up @@ -494,20 +501,18 @@ impl Client {
}

/// Get a page of [`CheckpointSummary`] for the provided parameters.
pub async fn checkpoints(
pub async fn checkpoints<'a>(
&self,
after: Option<&str>,
before: Option<&str>,
first: Option<i32>,
last: Option<i32>,
pagination_filter: PaginationFilter<'a>,
) -> Result<Option<Page<CheckpointSummary>>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);

let operation = CheckpointsQuery::build(CheckpointsArgs {
after,
before,
first,
last,
});

let response = self.run_query(&operation).await?;

if let Some(errors) = response.errors {
Expand Down Expand Up @@ -593,13 +598,9 @@ impl Client {
pub async fn events(
&self,
filter: Option<EventFilter>,
pagination_filter: Option<PaginationFilter<'_>>,
pagination_filter: PaginationFilter<'_>,
) -> Result<Page<Event>, Error> {
let pagination = pagination_filter.unwrap_or_default();
let (after, before, first, last) = match pagination.direction {
Direction::Forward => (pagination.cursor, None, pagination.limit, None),
Direction::Backward => (None, pagination.cursor, None, pagination.limit),
};
let (after, before, first, last) = self.pagination_filter(pagination_filter);

let operation = EventsQuery::build(EventsQueryArgs {
filter,
Expand Down Expand Up @@ -695,14 +696,9 @@ impl Client {
pub async fn objects(
&self,
filter: Option<ObjectFilter<'_>>,
pagination_filter: Option<PaginationFilter<'_>>,
pagination_filter: PaginationFilter<'_>,
) -> Result<Page<Object>, Error> {
let pagination = pagination_filter.unwrap_or_default();
let (after, before, first, last) = match pagination.direction {
Direction::Forward => (pagination.cursor, None, pagination.limit, None),
Direction::Backward => (None, pagination.cursor, None, pagination.limit),
};

let (after, before, first, last) = self.pagination_filter(pagination_filter);
let operation = ObjectsQuery::build(ObjectsQueryArgs {
after,
before,
Expand Down Expand Up @@ -870,13 +866,9 @@ impl Client {
pub async fn transactions<'a>(
&self,
filter: Option<TransactionsFilter<'a>>,
pagination_filter: Option<PaginationFilter<'a>>,
pagination_filter: PaginationFilter<'a>,
) -> Result<Page<SignedTransaction>, Error> {
let pagination = pagination_filter.unwrap_or_default();
let (after, before, first, last) = match pagination.direction {
Direction::Forward => (pagination.cursor, None, pagination.limit, None),
Direction::Backward => (None, pagination.cursor, None, pagination.limit),
};
let (after, before, first, last) = self.pagination_filter(pagination_filter);

let operation = TransactionBlocksQuery::build(TransactionBlocksQueryArgs {
after,
Expand Down Expand Up @@ -942,6 +934,7 @@ mod tests {
use futures::StreamExt;

use crate::Client;
use crate::PaginationFilter;
use crate::DEVNET_HOST;
use crate::LOCAL_HOST;
use crate::MAINNET_HOST;
Expand Down Expand Up @@ -1028,7 +1021,9 @@ mod tests {
async fn test_active_validators() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let av = client.active_validators(None, None).await;
let av = client
.active_validators(None, PaginationFilter::default())
.await;
assert!(
av.is_ok(),
"Active validators query failed for network: {n}. Error: {}",
Expand Down Expand Up @@ -1067,7 +1062,7 @@ mod tests {
async fn test_checkpoints_query() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let c = client.checkpoints(None, None, None, Some(5)).await;
let c = client.checkpoints(PaginationFilter::default()).await;
assert!(
c.is_ok(),
"Checkpoints query failed for network: {n}. Error: {}",
Expand Down Expand Up @@ -1133,7 +1128,7 @@ mod tests {
async fn test_events_query() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let events = client.events(None, None).await;
let events = client.events(None, PaginationFilter::default()).await;
assert!(
events.is_ok(),
"Events query failed for network: {n}. Error: {}",
Expand All @@ -1151,7 +1146,7 @@ mod tests {
async fn test_objects_query() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let objects = client.objects(None, None).await;
let objects = client.objects(None, PaginationFilter::default()).await;
assert!(
objects.is_ok(),
"Objects query failed for network: {n}. Error: {}",
Expand Down Expand Up @@ -1191,7 +1186,9 @@ mod tests {
async fn test_coins_query() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let coins = client.coins("0x1".parse().unwrap(), None, None).await;
let coins = client
.coins("0x1".parse().unwrap(), None, PaginationFilter::default())
.await;
assert!(
coins.is_ok(),
"Coins query failed for network: {n}. Error: {}",
Expand All @@ -1217,7 +1214,7 @@ mod tests {
async fn test_transactions_query() {
for (n, _) in NETWORKS {
let client = Client::new(n).unwrap();
let transactions = client.transactions(None, None).await;
let transactions = client.transactions(None, PaginationFilter::default()).await;
assert!(
transactions.is_ok(),
"Transactions query failed for network: {n}. Error: {}",
Expand Down

0 comments on commit 494b3a6

Please sign in to comment.