diff --git a/CHANGELOG.md b/CHANGELOG.md index f6919a3..a82116e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Added `CloudflareCacheBackend`. - Added `DynamoDBCacheBackend`. +- Changed `QueryFilter` and `root_resolver` to split variables between schemas. ## 0.1.0 (2023-06-13) diff --git a/ariadne_graphql_proxy/proxy_schema.py b/ariadne_graphql_proxy/proxy_schema.py index 2f65717..bedc458 100644 --- a/ariadne_graphql_proxy/proxy_schema.py +++ b/ariadne_graphql_proxy/proxy_schema.py @@ -216,10 +216,12 @@ async def root_resolver( { "operationName": operation_name, "query": print_ast(query_document), - "variables": variables, + "variables": variables + if not variables + else {key: variables[key] for key in query_variables}, }, ) - for schema_id, query_document in queries + for schema_id, query_document, query_variables in queries if self.urls[schema_id] ] ) diff --git a/ariadne_graphql_proxy/query_filter.py b/ariadne_graphql_proxy/query_filter.py index 54b9261..e2a5cb0 100644 --- a/ariadne_graphql_proxy/query_filter.py +++ b/ariadne_graphql_proxy/query_filter.py @@ -3,24 +3,27 @@ from graphql import ( DocumentNode, FieldNode, - FragmentSpreadNode, FragmentDefinitionNode, + FragmentSpreadNode, GraphQLSchema, InlineFragmentNode, NameNode, OperationDefinitionNode, SelectionNode, SelectionSetNode, + VariableNode, ) class QueryFilterContext: schema_id: int fragments: Dict[str, FragmentDefinitionNode] + variables: Set[str] def __init__(self, schema_id: int): self.schema_id = schema_id self.fragments = {} + self.variables = set() class QueryFilter: @@ -38,21 +41,25 @@ def __init__( self.fields_types = fields_types self.foreign_keys = foreign_keys - def split_query(self, document: DocumentNode) -> List[Tuple[int, DocumentNode]]: - queries: List[Tuple[int, DocumentNode]] = [] + def split_query( + self, document: DocumentNode + ) -> List[Tuple[int, DocumentNode, Set[str]]]: + queries: List[Tuple[int, DocumentNode, Set[str]]] = [] for schema_id in range(len(self.schemas)): - schema_query = self.get_schema_query(schema_id, document) + schema_query, used_variables = self.get_schema_query_with_used_variables( + schema_id, document + ) if schema_query: - queries.append((schema_id, schema_query)) + queries.append((schema_id, schema_query, used_variables)) return queries - def get_schema_query( + def get_schema_query_with_used_variables( self, schema_id: int, document: DocumentNode, - ) -> Optional[DocumentNode]: + ) -> Tuple[Optional[DocumentNode], Set[str]]: context = QueryFilterContext(schema_id) definitions = [] @@ -72,9 +79,9 @@ def get_schema_query( definitions.append(new_operation) if not definitions: - return None + return None, context.variables - return DocumentNode(definitions=tuple(definitions)) + return DocumentNode(definitions=tuple(definitions)), context.variables def filter_operation_node( self, @@ -117,12 +124,18 @@ def filter_operation_node( if not new_selections: return None + used_variable_definitions = [ + variable_definition + for variable_definition in operation_node.variable_definitions + if variable_definition.variable.name.value in context.variables + ] + return OperationDefinitionNode( loc=operation_node.loc, operation=operation_node.operation, name=operation_node.name, directives=operation_node.directives, - variable_definitions=operation_node.variable_definitions, + variable_definitions=tuple(used_variable_definitions), selection_set=SelectionSetNode( selections=tuple(new_selections), ), @@ -134,6 +147,12 @@ def filter_field_node( schema_obj: str, context: QueryFilterContext, ) -> Optional[FieldNode]: + context.variables.update( + argument.value.name.value + for argument in field_node.arguments + if isinstance(argument.value, VariableNode) + ) + if not field_node.selection_set: return field_node diff --git a/tests/test_proxy_schema.py b/tests/test_proxy_schema.py index 0bee668..dd5127c 100644 --- a/tests/test_proxy_schema.py +++ b/tests/test_proxy_schema.py @@ -603,3 +603,169 @@ async def test_proxy_schema_unpacks_fragment_in_query( """ ).strip(), } + + +@pytest.mark.asyncio +async def test_proxy_schema_splits_variables_between_schemas( + httpx_mock, + search_schema_json, + store_schema_json, + search_root_value, + store_root_value, +): + httpx_mock.add_response( + url="http://graphql.example.com/search/", json=search_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/store/", json=store_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/search/", + json={"data": search_root_value}, + ) + httpx_mock.add_response( + url="http://graphql.example.com/store/", + json={"data": store_root_value}, + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/search/") + proxy_schema.add_remote_schema("http://graphql.example.com/store/") + proxy_schema.get_final_schema() + + await proxy_schema.root_resolver( + {}, + "TestQuery", + {"searchQuery": "test", "orderId": "testId"}, + parse( + """ + query TestQuery($searchQuery: String!, $orderId: ID!) { + search(query: $searchQuery) { + id + } + order(id: $orderId) { + id + } + } + """ + ), + ) + + search_request = httpx_mock.get_requests(url="http://graphql.example.com/search/")[ + -1 + ] + store_request = httpx_mock.get_requests(url="http://graphql.example.com/store/")[-1] + + assert json.loads(search_request.content) == { + "operationName": "TestQuery", + "variables": {"searchQuery": "test"}, + "query": dedent( + """ + query TestQuery($searchQuery: String!) { + search(query: $searchQuery) { + id + } + } + """ + ).strip(), + } + assert json.loads(store_request.content) == { + "operationName": "TestQuery", + "variables": {"orderId": "testId"}, + "query": dedent( + """ + query TestQuery($orderId: ID!) { + order(id: $orderId) { + id + } + } + """ + ).strip(), + } + + +@pytest.mark.asyncio +async def test_proxy_schema_splits_variables_from_fragments_between_schemas( + httpx_mock, + search_schema_json, + store_schema_json, + search_root_value, + store_root_value, +): + httpx_mock.add_response( + url="http://graphql.example.com/search/", json=search_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/store/", json=store_schema_json + ) + httpx_mock.add_response( + url="http://graphql.example.com/search/", + json={"data": search_root_value}, + ) + httpx_mock.add_response( + url="http://graphql.example.com/store/", + json={"data": store_root_value}, + ) + + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/search/") + proxy_schema.add_remote_schema("http://graphql.example.com/store/") + proxy_schema.get_final_schema() + + await proxy_schema.root_resolver( + {}, + "TestQuery", + {"searchQuery": "test", "orderId": "testId"}, + parse( + """ + query TestQuery($searchQuery: String!, $orderId: ID!) { + ...searchFragment + ...orderFragment + } + + fragment searchFragment on Query { + search(query: $searchQuery) { + id + } + } + + fragment orderFragment on Query { + order(id: $orderId) { + id + } + } + """ + ), + ) + + search_request = httpx_mock.get_requests(url="http://graphql.example.com/search/")[ + -1 + ] + store_request = httpx_mock.get_requests(url="http://graphql.example.com/store/")[-1] + + assert json.loads(search_request.content) == { + "operationName": "TestQuery", + "variables": {"searchQuery": "test"}, + "query": dedent( + """ + query TestQuery($searchQuery: String!) { + search(query: $searchQuery) { + id + } + } + """ + ).strip(), + } + assert json.loads(store_request.content) == { + "operationName": "TestQuery", + "variables": {"orderId": "testId"}, + "query": dedent( + """ + query TestQuery($orderId: ID!) { + order(id: $orderId) { + id + } + } + """ + ).strip(), + }