Skip to content

Commit

Permalink
[chore][GraphQL/Limits] Separate QueryLimitChecker extension/factory (#…
Browse files Browse the repository at this point in the history
…18660)

## Description

Split up the extension factory from the extension itself, similar to
what we did for `Timeout` before. This avoids the confusion of the
single type being created with defaulted fields to act as the factory
and then creating new versions of itself to act as the extension.

## Test plan

CI

---

## Release notes

Check each box that your changes affect. If none of the boxes relate to
your changes, release notes aren't required.

For each box you select, include information after the relevant heading
that describes the impact of your changes that a user might notice and
any actions they must take to implement updates.

- [ ] Protocol:
- [ ] Nodes (Validators and Full nodes):
- [ ] Indexer:
- [ ] JSON-RPC:
- [ ] GraphQL:
- [ ] CLI:
- [ ] Rust SDK:
  • Loading branch information
amnn committed Jul 16, 2024
1 parent cb2f63e commit 3731f28
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 108 deletions.
206 changes: 102 additions & 104 deletions crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ use tokio::sync::Mutex;
use tracing::info;
use uuid::Uuid;

/// Extension factory for adding checks that the query is within configurable limits.
pub(crate) struct QueryLimitsChecker;

/// Only display usage information if this header was in the request.
pub(crate) struct ShowUsage;

Expand All @@ -39,7 +42,7 @@ struct ValidationRes {
}

#[derive(Debug, Default)]
pub(crate) struct QueryLimitsChecker {
struct QueryLimitsCheckerExt {
validation_result: Mutex<Option<ValidationRes>>,
}

Expand All @@ -64,7 +67,7 @@ impl headers::Header for ShowUsage {

impl ExtensionFactory for QueryLimitsChecker {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(QueryLimitsChecker {
Arc::new(QueryLimitsCheckerExt {
validation_result: Mutex::new(None),
})
}
Expand All @@ -90,7 +93,7 @@ impl std::ops::Add for ComponentCost {
}

#[async_trait::async_trait]
impl Extension for QueryLimitsChecker {
impl Extension for QueryLimitsCheckerExt {
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response {
let resp = next.run(ctx).await;
let validation_result = self.validation_result.lock().await.take();
Expand Down Expand Up @@ -179,7 +182,7 @@ impl Extension for QueryLimitsChecker {
}

running_costs.depth = 0;
self.analyze_selection_set(
analyze_selection_set(
&cfg.limits,
&doc.fragments,
sel_set,
Expand Down Expand Up @@ -220,118 +223,113 @@ impl Extension for QueryLimitsChecker {
}
}

impl QueryLimitsChecker {
/// Parse the selected fields in one operation and check if it conforms to configured limits.
fn analyze_selection_set(
&self,
limits: &Limits,
fragment_defs: &HashMap<Name, Positioned<FragmentDefinition>>,
sel_set: &Positioned<SelectionSet>,
cost: &mut ComponentCost,
variables: &Variables,
ctx: &ExtensionContext<'_>,
) -> ServerResult<()> {
// Use BFS to analyze the query and count the number of nodes and the depth of the query
struct ToVisit<'s> {
selection: &'s Positioned<Selection>,
parent_node_count: u32,
}
/// Parse the selected fields in one operation and check if it conforms to configured limits.
fn analyze_selection_set(
limits: &Limits,
fragment_defs: &HashMap<Name, Positioned<FragmentDefinition>>,
sel_set: &Positioned<SelectionSet>,
cost: &mut ComponentCost,
variables: &Variables,
ctx: &ExtensionContext<'_>,
) -> ServerResult<()> {
// Use BFS to analyze the query and count the number of nodes and the depth of the query
struct ToVisit<'s> {
selection: &'s Positioned<Selection>,
parent_node_count: u32,
}

// Queue to store the nodes at each level
let mut que = VecDeque::new();
// Queue to store the nodes at each level
let mut que = VecDeque::new();

for selection in sel_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count: 1,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}
for selection in sel_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count: 1,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}

// Track the number of nodes at first level if any
let mut level_len = que.len();

while !que.is_empty() {
// Signifies the start of a new level
cost.depth += 1;
check_limits(limits, cost, None, ctx)?;
while level_len > 0 {
// Ok to unwrap since we checked for empty queue
// and level_len > 0
let ToVisit {
selection,
parent_node_count,
} = que.pop_front().unwrap();

match &selection.node {
Selection::Field(f) => {
check_directives(&f.node.directives)?;

let current_count = estimate_output_nodes_for_curr_node(
f,
variables,
limits.default_page_size,
) * parent_node_count;

cost.output_nodes += current_count;

for field_sel in f.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection: field_sel,
parent_node_count: current_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(field_sel.pos), ctx)?;
}
// Track the number of nodes at first level if any
let mut level_len = que.len();

while !que.is_empty() {
// Signifies the start of a new level
cost.depth += 1;
check_limits(limits, cost, None, ctx)?;
while level_len > 0 {
// Ok to unwrap since we checked for empty queue
// and level_len > 0
let ToVisit {
selection,
parent_node_count,
} = que.pop_front().unwrap();

match &selection.node {
Selection::Field(f) => {
check_directives(&f.node.directives)?;

let current_count =
estimate_output_nodes_for_curr_node(f, variables, limits.default_page_size)
* parent_node_count;

cost.output_nodes += current_count;

for field_sel in f.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection: field_sel,
parent_node_count: current_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(field_sel.pos), ctx)?;
}
}

Selection::FragmentSpread(fs) => {
let frag_name = &fs.node.fragment_name.node;
let frag_def = fragment_defs.get(frag_name).ok_or_else(|| {
graphql_error_at_pos(
code::INTERNAL_SERVER_ERROR,
format!(
"Fragment {} not found but present in fragment list",
frag_name
),
fs.pos,
)
})?;

// TODO: this is inefficient as we might loop over same fragment multiple times
// Ideally web should cache the costs of fragments we've seen before
// Will do as enhancement
check_directives(&frag_def.node.directives)?;
for selection in frag_def.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}
Selection::FragmentSpread(fs) => {
let frag_name = &fs.node.fragment_name.node;
let frag_def = fragment_defs.get(frag_name).ok_or_else(|| {
graphql_error_at_pos(
code::INTERNAL_SERVER_ERROR,
format!(
"Fragment {} not found but present in fragment list",
frag_name
),
fs.pos,
)
})?;

// TODO: this is inefficient as we might loop over same fragment multiple times
// Ideally web should cache the costs of fragments we've seen before
// Will do as enhancement
check_directives(&frag_def.node.directives)?;
for selection in frag_def.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}
}

Selection::InlineFragment(fs) => {
check_directives(&fs.node.directives)?;
for selection in fs.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}
Selection::InlineFragment(fs) => {
check_directives(&fs.node.directives)?;
for selection in fs.node.selection_set.node.items.iter() {
que.push_back(ToVisit {
selection,
parent_node_count,
});
cost.input_nodes += 1;
check_limits(limits, cost, Some(selection.pos), ctx)?;
}
}
level_len -= 1;
}
level_len = que.len();
level_len -= 1;
}

Ok(())
level_len = que.len();
}

Ok(())
}

fn check_limits(
Expand Down
8 changes: 4 additions & 4 deletions crates/sui-graphql-rpc/src/server/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl ServerBuilder {
builder = builder.extension(Logger::default());
}
if config.internal_features.query_limits_checker {
builder = builder.extension(QueryLimitsChecker::default());
builder = builder.extension(QueryLimitsChecker);
}
if config.internal_features.query_timeout {
builder = builder.extension(Timeout);
Expand Down Expand Up @@ -864,7 +864,7 @@ pub mod tests {
};

let schema = prep_schema(None, Some(service_config))
.extension(QueryLimitsChecker::default())
.extension(QueryLimitsChecker)
.build_schema();
schema.execute(query).await
}
Expand Down Expand Up @@ -922,7 +922,7 @@ pub mod tests {
};

let schema = prep_schema(None, Some(service_config))
.extension(QueryLimitsChecker::default())
.extension(QueryLimitsChecker)
.build_schema();
schema.execute(query).await
}
Expand Down Expand Up @@ -1042,7 +1042,7 @@ pub mod tests {
let server_builder = prep_schema(None, None);
let metrics = server_builder.state.metrics.clone();
let schema = server_builder
.extension(QueryLimitsChecker::default()) // QueryLimitsChecker is where we actually set the metrics
.extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics
.build_schema();

schema
Expand Down

0 comments on commit 3731f28

Please sign in to comment.