Skip to content

Commit

Permalink
sui-graphql-client: add generic stream impl (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-mysten authored Oct 31, 2024
1 parent 12a24f9 commit c2d3080
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 76 deletions.
191 changes: 120 additions & 71 deletions crates/sui-graphql-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

pub mod faucet;
pub mod query_types;
pub mod streams;

use query_types::ActiveValidatorsArgs;
use query_types::ActiveValidatorsQuery;
Expand Down Expand Up @@ -61,6 +62,7 @@ use query_types::TransactionBlocksQueryArgs;
use query_types::TransactionMetadata;
use query_types::TransactionsFilter;
use query_types::Validator;
use streams::stream_paginated_query;

use sui_types::types::framework::Coin;
use sui_types::types::Address;
Expand Down Expand Up @@ -90,11 +92,12 @@ use futures::Stream;
use reqwest::Url;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::pin::Pin;
use std::str::FromStr;

use crate::query_types::CheckpointTotalTxQuery;

const DEFAULT_ITEMS_PER_PAGE: i32 = 10;

const MAINNET_HOST: &str = "https://sui-mainnet.mystenlabs.com/graphql";
const TESTNET_HOST: &str = "https://sui-testnet.mystenlabs.com/graphql";
const DEVNET_HOST: &str = "https://sui-devnet.mystenlabs.com/graphql";
Expand All @@ -114,7 +117,7 @@ pub struct DryRunResult {
}

/// The name part of a dynamic field, including its type, bcs, and json representation.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct DynamicFieldName {
/// The type name of this dynamic field name
pub type_: TypeTag,
Expand All @@ -126,7 +129,7 @@ pub struct DynamicFieldName {

/// The output of a dynamic field query, that includes the name, value, and value's json
/// representation.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct DynamicFieldOutput {
/// The name of the dynamic field
pub name: DynamicFieldName,
Expand All @@ -142,7 +145,7 @@ pub struct NameValue(Vec<u8>);
/// Helper struct for passing a raw bcs value.
pub struct BcsName(pub Vec<u8>);

#[derive(Debug)]
#[derive(Clone, Debug)]
/// A page of items returned by the GraphQL server.
pub struct Page<T> {
/// Information about the page, such as the cursor and whether there are more pages.
Expand Down Expand Up @@ -176,10 +179,14 @@ impl<T> Page<T> {
fn new_empty() -> Self {
Self::new(PageInfo::default(), vec![])
}

pub fn into_parts(self) -> (PageInfo, Vec<T>) {
(self.page_info, self.data)
}
}

/// Pagination direction.
#[derive(Default)]
#[derive(Clone, Debug, Default)]
pub enum Direction {
#[default]
Forward,
Expand All @@ -188,7 +195,7 @@ pub enum Direction {

/// Pagination options for querying the GraphQL server. It defaults to forward pagination with the
/// GraphQL server's default items per page limit.
#[derive(Default)]
#[derive(Clone, Debug, Default)]
pub struct PaginationFilter {
/// The direction of pagination.
pub direction: Direction,
Expand Down Expand Up @@ -308,23 +315,23 @@ impl Client {
}

/// Internal function to handle pagination filters and return the appropriate values.
fn pagination_filter(
async fn pagination_filter(
&self,
pagination_filter: PaginationFilter,
) -> (Option<String>, Option<String>, Option<i32>, Option<i32>) {
let limit = if let Some(limit) = pagination_filter.limit {
limit
} else {
let cfg = self.service_config().await;
if let Ok(cfg) = cfg {
cfg.max_page_size
} else {
DEFAULT_ITEMS_PER_PAGE
}
};
let (after, before, first, last) = match pagination_filter.direction {
Direction::Forward => (
pagination_filter.cursor,
None,
pagination_filter.limit,
None,
),
Direction::Backward => (
None,
pagination_filter.cursor,
None,
pagination_filter.limit,
),
Direction::Forward => (pagination_filter.cursor, None, Some(limit), None),
Direction::Backward => (None, pagination_filter.cursor, None, Some(limit)),
};
(after, before, first, last)
}
Expand Down Expand Up @@ -416,7 +423,7 @@ impl Client {
epoch: Option<u64>,
pagination_filter: PaginationFilter,
) -> Result<Page<Validator>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;

let operation = ActiveValidatorsQuery::build(ActiveValidatorsArgs {
id: epoch,
Expand Down Expand Up @@ -573,41 +580,20 @@ impl Client {
))
}

/// Stream of coins for the specified address and coin type.
pub fn coins_stream<'a>(
&'a self,
owner: Address,
coin_type: Option<&'a str>,
) -> Pin<Box<dyn Stream<Item = Result<Coin, Error>> + 'a>> {
Box::pin(async_stream::try_stream! {
let mut after = None;
loop {
let response = self.objects(
Some(ObjectFilter {
type_: Some(coin_type.unwrap_or("0x2::coin::Coin")),
owner: Some(owner),
object_ids: None,
object_keys: None,
}),
PaginationFilter {
cursor: after,
..Default::default()
},
).await?;

for object in response.data() {
if let Some(coin) = Coin::try_from_object(object) {
yield coin.into_owned();
}
}

if let Some(end_cursor) = response.page_info.end_cursor {
after = Some(end_cursor);
} else {
break;
}
}
})
/// Get the list of coins for the specified address as a stream.
///
/// If `coin_type` is not provided, it will default to `0x2::coin::Coin`, which will return all
/// coins. For SUI coin, pass in the coin type: `0x2::coin::Coin<0x2::sui::SUI>`.
pub async fn coins_stream(
&self,
address: Address,
coin_type: Option<&'static str>,
streaming_direction: Direction,
) -> impl Stream<Item = Result<Coin, Error>> {
stream_paginated_query(
move |filter| self.coins(address, coin_type, filter),
streaming_direction,
)
}

/// Get the coin metadata for the coin type.
Expand Down Expand Up @@ -670,8 +656,8 @@ impl Client {
pub async fn checkpoints<'a>(
&self,
pagination_filter: PaginationFilter,
) -> Result<Option<Page<CheckpointSummary>>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
) -> Result<Page<CheckpointSummary>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;

let operation = CheckpointsQuery::build(CheckpointsArgs {
after: after.as_deref(),
Expand All @@ -694,12 +680,21 @@ impl Client {
.map(|c| c.try_into())
.collect::<Result<Vec<CheckpointSummary>, _>>()?;

Ok(Some(Page::new(page_info, nodes)))
Ok(Page::new(page_info, nodes))
} else {
Ok(None)
Ok(Page::new_empty())
}
}

/// Get a stream of [`CheckpointSummary`]. Note that this will fetch all checkpoints which may
/// trigger a lot of requests.
pub async fn checkpoints_stream(
&self,
streaming_direction: Direction,
) -> impl Stream<Item = Result<CheckpointSummary, Error>> + '_ {
stream_paginated_query(move |filter| self.checkpoints(filter), streaming_direction)
}

/// Return the sequence number of the latest checkpoint that has been executed.
pub async fn latest_checkpoint_sequence_number(
&self,
Expand Down Expand Up @@ -814,7 +809,7 @@ impl Client {
address: Address,
pagination_filter: PaginationFilter,
) -> Result<Page<DynamicFieldOutput>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;
let operation = DynamicFieldsOwnerQuery::build(DynamicFieldConnectionArgs {
address,
after: after.as_deref(),
Expand Down Expand Up @@ -843,6 +838,19 @@ impl Client {
))
}

/// Get a stream of dynamic fields for the provided address. Note that this will also fetch
/// dynamic fields on wrapped objects.
pub async fn dynamic_fields_stream(
&self,
address: Address,
streaming_direction: Direction,
) -> impl Stream<Item = Result<DynamicFieldOutput, Error>> + '_ {
stream_paginated_query(
move |filter| self.dynamic_fields(address, filter),
streaming_direction,
)
}

// ===========================================================================
// Epoch API
// ===========================================================================
Expand Down Expand Up @@ -894,13 +902,13 @@ impl Client {
// Events API
// ===========================================================================

/// Return a page of events based on the provided filters.
/// Return a page of events based on the (optional) event filter.
pub async fn events(
&self,
filter: Option<EventFilter>,
pagination_filter: PaginationFilter,
) -> Result<Page<Event>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;

let operation = EventsQuery::build(EventsQueryArgs {
filter,
Expand Down Expand Up @@ -937,6 +945,18 @@ impl Client {
}
}

/// Return a stream of events based on the (optional) event filter.
pub async fn events_stream(
&self,
filter: Option<EventFilter>,
streaming_direction: Direction,
) -> impl Stream<Item = Result<Event, Error>> + '_ {
stream_paginated_query(
move |pag_filter| self.events(filter.clone(), pag_filter),
streaming_direction,
)
}

// ===========================================================================
// Objects API
// ===========================================================================
Expand Down Expand Up @@ -998,7 +1018,7 @@ impl Client {
filter: Option<ObjectFilter<'_>>,
pagination_filter: PaginationFilter,
) -> Result<Page<Object>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;
let operation = ObjectsQuery::build(ObjectsQueryArgs {
after: after.as_deref(),
before: before.as_deref(),
Expand Down Expand Up @@ -1037,6 +1057,18 @@ impl Client {
}
}

/// Return a stream of objects based on the (optional) object filter.
pub async fn objects_stream<'a>(
&'a self,
filter: Option<ObjectFilter<'a>>,
streaming_direction: Direction,
) -> impl Stream<Item = Result<Object, Error>> + 'a {
stream_paginated_query(
move |pag_filter| self.objects(filter.clone(), pag_filter),
streaming_direction,
)
}

/// Return the object's bcs content [`Vec<u8>`] based on the provided [`Address`].
pub async fn object_bcs(&self, object_id: Address) -> Result<Option<Vec<u8>>, Error> {
let operation = ObjectQuery::build(ObjectQueryArgs {
Expand Down Expand Up @@ -1224,7 +1256,7 @@ impl Client {
filter: Option<TransactionsFilter<'a>>,
pagination_filter: PaginationFilter,
) -> Result<Page<SignedTransaction>, Error> {
let (after, before, first, last) = self.pagination_filter(pagination_filter);
let (after, before, first, last) = self.pagination_filter(pagination_filter).await;

let operation = TransactionBlocksQuery::build(TransactionBlocksQueryArgs {
after: after.as_deref(),
Expand Down Expand Up @@ -1252,6 +1284,18 @@ impl Client {
}
}

/// Get a stream of transactions based on the (optional) transaction filter.
pub async fn transactions_stream<'a>(
&'a self,
filter: Option<TransactionsFilter<'a>>,
streaming_direction: Direction,
) -> impl Stream<Item = Result<SignedTransaction, Error>> + 'a {
stream_paginated_query(
move |pag_filter| self.transactions(filter.clone(), pag_filter),
streaming_direction,
)
}

/// Execute a transaction.
pub async fn execute_tx(
&self,
Expand Down Expand Up @@ -1329,13 +1373,13 @@ impl Client {
pagination_filter_structs: PaginationFilter,
) -> Result<Option<MoveModule>, Error> {
let (after_enums, before_enums, first_enums, last_enums) =
self.pagination_filter(pagination_filter_enums);
self.pagination_filter(pagination_filter_enums).await;
let (after_friends, before_friends, first_friends, last_friends) =
self.pagination_filter(pagination_filter_friends);
self.pagination_filter(pagination_filter_friends).await;
let (after_functions, before_functions, first_functions, last_functions) =
self.pagination_filter(pagination_filter_functions);
self.pagination_filter(pagination_filter_functions).await;
let (after_structs, before_structs, first_structs, last_structs) =
self.pagination_filter(pagination_filter_structs);
self.pagination_filter(pagination_filter_structs).await;
let operation = NormalizedMoveModuleQuery::build(NormalizedMoveModuleQueryArgs {
package: Address::from_str(package)?,
module,
Expand Down Expand Up @@ -1412,6 +1456,7 @@ mod tests {
use crate::faucet::FaucetClient;
use crate::BcsName;
use crate::Client;
use crate::Direction;
use crate::PaginationFilter;
use crate::DEVNET_HOST;
use crate::LOCAL_HOST;
Expand Down Expand Up @@ -1679,6 +1724,7 @@ mod tests {

#[tokio::test]
async fn test_coins_stream() {
const NUM_COINS_FROM_FAUCET: usize = 5;
let client = test_client();
let faucet = match client.rpc_server() {
LOCAL_HOST => FaucetClient::local(),
Expand All @@ -1689,13 +1735,16 @@ mod tests {
let key = Ed25519PublicKey::generate(rand::thread_rng());
let address = key.to_address();
faucet.request_and_wait(address).await.unwrap();
let mut stream = client.coins_stream(address, None);
let mut stream = client
.coins_stream(address, None, Direction::default())
.await;
let mut num_coins = 0;

while let Some(result) = stream.next().await {
assert!(result.is_ok());
num_coins = 1;
num_coins += 1;
}
assert!(num_coins > 0);
assert!(num_coins == NUM_COINS_FROM_FAUCET);
}

#[tokio::test]
Expand Down
Loading

0 comments on commit c2d3080

Please sign in to comment.