From 4ba06ed7463743fa46f6b5879b7cd1f4b76da2dc Mon Sep 17 00:00:00 2001 From: Ashok Menon Date: Tue, 16 Jul 2024 17:02:40 +0100 Subject: [PATCH] [GraphQL/Limits] Reimplement QueryLimitsChecker (#18666) ## Description Rewriting query limits checker to land a number of improvements and fixes: - Avoid issues with overflows by counting down from a predefined budget, rather than counting up to the limit and protecting multiplications using `checked_mul`. - Improve detection of paginated fields: - Previously we treated all connections-related fields as appearing as many times as the page size (including the field that introduced the connection, and the `pageInfo` field). This was over-approximated the output size by a large margin. The new approach counts exactly the number of nodes in the output: The connection's root field, and any non-`edges` or `nodes` field will not get multiplied by the page size. - The checker now also detects connections-related fields even if they are obscured by fragment or inline fragment spreads. - Tighter `__schema` query detection: Previously we would skip requests that started with a `__schema` introspection query. Now it's required to be the only operation in the request (not just the first). - Fix metrics collection after limits are hit: Previously, if a limit was hit, we would not observe validation-related metrics in prometheus. Now we will always record such metrics, and if a limit has been hit, it will register as being "at" the limit. ## Test plan ``` sui-graphql-e2e-tests$ cargo nextest run --features pg_integration -- limits/ ``` ## Stack - #18660 - #18661 - #18662 - #18663 - #18664 --- ## Release notes Check each box that your changes affect. If none of the boxes relate to your changes, release notes aren't required. For each box you select, include information after the relevant heading that describes the impact of your changes that a user might notice and any actions they must take to implement updates. - [ ] Protocol: - [ ] Nodes (Validators and Full nodes): - [ ] Indexer: - [ ] JSON-RPC: - [x] GraphQL: Output node estimation has been made more accurate -- the estimate should now track the theoretical max number of nodes on the JSON `data` output. - [ ] CLI: - [ ] Rust SDK: --- .../tests/limits/output_node_estimation.exp | 103 ++- .../tests/limits/output_node_estimation.move | 262 ++++--- .../src/extensions/query_limits_checker.rs | 682 +++++++++--------- crates/sui-graphql-rpc/src/server/builder.rs | 20 +- 4 files changed, 588 insertions(+), 479 deletions(-) diff --git a/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.exp b/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.exp index 90b87a03faaf2..3305ac5a56ce1 100644 --- a/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.exp +++ b/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.exp @@ -1,4 +1,4 @@ -processed 14 tasks +processed 16 tasks task 1 'run-graphql'. lines 6-14: Response: { @@ -16,7 +16,7 @@ Response: { "depth": 3, "variables": 0, "fragments": 0, - "queryPayload": 132 + "queryPayload": 254 } } } @@ -37,11 +37,11 @@ Response: { "extensions": { "usage": { "inputNodes": 4, - "outputNodes": 80, + "outputNodes": 42, "depth": 4, "variables": 0, "fragments": 0, - "queryPayload": 163 + "queryPayload": 359 } } } @@ -68,11 +68,11 @@ Response: { "extensions": { "usage": { "inputNodes": 6, - "outputNodes": 1640, + "outputNodes": 842, "depth": 6, "variables": 0, "fragments": 0, - "queryPayload": 206 + "queryPayload": 484 } } } @@ -108,11 +108,11 @@ Response: { "extensions": { "usage": { "inputNodes": 10, - "outputNodes": 1720, + "outputNodes": 922, "depth": 6, "variables": 0, "fragments": 0, - "queryPayload": 306 + "queryPayload": 735 } } } @@ -142,11 +142,11 @@ Response: { "extensions": { "usage": { "inputNodes": 10, - "outputNodes": 1640, + "outputNodes": 882, "depth": 6, "variables": 0, "fragments": 0, - "queryPayload": 308 + "queryPayload": 733 } } } @@ -171,7 +171,7 @@ Response: { "depth": 4, "variables": 0, "fragments": 0, - "queryPayload": 145 + "queryPayload": 323 } } } @@ -196,7 +196,7 @@ Response: { "depth": 4, "variables": 0, "fragments": 0, - "queryPayload": 143 + "queryPayload": 322 } } } @@ -219,16 +219,16 @@ Response: { "extensions": { "usage": { "inputNodes": 14, - "outputNodes": 3320, + "outputNodes": 1762, "depth": 8, "variables": 0, "fragments": 0, - "queryPayload": 501 + "queryPayload": 1077 } } } -task 9 'run-graphql'. lines 144-170: +task 9 'run-graphql'. lines 144-171: Response: { "data": { "transactionBlocks": { @@ -244,16 +244,16 @@ Response: { "extensions": { "usage": { "inputNodes": 13, - "outputNodes": 3300, + "outputNodes": 1742, "depth": 7, "variables": 0, "fragments": 0, - "queryPayload": 533 + "queryPayload": 1030 } } } -task 10 'run-graphql'. lines 172-221: +task 10 'run-graphql'. lines 173-222: Response: { "data": { "transactionBlocks": { @@ -274,16 +274,16 @@ Response: { "extensions": { "usage": { "inputNodes": 24, - "outputNodes": 86340, + "outputNodes": 46424, "depth": 11, "variables": 0, "fragments": 0, - "queryPayload": 1395 + "queryPayload": 2093 } } } -task 11 'run-graphql'. lines 223-248: +task 11 'run-graphql'. lines 224-249: Response: { "data": { "transactionBlocks": { @@ -300,26 +300,56 @@ Response: { "extensions": { "usage": { "inputNodes": 12, - "outputNodes": 33300, + "outputNodes": 17302, "depth": 11, "variables": 0, "fragments": 0, - "queryPayload": 704 + "queryPayload": 1029 } } } -task 12 'run-graphql'. lines 250-260: +task 12 'run-graphql'. lines 251-272: +Response: { + "data": { + "fragmentSpread": { + "nodes": [ + { + "digest": "EoFwLKRy23XKLkWZbBLiqjTV2vsKPsmpW6dV2caK8ZDH" + } + ] + }, + "inlineFragment": { + "nodes": [ + { + "digest": "EoFwLKRy23XKLkWZbBLiqjTV2vsKPsmpW6dV2caK8ZDH" + } + ] + } + }, + "extensions": { + "usage": { + "inputNodes": 8, + "outputNodes": 44, + "depth": 4, + "variables": 0, + "fragments": 1, + "queryPayload": 562 + } + } +} + +task 13 'run-graphql'. lines 274-286: Response: { "data": null, "extensions": { "usage": { "inputNodes": 4, - "outputNodes": 80, + "outputNodes": 62, "depth": 4, "variables": 0, "fragments": 0, - "queryPayload": 154 + "queryPayload": 394 } }, "errors": [ @@ -327,7 +357,7 @@ Response: { "message": "'first' and 'last' must not be used together", "locations": [ { - "line": 3, + "line": 4, "column": 3 } ], @@ -341,17 +371,17 @@ Response: { ] } -task 13 'run-graphql'. lines 262-272: +task 14 'run-graphql'. lines 288-298: Response: { "data": null, "extensions": { "usage": { "inputNodes": 4, - "outputNodes": 80, + "outputNodes": 42, "depth": 4, "variables": 0, "fragments": 0, - "queryPayload": 147 + "queryPayload": 141 } }, "errors": [ @@ -366,3 +396,16 @@ Response: { } ] } + +task 15 'run-graphql'. lines 300-310: +Response: { + "data": null, + "errors": [ + { + "message": "Estimated output nodes exceeds 100000", + "extensions": { + "code": "BAD_USER_INPUT" + } + } + ] +} diff --git a/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.move b/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.move index 2d23f48fd7c71..62dcb8dbd6b97 100644 --- a/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.move +++ b/crates/sui-graphql-e2e-tests/tests/limits/output_node_estimation.move @@ -6,9 +6,9 @@ //# run-graphql --show-usage # pageInfo does not inherit connection's weights { - transactionBlocks(first: 20) { - pageInfo { - hasPreviousPage + transactionBlocks(first: 20) { # 1 + pageInfo { # 1 + hasPreviousPage # 1 } } } @@ -16,10 +16,10 @@ //# run-graphql --show-usage # if connection does not have 'first' or 'last' set, use default_page_size (20) { - transactionBlocks { - edges { - node { - digest + transactionBlocks { # 1 + edges { # 1 + node { # 20 + digest # 20 } } } @@ -28,12 +28,12 @@ //# run-graphql --show-usage # build on previous example with nested connection { - checkpoints { - nodes { - transactionBlocks { - edges { - txns: node { - digest + checkpoints { # 1 + nodes { # 1 + transactionBlocks { # 20 + edges { # 20 + txns: node { # 400 + digest # 400 } } } @@ -44,19 +44,19 @@ //# run-graphql --show-usage # handles 1 { - checkpoints { - nodes { - notOne: transactionBlocks { - edges { - txns: node { - digest + checkpoints { # 1 + nodes { # 1 + notOne: transactionBlocks { # 20 + edges { # 20 + txns: node { # 400 + digest # 400 } } } - isOne: transactionBlocks(first: 1) { - edges { - txns: node { - digest + isOne: transactionBlocks(first: 1) { # 20 + edges { # 20 + txns: node { # 20 + digest # 20 } } } @@ -67,19 +67,19 @@ //# run-graphql --show-usage # handles 0 { - checkpoints { - nodes { - notZero: transactionBlocks { - edges { - txns: node { - digest + checkpoints { # 1 + nodes { # 1 + notZero: transactionBlocks { # 20 + edges { # 20 + txns: node { # 400 + digest # 400 } } } - isZero: transactionBlocks(first: 0) { - edges { - txns: node { - digest + isZero: transactionBlocks(first: 0) { # 20 + edges { # 20 + txns: node { # 0 + digest # 0 } } } @@ -90,10 +90,10 @@ //# run-graphql --show-usage # if connection does have 'first' set, use it { - transactionBlocks(first: 1) { - edges { - txns: node { - digest + transactionBlocks(first: 1) { # 1 + edges { # 1 + txns: node { # 1 + digest # 1 } } } @@ -102,10 +102,10 @@ //# run-graphql --show-usage # if connection does have 'last' set, use it { - transactionBlocks(last: 1) { - edges { - txns: node { - digest + transactionBlocks(last: 1) { # 1 + edges { # 1 + txns: node { # 1 + digest # 1 } } } @@ -114,24 +114,24 @@ //# run-graphql --show-usage # first and last should behave the same { - transactionBlocks { - edges { - txns: node { - digest - first: expiration { - checkpoints(first: 20) { - edges { - node { - sequenceNumber + transactionBlocks { # 1 + edges { # 1 + txns: node { # 20 + digest # 20 + first: expiration { # 20 + checkpoints(first: 20) { # 20 + edges { # 20 + node { # 400 + sequenceNumber # 400 } } } } - last: expiration { - checkpoints(last: 20) { - edges { - node { - sequenceNumber + last: expiration { # 20 + checkpoints(last: 20) { # 20 + edges { # 20 + node { # 400 + sequenceNumber # 400 } } } @@ -142,28 +142,29 @@ } //# run-graphql --show-usage -# edges incur additional cost over nodes +# edges incur additional cost over nodes, because of the extra level +# of nesting { - transactionBlocks { - nodes { - digest - first: expiration { # 80 cumulative - checkpoints(first: 20) { - edges { - node { - sequenceNumber + transactionBlocks { # 1 + nodes { # 1 + digest # 20 + first: expiration { # 20 + checkpoints(first: 20) { # 20 + edges { # 20 + node { # 400 + sequenceNumber # 400 } } } - } # 1680 cumulative - last: expiration { # 20 + 1680 = 1700 cumulative - checkpoints(last: 20) { - edges { - node { - sequenceNumber + } + last: expiration { # 20 + checkpoints(last: 20) { # 20 + edges { # 20 + node { # 400 + sequenceNumber # 400 } } - } # another 1600, 3300 cumulative + } } } } @@ -174,18 +175,18 @@ # https://docs.github.com/en/graphql/overview/rate-limits-and-node-limits-for-the-graphql-api#node-limit # our costing will be different since we consider all nodes { - transactionBlocks(first: 50) { # 50, 50 - edges { # 50, 100 - txns: node { # 50, 150 - digest # 50, 200 - a: expiration { # 50, 250 - checkpoints(last: 20) { # 50 * 20 = 1000, 1250 - edges { # 1000, 2250 - node { # 1000, 3250 - transactionBlocks(first: 10) { # 50 * 20 * 10 = 10000, 13250 - edges { # 10000, 23250 - node { # 10000, 33250 - digest # 10000, 43250 + transactionBlocks(first: 50) { # 1 + edges { # 1 + txns: node { # 50 + digest # 50 + a: expiration { # 50 + checkpoints(last: 20) { # 50 + edges { # 50 + node { # 50 * 20 + transactionBlocks(first: 10) { # 50 * 20 + edges { # 50 * 20 + node { # 50 * 20 * 10 + digest # 50 * 20 * 10 } } } @@ -193,14 +194,14 @@ } } } - b: expiration { # 50, 43300 - checkpoints(first: 20) { # 50 * 20 = 1000, 44300 - edges { # 1000, 45300 - node { # 1000, 46300 - transactionBlocks(last: 10) { # 50 * 20 * 10 = 10000, 56300 - edges { # 10000, 66300 - node { # 10000, 76300 - digest # 10000, 86300 + b: expiration { # 50 + checkpoints(first: 20) { # 50 + edges { # 50 + node { # 50 * 20 + transactionBlocks(last: 10) { # 50 * 20 + edges { # 50 * 20 + node { # 50 * 20 * 10 + digest # 50 * 20 * 10 } } } @@ -211,30 +212,30 @@ } } } - events(last: 10) { # 10 - edges { - node { - timestamp + events(last: 10) { # 1 + edges { # 1 + node { # 10 + timestamp # 10 } } - } # 40, 86340 + } } //# run-graphql --show-usage # Null value for variable passed to limit will use default_page_size query NullVariableForLimit($howMany: Int) { - transactionBlocks(last: $howMany) { # 20, 20 - edges { # 20, 40 - node { # 20, 60 - digest # 20, 80 - a: expiration { # 20, 100 - checkpoints { # 20 * 20, 500 - edges { # 400, 900 - node { # 400, 1300 - transactionBlocks(first: $howMany) { # 20 * 20 * 20 = 8000, 9300 - edges { # 8000, 17300 - node { # 8000, 25300 - digest # 8000, 33300 + transactionBlocks(last: $howMany) { # 1 + edges { # 1 + node { # 20 + digest # 20 + a: expiration { # 20 + checkpoints { # 20 + edges { # 20 + node { # 400 + transactionBlocks(first: $howMany) { # 400 + edges { # 400 + node { # 8000 + digest # 8000 } } } @@ -248,9 +249,46 @@ query NullVariableForLimit($howMany: Int) { } //# run-graphql --show-usage -# error state - can't use first and last together +# Connection detection needs to be resilient to connection fields +# being obscured by fragments. +fragment Nodes on TransactionBlockConnection { + nodes { + digest + } +} + +{ + fragmentSpread: transactionBlocks { # 1 + ...Nodes # 1 + 20 + } + + inlineFragment: transactionBlocks { # 1 + ... on TransactionBlockConnection { + nodes { # 1 + digest # 20 + } + } + } +} + +//# run-graphql --show-usage + +# error state - can't use first and last together, but we will use the +# max of the two for output node estimation +{ + transactionBlocks(first: 20, last: 30) { # 1 + edges { # 1 + node { # 30 + digest # 30 + } + } + } +} + +//# run-graphql --show-usage +# error state - overflow u64 { - transactionBlocks(first: 20, last: 30) { + transactionBlocks(first: 36893488147419103000) { edges { node { digest @@ -260,9 +298,9 @@ query NullVariableForLimit($howMany: Int) { } //# run-graphql --show-usage -# error state - exceed max integer +# error state, overflow u32 { - transactionBlocks(first: 36893488147419103000) { + transactionBlocks(first: 4294967297) { edges { node { digest diff --git a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs index 79054f833e93a..12fcfa7ca711a 100644 --- a/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs +++ b/crates/sui-graphql-rpc/src/extensions/query_limits_checker.rs @@ -8,17 +8,20 @@ use async_graphql::extensions::NextParseQuery; use async_graphql::extensions::NextRequest; use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory}; use async_graphql::parser::types::{ - ExecutableDocument, Field, FragmentDefinition, Selection, SelectionSet, + DocumentOperations, ExecutableDocument, Field, FragmentDefinition, OperationDefinition, + Selection, }; -use async_graphql::{value, Name, Pos, Positioned, Response, ServerResult, Value, Variables}; -use async_graphql_value::Value as GqlValue; +use async_graphql::{value, Name, Positioned, Response, ServerError, ServerResult, Variables}; +use async_graphql_value::{ConstValue, Value}; +use async_trait::async_trait; use axum::http::HeaderName; -use std::collections::{HashMap, VecDeque}; +use serde::Serialize; +use std::collections::HashMap; +use std::mem; use std::net::SocketAddr; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Instant; use sui_graphql_rpc_headers::LIMITS_HEADER; -use tokio::sync::Mutex; use tracing::info; use uuid::Uuid; @@ -29,60 +32,365 @@ pub(crate) struct QueryLimitsChecker; #[derive(Debug, Default)] struct QueryLimitsCheckerExt { - validation_result: Mutex>, + usage: Mutex>, } /// Only display usage information if this header was in the request. pub(crate) struct ShowUsage; -#[derive(Clone, Debug, Default)] -struct ValidationRes { +/// State for traversing a document to check for limits. Holds on to environments for looking up +/// variables and fragments, limits, and the remainder of the limit that can be used. +struct LimitsTraversal<'a> { + // Environments for resolving lookups in the document + fragments: &'a HashMap>, + variables: &'a Variables, + + // Relevant limits from the service configuration + default_page_size: u32, + max_input_nodes: u32, + max_output_nodes: u32, + max_depth: u32, + + // Remaining budget for the traversal + input_budget: u32, + output_budget: u32, + depth_seen: u32, +} + +#[derive(Clone, Debug, Default, Serialize)] +#[serde(rename_all = "camelCase")] +struct Usage { input_nodes: u32, output_nodes: u32, depth: u32, - num_variables: u32, - num_fragments: u32, + variables: u32, + fragments: u32, query_payload: u32, } -#[derive(Debug)] -struct ComponentCost { - pub input_nodes: u32, - pub output_nodes: u32, - pub depth: u32, -} - impl ShowUsage { pub(crate) fn name() -> &'static HeaderName { &LIMITS_HEADER } } +impl<'a> LimitsTraversal<'a> { + fn new( + limits: &Limits, + fragments: &'a HashMap>, + variables: &'a Variables, + ) -> Self { + Self { + fragments, + variables, + default_page_size: limits.default_page_size, + max_input_nodes: limits.max_query_nodes, + max_output_nodes: limits.max_output_nodes, + max_depth: limits.max_query_depth, + input_budget: limits.max_query_nodes, + output_budget: limits.max_output_nodes, + depth_seen: 0, + } + } + + /// Main entrypoint for checking all limits. + fn check_document(&mut self, doc: &ExecutableDocument) -> ServerResult<()> { + for (_name, op) in doc.operations.iter() { + self.check_input_limits(op)?; + self.check_output_limits(op)?; + } + Ok(()) + } + + /// Test that the operation meets input limits (number of nodes and depth). + fn check_input_limits(&mut self, op: &Positioned) -> ServerResult<()> { + let mut next_level = vec![]; + let mut curr_level = vec![]; + let mut depth_budget = self.max_depth; + + next_level.extend(&op.node.selection_set.node.items); + while let Some(next) = next_level.first() { + if depth_budget == 0 { + return Err(graphql_error_at_pos( + code::BAD_USER_INPUT, + format!("Query nesting is over {}", self.max_depth), + next.pos, + )); + } else { + depth_budget -= 1; + } + + mem::swap(&mut next_level, &mut curr_level); + + for selection in curr_level.drain(..) { + if self.input_budget == 0 { + return Err(graphql_error_at_pos( + code::BAD_USER_INPUT, + format!("Query has over {} nodes", self.max_input_nodes), + selection.pos, + )); + } else { + self.input_budget -= 1; + } + + match &selection.node { + Selection::Field(f) => { + next_level.extend(&f.node.selection_set.node.items); + } + + Selection::InlineFragment(f) => { + next_level.extend(&f.node.selection_set.node.items); + } + + Selection::FragmentSpread(fs) => { + let name = &fs.node.fragment_name.node; + let def = self.fragments.get(name).ok_or_else(|| { + graphql_error_at_pos( + code::INTERNAL_SERVER_ERROR, + format!("Fragment {name} referred to but not found in document"), + fs.pos, + ) + })?; + + next_level.extend(&def.node.selection_set.node.items); + } + } + } + } + + self.depth_seen = self.depth_seen.max(self.max_depth - depth_budget); + Ok(()) + } + + /// Check that the operation's output node estimate will not exceed the service's limit. + /// + /// This check must be done after the input limit check, because it relies on the query depth + /// being bounded to protect it from recursing too deeply. + fn check_output_limits(&mut self, op: &Positioned) -> ServerResult<()> { + for selection in &op.node.selection_set.node.items { + self.traverse_selection_for_output(selection, 1, None)?; + } + Ok(()) + } + + /// Account for the estimated output size of this selection and its children. + /// + /// `multiplicity` is the number of times this selection will be output, on account of being + /// nested within paginated ancestors. + /// + /// If this field is inside a connection, but not inside one of its fields, `page_size` is the + /// size of the connection's page. + fn traverse_selection_for_output( + &mut self, + selection: &Positioned, + multiplicity: u32, + page_size: Option, + ) -> ServerResult<()> { + match &selection.node { + Selection::Field(f) => { + if multiplicity > self.output_budget { + return Err(self.output_node_error()); + } else { + self.output_budget -= multiplicity; + } + + // If the field being traversed is a connection field, increase multiplicity by a + // factor of page size. This operation can fail due to overflow, which will be + // treated as a limits check failure, even if the resulting value does not get used + // for anything. + let name = &f.node.name.node; + let multiplicity = 'm: { + if !CONNECTION_FIELDS.contains(&name.as_str()) { + break 'm multiplicity; + } + + let Some(page_size) = page_size else { + break 'm multiplicity; + }; + + multiplicity + .checked_mul(page_size) + .ok_or_else(|| self.output_node_error())? + }; + + let page_size = self.connection_page_size(f)?; + for selection in &f.node.selection_set.node.items { + self.traverse_selection_for_output(selection, multiplicity, page_size)?; + } + } + + // Just recurse through fragments, because they are inlined into their "call site". + Selection::InlineFragment(f) => { + for selection in f.node.selection_set.node.items.iter() { + self.traverse_selection_for_output(selection, multiplicity, page_size)?; + } + } + + Selection::FragmentSpread(fs) => { + let name = &fs.node.fragment_name.node; + let def = self.fragments.get(name).ok_or_else(|| { + graphql_error_at_pos( + code::INTERNAL_SERVER_ERROR, + format!("Fragment {name} referred to but not found in document"), + fs.pos, + ) + })?; + + for selection in def.node.selection_set.node.items.iter() { + self.traverse_selection_for_output(selection, multiplicity, page_size)?; + } + } + } + + Ok(()) + } + + /// If the field `f` is a connection, extract its page size, otherwise return `None`. + /// Returns an error if the page size cannot be represented as a `u32`. + fn connection_page_size(&mut self, f: &Positioned) -> ServerResult> { + if !self.is_connection(f) { + return Ok(None); + } + + let first = f.node.get_argument("first"); + let last = f.node.get_argument("last"); + + let page_size = match (self.resolve_u64(first), self.resolve_u64(last)) { + (Some(f), Some(l)) => f.max(l), + (Some(p), _) | (_, Some(p)) => p, + (None, None) => self.default_page_size as u64, + }; + + Ok(Some( + page_size.try_into().map_err(|_| self.output_node_error())?, + )) + } + + /// Checks if the given field corresponds to a connection based on whether it contains a + /// selection for `edges` or `nodes`. That selection could be immediately in that field's + /// selection set, or nested within a fragment or inline fragment spread. + fn is_connection(&self, f: &Positioned) -> bool { + f.node + .selection_set + .node + .items + .iter() + .any(|s| self.has_connection_fields(s)) + } + + /// Look for fields that suggest the container for this selection is a connection. Recurses + /// through fragment and inline fragment applications, but does not look recursively through + /// fields, as only the fields requested from the immediate parent are relevant. + fn has_connection_fields(&self, s: &Positioned) -> bool { + match &s.node { + Selection::Field(f) => { + let name = &f.node.name.node; + CONNECTION_FIELDS.contains(&name.as_str()) + } + + Selection::InlineFragment(f) => f + .node + .selection_set + .node + .items + .iter() + .any(|s| self.has_connection_fields(s)), + + Selection::FragmentSpread(fs) => { + let name = &fs.node.fragment_name.node; + let Some(def) = self.fragments.get(name) else { + return false; + }; + + def.node + .selection_set + .node + .items + .iter() + .any(|s| self.has_connection_fields(s)) + } + } + } + + /// Translate a GraphQL value into a u64, if possible, resolving variables if necessary. + fn resolve_u64(&self, value: Option<&Positioned>) -> Option { + match &value?.node { + Value::Number(num) => num, + + Value::Variable(var) => { + if let ConstValue::Number(num) = self.variables.get(var)? { + num + } else { + return None; + } + } + + _ => return None, + } + .as_u64() + } + + /// Error returned if output node estimate exceeds limit. Also sets the output budget to zero, + /// to indicate that it has been spent (This is done because unlike other budgets, the output + /// budget is not decremented one unit at a time, so we can have hit the limit previously but + /// still have budget left over). + fn output_node_error(&mut self) -> ServerError { + self.output_budget = 0; + graphql_error( + code::BAD_USER_INPUT, + format!("Estimated output nodes exceeds {}", self.max_output_nodes), + ) + } + + /// Finish the traversal and report its usage. + fn finish(self, query_payload: u32) -> Usage { + Usage { + input_nodes: self.max_input_nodes - self.input_budget, + output_nodes: self.max_output_nodes - self.output_budget, + depth: self.depth_seen, + variables: self.variables.len() as u32, + fragments: self.fragments.len() as u32, + query_payload, + } + } +} + +impl Usage { + fn report(&self, metrics: &Metrics) { + metrics + .request_metrics + .input_nodes + .observe(self.input_nodes as f64); + metrics + .request_metrics + .output_nodes + .observe(self.output_nodes as f64); + metrics + .request_metrics + .query_depth + .observe(self.depth as f64); + metrics + .request_metrics + .query_payload_size + .observe(self.query_payload as f64); + } +} + impl ExtensionFactory for QueryLimitsChecker { fn create(&self) -> Arc { Arc::new(QueryLimitsCheckerExt { - validation_result: Mutex::new(None), + usage: Mutex::new(None), }) } } -#[async_trait::async_trait] +#[async_trait] impl Extension for QueryLimitsCheckerExt { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { let resp = next.run(ctx).await; - let validation_result = self.validation_result.lock().await.take(); - if let Some(validation_result) = validation_result { - resp.extension( - "usage", - value! ({ - "inputNodes": validation_result.input_nodes, - "outputNodes": validation_result.output_nodes, - "depth": validation_result.depth, - "variables": validation_result.num_variables, - "fragments": validation_result.num_fragments, - "queryPayload": validation_result.query_payload, - }), - ) + let usage = self.usage.lock().unwrap().take(); + if let Some(usage) = usage { + resp.extension("usage", value!(usage)) } else { resp } @@ -100,10 +408,9 @@ impl Extension for QueryLimitsCheckerExt { let query_id: &Uuid = ctx.data_unchecked(); let session_id: &SocketAddr = ctx.data_unchecked(); let metrics: &Metrics = ctx.data_unchecked(); + let cfg: &ServiceConfig = ctx.data_unchecked(); let instant = Instant::now(); - let cfg = ctx - .data::() - .expect("No service config provided in schema data"); + if query.len() > cfg.limits.max_query_payload_size as usize { metrics .request_metrics @@ -129,298 +436,31 @@ impl Extension for QueryLimitsCheckerExt { // Document layout of the query let doc = next.run(ctx, query, variables).await?; - // TODO: Limit the complexity of fragments early on - - let mut running_costs = ComponentCost { - depth: 0, - input_nodes: 0, - output_nodes: 0, - }; - let mut max_depth_seen = 0; - - // An operation is a query, mutation or subscription consisting of a set of selections - for (count, (_name, oper)) in doc.operations.iter().enumerate() { - let sel_set = &oper.node.selection_set; - - // If the query is pure introspection, we don't need to check the limits. - // Pure introspection queries are queries that only have one operation with one field - // and that field is a `__schema` query - if (count == 0) && (sel_set.node.items.len() == 1) { - if let Some(node) = sel_set.node.items.first() { - if let Selection::Field(field) = &node.node { - if field.node.name.node == "__schema" { - continue; - } + // If the query is pure introspection, we don't need to check the limits. Pure introspection + // queries are queries that only have one operation with one field and that field is a + // `__schema` query + if let DocumentOperations::Single(op) = &doc.operations { + if let [field] = &op.node.selection_set.node.items[..] { + if let Selection::Field(f) = &field.node { + if f.node.name.node == "__schema" { + return Ok(doc); } } } - - running_costs.depth = 0; - analyze_selection_set( - &cfg.limits, - &doc.fragments, - sel_set, - &mut running_costs, - variables, - ctx, - )?; - max_depth_seen = max_depth_seen.max(running_costs.depth); } - if ctx.data_opt::().is_some() { - *self.validation_result.lock().await = Some(ValidationRes { - input_nodes: running_costs.input_nodes, - output_nodes: running_costs.output_nodes, - depth: running_costs.depth, - query_payload: query.len() as u32, - num_variables: variables.len() as u32, - num_fragments: doc.fragments.len() as u32, - }); - } - metrics.query_validation_latency(instant.elapsed()); - metrics - .request_metrics - .input_nodes - .observe(running_costs.input_nodes as f64); - metrics - .request_metrics - .output_nodes - .observe(running_costs.output_nodes as f64); - metrics - .request_metrics - .query_depth - .observe(running_costs.depth as f64); - metrics - .request_metrics - .query_payload_size - .observe(query.len() as f64); - Ok(doc) - } -} - -impl std::ops::Add for ComponentCost { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Self { - input_nodes: self.input_nodes + rhs.input_nodes, - output_nodes: self.output_nodes + rhs.output_nodes, - depth: self.depth + rhs.depth, - } - } -} - -/// Parse the selected fields in one operation and check if it conforms to configured limits. -fn analyze_selection_set( - limits: &Limits, - fragment_defs: &HashMap>, - sel_set: &Positioned, - cost: &mut ComponentCost, - variables: &Variables, - ctx: &ExtensionContext<'_>, -) -> ServerResult<()> { - // Use BFS to analyze the query and count the number of nodes and the depth of the query - struct ToVisit<'s> { - selection: &'s Positioned, - parent_node_count: u32, - } - - // Queue to store the nodes at each level - let mut que = VecDeque::new(); - - for selection in sel_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count: 1, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } - // Track the number of nodes at first level if any - let mut level_len = que.len(); - - while !que.is_empty() { - // Signifies the start of a new level - cost.depth += 1; - check_limits(limits, cost, None, ctx)?; - while level_len > 0 { - // Ok to unwrap since we checked for empty queue - // and level_len > 0 - let ToVisit { - selection, - parent_node_count, - } = que.pop_front().unwrap(); - - match &selection.node { - Selection::Field(f) => { - let current_count = - estimate_output_nodes_for_curr_node(f, variables, limits.default_page_size) - * parent_node_count; - - cost.output_nodes += current_count; - - for field_sel in f.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection: field_sel, - parent_node_count: current_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(field_sel.pos), ctx)?; - } - } - - Selection::FragmentSpread(fs) => { - let frag_name = &fs.node.fragment_name.node; - let frag_def = fragment_defs.get(frag_name).ok_or_else(|| { - graphql_error_at_pos( - code::INTERNAL_SERVER_ERROR, - format!( - "Fragment {} not found but present in fragment list", - frag_name - ), - fs.pos, - ) - })?; - - // TODO: this is inefficient as we might loop over same fragment multiple times - // Ideally web should cache the costs of fragments we've seen before - // Will do as enhancement - for selection in frag_def.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } - } + let mut traversal = LimitsTraversal::new(&cfg.limits, &doc.fragments, variables); + let res = traversal.check_document(&doc); + let usage = traversal.finish(query.len() as u32); + metrics.query_validation_latency(instant.elapsed()); + usage.report(metrics); - Selection::InlineFragment(fs) => { - for selection in fs.node.selection_set.node.items.iter() { - que.push_back(ToVisit { - selection, - parent_node_count, - }); - cost.input_nodes += 1; - check_limits(limits, cost, Some(selection.pos), ctx)?; - } - } + res.map(|()| { + if ctx.data_opt::().is_some() { + *self.usage.lock().unwrap() = Some(usage); } - level_len -= 1; - } - level_len = que.len(); - } - - Ok(()) -} - -fn check_limits( - limits: &Limits, - cost: &ComponentCost, - pos: Option, - ctx: &ExtensionContext<'_>, -) -> ServerResult<()> { - let query_id: &Uuid = ctx.data_unchecked(); - let session_id: &SocketAddr = ctx.data_unchecked(); - let error_code = code::BAD_USER_INPUT; - if cost.input_nodes > limits.max_query_nodes { - info!( - query_id = %query_id, - session_id = %session_id, - error_code, - "Query has too many nodes: {}", cost.input_nodes - ); - return Err(graphql_error_at_pos( - error_code, - format!( - "Query has too many nodes {}. The maximum allowed is {}", - cost.input_nodes, limits.max_query_nodes - ), - pos.unwrap_or_default(), - )); - } - if cost.depth > limits.max_query_depth { - info!( - query_id = %query_id, - session_id = %session_id, - error_code, - "Query has too many levels of nesting: {}", cost.depth - ); - return Err(graphql_error_at_pos( - error_code, - format!( - "Query has too many levels of nesting {}. The maximum allowed is {}", - cost.depth, limits.max_query_depth - ), - pos.unwrap_or_default(), - )); - } - - if cost.output_nodes > limits.max_output_nodes { - info!( - query_id = %query_id, - session_id = %session_id, - error_code, - "Query will result in too many output nodes: {}", - cost.output_nodes - ); - return Err(graphql_error_at_pos( - error_code, - format!( - "Query will result in too many output nodes. The maximum allowed is {}, estimated {}", - limits.max_output_nodes, cost.output_nodes - ), - pos.unwrap_or_default(), - )); - } - - Ok(()) -} - -/// Given a node, estimate the number of output nodes it will produce. -fn estimate_output_nodes_for_curr_node( - f: &Positioned, - variables: &Variables, - default_page_size: u32, -) -> u32 { - if !is_connection(f) { - 1 - } else { - // If the args 'first' or 'last' is set, then we should use that as the count - let first_arg = f.node.get_argument("first"); - let last_arg = f.node.get_argument("last"); - - extract_limit(first_arg, variables) - .or_else(|| extract_limit(last_arg, variables)) - .unwrap_or(default_page_size) - } -} - -/// Try to extract a u32 value from the given argument, or return None on failure. -fn extract_limit(value: Option<&Positioned>, variables: &Variables) -> Option { - if let GqlValue::Variable(var) = &value?.node { - return match variables.get(var) { - Some(Value::Number(num)) => num.as_u64().map(|v| v as u32), - _ => None, - }; - } - - let GqlValue::Number(value) = &value?.node else { - return None; - }; - value.as_u64().map(|v| v as u32) -} - -/// Checks if the given field is a connection field by whether it has 'edges' or 'nodes' selected. -/// This should typically not require checking more than the first element of the selection set -fn is_connection(f: &Positioned) -> bool { - for field_sel in f.node.selection_set.node.items.iter() { - if let Selection::Field(field) = &field_sel.node { - if CONNECTION_FIELDS.contains(&field.node.name.node.as_str()) { - return true; - } - } + doc + }) } - false } diff --git a/crates/sui-graphql-rpc/src/server/builder.rs b/crates/sui-graphql-rpc/src/server/builder.rs index 01f70dd20d390..b0afb65b57cd5 100644 --- a/crates/sui-graphql-rpc/src/server/builder.rs +++ b/crates/sui-graphql-rpc/src/server/builder.rs @@ -897,10 +897,7 @@ pub mod tests { .map(|e| e.message) .collect(); - assert_eq!( - errs, - vec!["Query has too many levels of nesting 1. The maximum allowed is 0".to_string()] - ); + assert_eq!(errs, vec!["Query nesting is over 0".to_string()]); let errs: Vec<_> = exec_query_depth_limit( 2, "{ chainIdentifier protocolConfig { configs { value key }} }", @@ -911,10 +908,7 @@ pub mod tests { .into_iter() .map(|e| e.message) .collect(); - assert_eq!( - errs, - vec!["Query has too many levels of nesting 3. The maximum allowed is 2".to_string()] - ); + assert_eq!(errs, vec!["Query nesting is over 2".to_string()]); } pub async fn test_query_node_limit_impl() { @@ -954,10 +948,7 @@ pub mod tests { .into_iter() .map(|e| e.message) .collect(); - assert_eq!( - err, - vec!["Query has too many nodes 1. The maximum allowed is 0".to_string()] - ); + assert_eq!(err, vec!["Query has over 0 nodes".to_string()]); let err: Vec<_> = exec_query_node_limit( 4, @@ -969,10 +960,7 @@ pub mod tests { .into_iter() .map(|e| e.message) .collect(); - assert_eq!( - err, - vec!["Query has too many nodes 5. The maximum allowed is 4".to_string()] - ); + assert_eq!(err, vec!["Query has over 4 nodes".to_string()]); } pub async fn test_query_default_page_limit_impl(connection_config: ConnectionConfig) {