diff --git a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs index 98acec0add6f4..882be3edeb07f 100644 --- a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs +++ b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs @@ -25,6 +25,9 @@ use tokio::sync::Mutex; use tracing::info; use uuid::Uuid; +/// Extension factory for adding checks that the query is within configurable limits. +pub(crate) struct QueryLimitsChecker; + /// Only display usage information if this header was in the request. pub(crate) struct ShowUsage; @@ -39,7 +42,7 @@ struct ValidationRes { } #[derive(Debug, Default)] -pub(crate) struct QueryLimitsChecker { +struct QueryLimitsCheckerExt { validation_result: Mutex>, } @@ -64,7 +67,7 @@ impl headers::Header for ShowUsage { impl ExtensionFactory for QueryLimitsChecker { fn create(&self) -> Arc { - Arc::new(QueryLimitsChecker { + Arc::new(QueryLimitsCheckerExt { validation_result: Mutex::new(None), }) } @@ -90,7 +93,7 @@ impl std::ops::Add for ComponentCost { } #[async_trait::async_trait] -impl Extension for QueryLimitsChecker { +impl Extension for QueryLimitsCheckerExt { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { let resp = next.run(ctx).await; let validation_result = self.validation_result.lock().await.take(); @@ -179,7 +182,7 @@ impl Extension for QueryLimitsChecker { } running_costs.depth = 0; - self.analyze_selection_set( + analyze_selection_set( &cfg.limits, &doc.fragments, sel_set, @@ -220,118 +223,113 @@ impl Extension for QueryLimitsChecker { } } -impl QueryLimitsChecker { - /// Parse the selected fields in one operation and check if it conforms to configured limits. - fn analyze_selection_set( - &self, - limits: &Limits, - fragment_defs: &HashMap>, - sel_set: &Positioned, - cost: &mut ComponentCost, - variables: &Variables, - ctx: &ExtensionContext<'_>, - ) -> ServerResult<()> { - // Use BFS to analyze the query and count the number of nodes and the depth of the query - struct ToVisit<'s> { - selection: &'s Positioned, - parent_node_count: u32, - } +/// Parse the selected fields in one operation and check if it conforms to configured limits. +fn analyze_selection_set( + limits: &Limits, + fragment_defs: &HashMap>, + sel_set: &Positioned, + cost: &mut ComponentCost, + variables: &Variables, + ctx: &ExtensionContext<'_>, +) -> ServerResult<()> { + // Use BFS to analyze the query and count the number of nodes and the depth of the query + struct ToVisit<'s> { + selection: &'s Positioned, + parent_node_count: u32, + } - // Queue to store the nodes at each level - let mut que = VecDeque::new(); + // Queue to store the nodes at each level + let mut que = VecDeque::new(); - for selection in sel_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count: 1, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + for selection in sel_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count: 1, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; + } - // Track the number of nodes at first level if any - let mut level_len = que.len(); - - while !que.is_empty() { - // Signifies the start of a new level - cost.depth += 1; - check_limits(limits, cost, None, ctx)?; - while level_len > 0 { - // Ok to unwrap since we checked for empty queue - // and level_len > 0 - let ToVisit { - selection, - parent_node_count, - } = que.pop_front().unwrap(); - - match &selection.node { - Selection::Field(f) => { - check_directives(&f.node.directives)?; - - let current_count = estimate_output_nodes_for_curr_node( - f, - variables, - limits.default_page_size, - ) * parent_node_count; - - cost.output_nodes += current_count; - - for field_sel in f.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection: field_sel, - parent_node_count: current_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(field_sel.pos), ctx)?; - } + // Track the number of nodes at first level if any + let mut level_len = que.len(); + + while !que.is_empty() { + // Signifies the start of a new level + cost.depth += 1; + check_limits(limits, cost, None, ctx)?; + while level_len > 0 { + // Ok to unwrap since we checked for empty queue + // and level_len > 0 + let ToVisit { + selection, + parent_node_count, + } = que.pop_front().unwrap(); + + match &selection.node { + Selection::Field(f) => { + check_directives(&f.node.directives)?; + + let current_count = + estimate_output_nodes_for_curr_node(f, variables, limits.default_page_size) + * parent_node_count; + + cost.output_nodes += current_count; + + for field_sel in f.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection: field_sel, + parent_node_count: current_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(field_sel.pos), ctx)?; } + } - Selection::FragmentSpread(fs) => { - let frag_name = &fs.node.fragment_name.node; - let frag_def = fragment_defs.get(frag_name).ok_or_else(|| { - graphql_error_at_pos( - code::INTERNAL_SERVER_ERROR, - format!( - "Fragment {} not found but present in fragment list", - frag_name - ), - fs.pos, - ) - })?; - - // TODO: this is inefficient as we might loop over same fragment multiple times - // Ideally web should cache the costs of fragments we've seen before - // Will do as enhancement - check_directives(&frag_def.node.directives)?; - for selection in frag_def.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + Selection::FragmentSpread(fs) => { + let frag_name = &fs.node.fragment_name.node; + let frag_def = fragment_defs.get(frag_name).ok_or_else(|| { + graphql_error_at_pos( + code::INTERNAL_SERVER_ERROR, + format!( + "Fragment {} not found but present in fragment list", + frag_name + ), + fs.pos, + ) + })?; + + // TODO: this is inefficient as we might loop over same fragment multiple times + // Ideally web should cache the costs of fragments we've seen before + // Will do as enhancement + check_directives(&frag_def.node.directives)?; + for selection in frag_def.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; } + } - Selection::InlineFragment(fs) => { - check_directives(&fs.node.directives)?; - for selection in fs.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } + Selection::InlineFragment(fs) => { + check_directives(&fs.node.directives)?; + for selection in fs.node.selection_set.node.items.iter() { + que.push_back(ToVisit { + selection, + parent_node_count, + }); + cost.input_nodes += 1; + check_limits(limits, cost, Some(selection.pos), ctx)?; } } - level_len -= 1; } - level_len = que.len(); + level_len -= 1; } - - Ok(()) + level_len = que.len(); } + + Ok(()) } fn check_limits( diff --git a/crates/sui-graphql-rpc/src/server/builder.rs b/crates/sui-graphql-rpc/src/server/builder.rs index 5cd601de9039d..f2ef3f408567b 100644 --- a/crates/sui-graphql-rpc/src/server/builder.rs +++ b/crates/sui-graphql-rpc/src/server/builder.rs @@ -476,7 +476,7 @@ impl ServerBuilder { builder = builder.extension(Logger::default()); } if config.internal_features.query_limits_checker { - builder = builder.extension(QueryLimitsChecker::default()); + builder = builder.extension(QueryLimitsChecker); } if config.internal_features.query_timeout { builder = builder.extension(Timeout); @@ -864,7 +864,7 @@ pub mod tests { }; let schema = prep_schema(None, Some(service_config)) - .extension(QueryLimitsChecker::default()) + .extension(QueryLimitsChecker) .build_schema(); schema.execute(query).await } @@ -922,7 +922,7 @@ pub mod tests { }; let schema = prep_schema(None, Some(service_config)) - .extension(QueryLimitsChecker::default()) + .extension(QueryLimitsChecker) .build_schema(); schema.execute(query).await } @@ -1042,7 +1042,7 @@ pub mod tests { let server_builder = prep_schema(None, None); let metrics = server_builder.state.metrics.clone(); let schema = server_builder - .extension(QueryLimitsChecker::default()) // QueryLimitsChecker is where we actually set the metrics + .extension(QueryLimitsChecker) // QueryLimitsChecker is where we actually set the metrics .build_schema(); schema