Skip to content

Commit

Permalink
fix: Add graphql max depth and aliases limits (#955)
Browse files Browse the repository at this point in the history
  • Loading branch information
suejung-sentry authored Nov 6, 2024
1 parent 0273319 commit 63124e2
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 2 deletions.
4 changes: 4 additions & 0 deletions codecov/settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@

GRAPHQL_INTROSPECTION_ENABLED = False

GRAPHQL_MAX_DEPTH = get_config("setup", "graphql", "max_depth", default=20)

GRAPHQL_MAX_ALIASES = get_config("setup", "graphql", "max_aliases", default=10)

# Database
# https://docs.djangoproject.com/en/2.1/ref/settings/#databases

Expand Down
100 changes: 100 additions & 0 deletions graphql_api/tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from graphql import (
GraphQLField,
GraphQLObjectType,
GraphQLSchema,
GraphQLString,
parse,
validate,
)

from ..validation import (
create_max_aliases_rule,
create_max_depth_rule,
)


def resolve_field(*args):
return "test"


QueryType = GraphQLObjectType(
"Query", {"field": GraphQLField(GraphQLString, resolve=resolve_field)}
)
schema = GraphQLSchema(query=QueryType)


def validate_query(query, *rules):
ast = parse(query)
return validate(schema, ast, rules=rules)


def test_max_depth_rule_allows_within_depth():
query = """
query {
field
}
"""
errors = validate_query(query, create_max_depth_rule(2))
assert not errors, "Expected no errors for depth within the limit"


def test_max_depth_rule_rejects_exceeding_depth():
query = """
query {
field {
field {
field
}
}
}
"""
errors = validate_query(query, create_max_depth_rule(2))
assert errors, "Expected errors for exceeding depth limit"
assert any(
"Query depth exceeds the maximum allowed depth" in str(e) for e in errors
)


def test_max_depth_rule_exact_depth():
query = """
query {
field
}
"""
errors = validate_query(query, create_max_depth_rule(2))
assert not errors, "Expected no errors when query depth matches the limit"


def test_max_aliases_rule_allows_within_alias_limit():
query = """
query {
alias1: field
alias2: field
}
"""
errors = validate_query(query, create_max_aliases_rule(2))
assert not errors, "Expected no errors for alias count within the limit"


def test_max_aliases_rule_rejects_exceeding_alias_limit():
query = """
query {
alias1: field
alias2: field
alias3: field
}
"""
errors = validate_query(query, create_max_aliases_rule(2))
assert errors, "Expected errors for exceeding alias limit"
assert any("Query uses too many aliases" in str(e) for e in errors)


def test_max_aliases_rule_exact_alias_limit():
query = """
query {
alias1: field
alias2: field
}
"""
errors = validate_query(query, create_max_aliases_rule(2))
assert not errors, "Expected no errors when alias count matches the limit"
65 changes: 65 additions & 0 deletions graphql_api/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any, Type

from graphql import GraphQLError, ValidationRule
from graphql.language.ast import DocumentNode, FieldNode, OperationDefinitionNode
from graphql.validation import ValidationContext


def create_max_depth_rule(max_depth: int) -> Type[ValidationRule]:
class MaxDepthRule(ValidationRule):
def __init__(self, context: ValidationContext) -> None:
super().__init__(context)
self.operation_depth: int = 1
self.max_depth_reached: bool = False
self.max_depth: int = max_depth

def enter_operation_definition(
self, node: OperationDefinitionNode, *_args: Any
) -> None:
self.operation_depth = 1
self.max_depth_reached = False

def enter_field(self, node: FieldNode, *_args: Any) -> None:
self.operation_depth += 1

if self.operation_depth > self.max_depth and not self.max_depth_reached:
self.max_depth_reached = True
self.report_error(
GraphQLError(
"Query depth exceeds the maximum allowed depth",
node,
)
)

def leave_field(self, node: FieldNode, *_args: Any) -> None:
self.operation_depth -= 1

return MaxDepthRule


def create_max_aliases_rule(max_aliases: int) -> Type[ValidationRule]:
class MaxAliasesRule(ValidationRule):
def __init__(self, context: ValidationContext) -> None:
super().__init__(context)
self.alias_count: int = 0
self.has_reported_error: bool = False
self.max_aliases: int = max_aliases

def enter_document(self, node: DocumentNode, *_args: Any) -> None:
self.alias_count = 0
self.has_reported_error = False

def enter_field(self, node: FieldNode, *_args: Any) -> None:
if node.alias:
self.alias_count += 1

if self.alias_count > self.max_aliases and not self.has_reported_error:
self.has_reported_error = True
self.report_error(
GraphQLError(
"Query uses too many aliases",
node,
)
)

return MaxAliasesRule
7 changes: 5 additions & 2 deletions graphql_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from services.redis_configuration import get_redis_connection

from .schema import schema
from .validation import create_max_aliases_rule, create_max_depth_rule

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -188,7 +189,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
class AsyncGraphqlView(GraphQLAsyncView):
schema = schema
extensions = [QueryMetricsExtension]
introspection = getattr(settings, "GRAPHQL_INTROSPECTION_ENABLED", False)
introspection = settings.GRAPHQL_INTROSPECTION_ENABLED

def get_validation_rules(
self,
Expand All @@ -197,11 +198,13 @@ def get_validation_rules(
data: dict,
) -> Optional[Collection]:
return [
create_max_aliases_rule(max_aliases=settings.GRAPHQL_MAX_ALIASES),
create_max_depth_rule(max_depth=settings.GRAPHQL_MAX_DEPTH),
cost_validator(
maximum_cost=settings.GRAPHQL_QUERY_COST_THRESHOLD,
default_cost=1,
variables=data.get("variables"),
)
),
]

validation_rules = get_validation_rules # type: ignore
Expand Down

0 comments on commit 63124e2

Please sign in to comment.