Skip to content

Commit

Permalink
Merge pull request #38 from mirumee/split_variables_query_filter
Browse files Browse the repository at this point in the history
Split variables in query filter
  • Loading branch information
mat-sop authored Sep 1, 2023
2 parents f896aef + 4997a22 commit 83f314f
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Added `CloudflareCacheBackend`.
- Added `DynamoDBCacheBackend`.
- Changed `QueryFilter` and `root_resolver` to split variables between schemas.


## 0.1.0 (2023-06-13)
Expand Down
6 changes: 4 additions & 2 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ async def root_resolver(
{
"operationName": operation_name,
"query": print_ast(query_document),
"variables": variables,
"variables": variables
if not variables
else {key: variables[key] for key in query_variables},
},
)
for schema_id, query_document in queries
for schema_id, query_document, query_variables in queries
if self.urls[schema_id]
]
)
Expand Down
39 changes: 29 additions & 10 deletions ariadne_graphql_proxy/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
from graphql import (
DocumentNode,
FieldNode,
FragmentSpreadNode,
FragmentDefinitionNode,
FragmentSpreadNode,
GraphQLSchema,
InlineFragmentNode,
NameNode,
OperationDefinitionNode,
SelectionNode,
SelectionSetNode,
VariableNode,
)


class QueryFilterContext:
schema_id: int
fragments: Dict[str, FragmentDefinitionNode]
variables: Set[str]

def __init__(self, schema_id: int):
self.schema_id = schema_id
self.fragments = {}
self.variables = set()


class QueryFilter:
Expand All @@ -38,21 +41,25 @@ def __init__(
self.fields_types = fields_types
self.foreign_keys = foreign_keys

def split_query(self, document: DocumentNode) -> List[Tuple[int, DocumentNode]]:
queries: List[Tuple[int, DocumentNode]] = []
def split_query(
self, document: DocumentNode
) -> List[Tuple[int, DocumentNode, Set[str]]]:
queries: List[Tuple[int, DocumentNode, Set[str]]] = []

for schema_id in range(len(self.schemas)):
schema_query = self.get_schema_query(schema_id, document)
schema_query, used_variables = self.get_schema_query_with_used_variables(
schema_id, document
)
if schema_query:
queries.append((schema_id, schema_query))
queries.append((schema_id, schema_query, used_variables))

return queries

def get_schema_query(
def get_schema_query_with_used_variables(
self,
schema_id: int,
document: DocumentNode,
) -> Optional[DocumentNode]:
) -> Tuple[Optional[DocumentNode], Set[str]]:
context = QueryFilterContext(schema_id)
definitions = []

Expand All @@ -72,9 +79,9 @@ def get_schema_query(
definitions.append(new_operation)

if not definitions:
return None
return None, context.variables

return DocumentNode(definitions=tuple(definitions))
return DocumentNode(definitions=tuple(definitions)), context.variables

def filter_operation_node(
self,
Expand Down Expand Up @@ -117,12 +124,18 @@ def filter_operation_node(
if not new_selections:
return None

used_variable_definitions = [
variable_definition
for variable_definition in operation_node.variable_definitions
if variable_definition.variable.name.value in context.variables
]

return OperationDefinitionNode(
loc=operation_node.loc,
operation=operation_node.operation,
name=operation_node.name,
directives=operation_node.directives,
variable_definitions=operation_node.variable_definitions,
variable_definitions=tuple(used_variable_definitions),
selection_set=SelectionSetNode(
selections=tuple(new_selections),
),
Expand All @@ -134,6 +147,12 @@ def filter_field_node(
schema_obj: str,
context: QueryFilterContext,
) -> Optional[FieldNode]:
context.variables.update(
argument.value.name.value
for argument in field_node.arguments
if isinstance(argument.value, VariableNode)
)

if not field_node.selection_set:
return field_node

Expand Down
166 changes: 166 additions & 0 deletions tests/test_proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,3 +603,169 @@ async def test_proxy_schema_unpacks_fragment_in_query(
"""
).strip(),
}


@pytest.mark.asyncio
async def test_proxy_schema_splits_variables_between_schemas(
httpx_mock,
search_schema_json,
store_schema_json,
search_root_value,
store_root_value,
):
httpx_mock.add_response(
url="http://graphql.example.com/search/", json=search_schema_json
)
httpx_mock.add_response(
url="http://graphql.example.com/store/", json=store_schema_json
)
httpx_mock.add_response(
url="http://graphql.example.com/search/",
json={"data": search_root_value},
)
httpx_mock.add_response(
url="http://graphql.example.com/store/",
json={"data": store_root_value},
)

proxy_schema = ProxySchema()
proxy_schema.add_remote_schema("http://graphql.example.com/search/")
proxy_schema.add_remote_schema("http://graphql.example.com/store/")
proxy_schema.get_final_schema()

await proxy_schema.root_resolver(
{},
"TestQuery",
{"searchQuery": "test", "orderId": "testId"},
parse(
"""
query TestQuery($searchQuery: String!, $orderId: ID!) {
search(query: $searchQuery) {
id
}
order(id: $orderId) {
id
}
}
"""
),
)

search_request = httpx_mock.get_requests(url="http://graphql.example.com/search/")[
-1
]
store_request = httpx_mock.get_requests(url="http://graphql.example.com/store/")[-1]

assert json.loads(search_request.content) == {
"operationName": "TestQuery",
"variables": {"searchQuery": "test"},
"query": dedent(
"""
query TestQuery($searchQuery: String!) {
search(query: $searchQuery) {
id
}
}
"""
).strip(),
}
assert json.loads(store_request.content) == {
"operationName": "TestQuery",
"variables": {"orderId": "testId"},
"query": dedent(
"""
query TestQuery($orderId: ID!) {
order(id: $orderId) {
id
}
}
"""
).strip(),
}


@pytest.mark.asyncio
async def test_proxy_schema_splits_variables_from_fragments_between_schemas(
httpx_mock,
search_schema_json,
store_schema_json,
search_root_value,
store_root_value,
):
httpx_mock.add_response(
url="http://graphql.example.com/search/", json=search_schema_json
)
httpx_mock.add_response(
url="http://graphql.example.com/store/", json=store_schema_json
)
httpx_mock.add_response(
url="http://graphql.example.com/search/",
json={"data": search_root_value},
)
httpx_mock.add_response(
url="http://graphql.example.com/store/",
json={"data": store_root_value},
)

proxy_schema = ProxySchema()
proxy_schema.add_remote_schema("http://graphql.example.com/search/")
proxy_schema.add_remote_schema("http://graphql.example.com/store/")
proxy_schema.get_final_schema()

await proxy_schema.root_resolver(
{},
"TestQuery",
{"searchQuery": "test", "orderId": "testId"},
parse(
"""
query TestQuery($searchQuery: String!, $orderId: ID!) {
...searchFragment
...orderFragment
}
fragment searchFragment on Query {
search(query: $searchQuery) {
id
}
}
fragment orderFragment on Query {
order(id: $orderId) {
id
}
}
"""
),
)

search_request = httpx_mock.get_requests(url="http://graphql.example.com/search/")[
-1
]
store_request = httpx_mock.get_requests(url="http://graphql.example.com/store/")[-1]

assert json.loads(search_request.content) == {
"operationName": "TestQuery",
"variables": {"searchQuery": "test"},
"query": dedent(
"""
query TestQuery($searchQuery: String!) {
search(query: $searchQuery) {
id
}
}
"""
).strip(),
}
assert json.loads(store_request.content) == {
"operationName": "TestQuery",
"variables": {"orderId": "testId"},
"query": dedent(
"""
query TestQuery($orderId: ID!) {
order(id: $orderId) {
id
}
}
"""
).strip(),
}

0 comments on commit 83f314f

Please sign in to comment.