From 3731f2887c9619751677255ec1325a5467b341b7 Mon Sep 17 00:00:00 2001 From: Ashok Menon Date: Sat, 13 Jul 2024 14:52:49 +0100 Subject: [PATCH] [chore][GraphQL/Limits] Separate QueryLimitChecker extension/factory (#18660) ## Description Split up the extension factory from the extension itself, similar to what we did for `Timeout` before. This avoids the confusion of the single type being created with defaulted fields to act as the factory and then creating new versions of itself to act as the extension. ## Test plan CI --- ## Release notes Check each box that your changes affect. If none of the boxes relate to your changes, release notes aren't required. For each box you select, include information after the relevant heading that describes the impact of your changes that a user might notice and any actions they must take to implement updates. - [ ] Protocol: - [ ] Nodes (Validators and Full nodes): - [ ] Indexer: - [ ] JSON-RPC: - [ ] GraphQL: - [ ] CLI: - [ ] Rust SDK: --- .../src/extensions/query_limits_checker.rs | 206 +++++++++--------- crates/sui-graphql-rpc/src/server/builder.rs | 8 +- 2 files changed, 106 insertions(+), 108 deletions(-) diff --git a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs index 98acec0add6f4..882be3edeb07f 100644 --- a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs +++ b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs @@ -25,6 +25,9 @@ use tokio::sync::Mutex; use tracing::info; use uuid::Uuid; +/// Extension factory for adding checks that the query is within configurable limits. +pub(crate) struct QueryLimitsChecker; + /// Only display usage information if this header was in the request. pub(crate) struct ShowUsage; @@ -39,7 +42,7 @@ struct ValidationRes { } #[derive(Debug, Default)] -pub(crate) struct QueryLimitsChecker { +struct QueryLimitsCheckerExt { validation_result: Mutex>, } @@ -64,7 +67,7 @@ impl headers::Header for ShowUsage { impl ExtensionFactory for QueryLimitsChecker { fn create(&self) -> Arc { - Arc::new(QueryLimitsChecker { + Arc::new(QueryLimitsCheckerExt { validation_result: Mutex::new(None), }) } @@ -90,7 +93,7 @@ impl std::ops::Add for ComponentCost { } #[async_trait::async_trait] -impl Extension for QueryLimitsChecker { +impl Extension for QueryLimitsCheckerExt { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { let resp = next.run(ctx).await; let validation_result = self.validation_result.lock().await.take(); @@ -179,7 +182,7 @@ impl Extension for QueryLimitsChecker { } running_costs.depth = 0; - self.analyze_selection_set( + analyze_selection_set( &cfg.limits, &doc.fragments, sel_set, @@ -220,118 +223,113 @@ impl Extension for QueryLimitsChecker { } } -impl QueryLimitsChecker { - /// Parse the selected fields in one operation and check if it conforms to configured limits. - fn analyze_selection_set( - &self, - limits: &Limits, - fragment_defs: &HashMap>, - sel_set: &Positioned, - cost: &mut ComponentCost, - variables: &Variables, - ctx: &ExtensionContext<'_>, - ) -> ServerResult<()> { - // Use BFS to analyze the query and count the number of nodes and the depth of the query - struct ToVisit<'s> { - selection: &'s Positioned, - parent_node_count: u32, - } +/// Parse the selected fields in one operation and check if it conforms to configured limits. +fn analyze_selection_set( + limits: &Limits, + fragment_defs: &HashMap>, + sel_set: &Positioned, + cost: &mut ComponentCost, + variables: &Variables, + ctx: &ExtensionContext<'_>, +) -> ServerResult<()> { + // Use BFS to analyze the query and count the number of nodes and the depth of the query + struct ToVisit<'s> { + selection: &'s Positioned, + parent_node_count: u32, + } - // Queue to store the nodes at each level - let mut que = VecDeque::new(); + // Queue to store the nodes at each level + let mut que = VecDeque::new(); - for selection in sel_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count: 1, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + for selection in sel_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count: 1, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; + } - // Track the number of nodes at first level if any - let mut level_len = que.len(); - - while !que.is_empty() { - // Signifies the start of a new level - cost.depth += 1; - check_limits(limits, cost, None, ctx)?; - while level_len > 0 { - // Ok to unwrap since we checked for empty queue - // and level_len > 0 - let ToVisit { - selection, - parent_node_count, - } = que.pop_front().unwrap(); - - match &selection.node { - Selection::Field(f) => { - check_directives(&f.node.directives)?; - - let current_count = estimate_output_nodes_for_curr_node( - f, - variables, - limits.default_page_size, - ) * parent_node_count; - - cost.output_nodes += current_count; - - for field_sel in f.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection: field_sel, - parent_node_count: current_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(field_sel.pos), ctx)?; - } + // Track the number of nodes at first level if any + let mut level_len = que.len(); + + while !que.is_empty() { + // Signifies the start of a new level + cost.depth += 1; + check_limits(limits, cost, None, ctx)?; + while level_len > 0 { + // Ok to unwrap since we checked for empty queue + // and level_len > 0 + let ToVisit { + selection, + parent_node_count, + } = que.pop_front().unwrap(); + + match &selection.node { + Selection::Field(f) => { + check_directives(&f.node.directives)?; + + let current_count = + estimate_output_nodes_for_curr_node(f, variables, limits.default_page_size) + * parent_node_count; + + cost.output_nodes += current_count; + + for field_sel in f.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection: field_sel, + parent_node_count: current_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(field_sel.pos), ctx)?; } + } - Selection::FragmentSpread(fs) => { - let frag_name = &fs.node.fragment_name.node; - let frag_def = fragment_defs.get(frag_name).ok_or_else(|| { - graphql_error_at_pos( - code::INTERNAL_SERVER_ERROR, - format!( - "Fragment {} not found but present in fragment list", - frag_name - ), - fs.pos, - ) - })?; - - // TODO: this is inefficient as we might loop over same fragment multiple times - // Ideally web should cache the costs of fragments we've seen before - // Will do as enhancement - check_directives(&frag_def.node.directives)?; - for selection in frag_def.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + Selection::FragmentSpread(fs) => { + let frag_name = &fs.node.fragment_name.node; + let frag_def = fragment_defs.get(frag_name).ok_or_else(|| { + graphql_error_at_pos( + code::INTERNAL_SERVER_ERROR, + format!( + "Fragment {} not found but present in fragment list", + frag_name + ), + fs.pos, + ) + })?; + + // TODO: this is inefficient as we might loop over same fragment multiple times + // Ideally web should cache the costs of fragments we've seen before + // Will do as enhancement + check_directives(&frag_def.node.directives)?; + for selection in frag_def.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; } + } - Selection::InlineFragment(fs) => { - check_directives(&fs.node.directives)?; - for selection in fs.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + Selection::InlineFragment(fs) => { + check_directives(&fs.node.directives)?; + for selection in fs.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; } } - level_len -= 1; } - level_len = que.len(); + level_len -= 1; } - - Ok(()) + level_len = que.len(); } + + Ok(()) } fn check_limits( diff --git a/crates/sui-graphql-rpc/src/server/builder.rs b/crates/sui-graphql-rpc/src/server/builder.rs index 5cd601de9039d..f2ef3f408567b 100644 --- a/crates/sui-graphql-rpc/src/server/builder.rs +++ b/crates/sui-graphql-rpc/src/server/builder.rs @@ -476,7 +476,7 @@ impl ServerBuilder { builder = builder.extension(Logger::default()); } if config.internal_features.query_limits_checker { - builder = builder.extension(QueryLimitsChecker::default()); + builder = builder.extension(QueryLimitsChecker); } if config.internal_features.query_timeout { builder = builder.extension(Timeout); @@ -864,7 +864,7 @@ pub mod tests { }; let schema = prep_schema(None, Some(service_config)) - .extension(QueryLimitsChecker::default()) + .extension(QueryLimitsChecker) .build_schema(); schema.execute(query).await } @@ -922,7 +922,7 @@ pub mod tests { }; let schema = prep_schema(None, Some(service_config)) - .extension(QueryLimitsChecker::default()) + .extension(QueryLimitsChecker) .build_schema(); schema.execute(query).await } @@ -1042,7 +1042,7 @@ pub mod tests { let server_builder = prep_schema(None, None); let metrics = server_builder.state.metrics.clone(); let schema = server_builder - .extension(QueryLimitsChecker::default()) // QueryLimitsChecker is where we actually set the metrics + .extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics .build_schema(); schema