diff --git a/GUIDE.md b/GUIDE.md index 7849585..b0aaab1 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -572,6 +572,71 @@ If `proxy_headers` is a callable, it will be called with single argument (`conte If `proxy_headers` is `None` or `False`, no headers are proxied to the other service. +## Fields dependencies + +In situations where field depends on data from sibling fields in order to be resolved, `ProxySchema` can be configured to include those additional fields in root value query sent to remote schema. + +Below example pulls a remote schema that defines `Product` type, extends this type with `image: String` field, and then uses `ProxySchema.add_field_dependencies` to configure `{ metadata { thumb} }` as additional fields to retrieve when `image` field is queried. It also includes custom resolver for `image` field that uses this additional data: + + +```python +from ariadne.asgi import GraphQL +from ariadne_graphql_proxy import ( + ProxySchema, + get_context_value, + set_resolver, +) +from graphql import build_ast_schema, parse + + +proxy_schema = ProxySchema() + +# Store schema ID for remote schema +remote_schema_id = proxy_schema.add_remote_schema( + "https://example.com/graphql/", +) + +# Extend Product type with additional image field +proxy_schema.add_schema( + build_ast_schema( + parse( + """ + type Product { + image: String + } + """ + ) + ) +) + +# Configure proxy schema to retrieve thumb from metadata +# from remote schema when image is queried +proxy_schema.add_field_dependencies( + remote_schema_id, "Product", "image", "{ metadata { thumb } }" +) + +# Create schema instance +final_schema = proxy_schema.get_final_schema() + + +# Add product image resolver +def resolve_product_image(obj, info): + return obj["metadata"]["thumb"] + + +set_resolver(final_schema, "Product", "image", resolve_product_image) + + +# Setup Ariadne ASGI GraphQL application +app = GraphQL( + final_schema, + context_value=get_context_value, + root_value=proxy_schema.root_resolver, + debug=True, +) +``` + + ## Cache framework Ariadne GraphQL Proxy implements basic cache framework that enables of caching parts of GraphQL queries. @@ -855,6 +920,34 @@ def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]): Sets specific fields in schema as delayed. Delayed fields are excluded from queries ran by `root_resolver` against the remote GraphQL APIs. +#### `delayed_fields` + +This is a dict of type name and fields names lists: + +```python +{"Type": ["field", "otherField"], "OtherType": ["field"]} +``` + + +### `add_field_dependencies` + +```python +def add_field_dependencies( + self, schema_id: int, type_name: str, field_name: str, query: str +): +``` + +Adds fields specified in `query` as dependencies for `field_name` of `type_name` that should be retrieved from schema with `schema_id`. + + +#### Required arguments + +- `schema_id`: an `int` with ID of schema returned by `add_remote_schema` or `add_schema`. +- `type_name`: a `str` with name of type for which dependencies will be set. +- `field_name`: a `str` with name of field which dependencies will be set. +- `query`: a `str` with additional fields to fetch when `field_name` is included, eg. `{ metadata { key value} }`. + + ### `add_foreign_key` ```python diff --git a/ariadne_graphql_proxy/proxy_schema.py b/ariadne_graphql_proxy/proxy_schema.py index bcb803c..587b273 100644 --- a/ariadne_graphql_proxy/proxy_schema.py +++ b/ariadne_graphql_proxy/proxy_schema.py @@ -227,10 +227,8 @@ def add_field_dependencies( f"Type '{type_name}' in schema with ID '{schema_id}' is not " "an object type." ) - if field_name not in schema_type.fields: - raise ValueError( - f"Type '{type_name}' doesn't define the '{field_name}' field." - ) + + self.validate_field_with_dependencies(type_name, field_name) if schema_id not in self.dependencies: self.dependencies[schema_id] = {} @@ -248,7 +246,14 @@ def add_field_dependencies( ) def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetNode: - ast = parse(query) + clean_query = query.strip() + if not clean_query.startswith("{") or not clean_query.endswith("}"): + raise ValueError( + f"'{field_name}' field dependencies should be defined as a single " + "GraphQL operation, e.g.: '{ field other { subfield } }'." + ) + + ast = parse(clean_query) if ( not len(ast.definitions) == 1 @@ -262,6 +267,19 @@ def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetN return ast.definitions[0].selection_set + def validate_field_with_dependencies(self, type_name: str, field_name: str) -> None: + for schema in self.schemas: + if ( + type_name in schema.type_map + and isinstance(schema.type_map[type_name], GraphQLObjectType) + and field_name in schema.type_map[type_name].fields + ): + return + + raise ValueError( + f"Type '{type_name}' doesn't define the '{field_name}' field in any of schemas." + ) + def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]): for type_name, type_fields in delayed_fields.items(): if type_name not in self.fields_map: diff --git a/ariadne_graphql_proxy/selections.py b/ariadne_graphql_proxy/selections.py new file mode 100644 index 0000000..875d7aa --- /dev/null +++ b/ariadne_graphql_proxy/selections.py @@ -0,0 +1,53 @@ +from typing import Dict, Sequence, List, cast + +from graphql import FieldNode, SelectionNode, SelectionSetNode + + +def merge_selection_sets( + set_a: SelectionSetNode, set_b: SelectionSetNode +) -> SelectionSetNode: + return SelectionSetNode( + selections=tuple(merge_selections(set_a.selections, set_b.selections)), + ) + + +def merge_selections( + set_a: Sequence[SelectionNode], set_b: Sequence[SelectionNode] +) -> List[SelectionNode]: + final_set: List[SelectionNode] = list(set_a) + + index: Dict[str, int] = {} + for i, field in enumerate(final_set): + if isinstance(field, FieldNode): + index[(field.alias or field.name).value] = i + + for field in set_b: + if isinstance(field, FieldNode): + field_name = (field.alias or field.name).value + if field_name in index: + field_index = index[field_name] + other_field = cast(FieldNode, final_set[field_index]) + if other_field.selection_set and field.selection_set: + final_set[field_index] = FieldNode( + directives=other_field.directives, + alias=other_field.alias, + name=field.name, + arguments=other_field.arguments, + selection_set=merge_selection_sets( + other_field.selection_set, field.selection_set + ), + ) + elif other_field.selection_set or field.selection_set: + final_set[field_index] = FieldNode( + directives=other_field.directives, + alias=other_field.alias, + name=field.name, + arguments=other_field.arguments, + selection_set=( + other_field.selection_set or field.selection_set + ), + ) + else: + final_set.append(field) + + return final_set diff --git a/tests/test_merge_selection_sets.py b/tests/test_merge_selection_sets.py new file mode 100644 index 0000000..c091cb4 --- /dev/null +++ b/tests/test_merge_selection_sets.py @@ -0,0 +1,82 @@ +from textwrap import dedent + +from graphql import parse, print_ast + +from ariadne_graphql_proxy import merge_selection_sets + + +def test_merge_selection_sets_merges_two_flat_sets(): + set_a = parse("{ hello }").definitions[0].selection_set + set_b = parse("{ world }").definitions[0].selection_set + + result = merge_selection_sets(set_a, set_b) + assert ( + print_ast(result) + == dedent( + """ + { + hello + world + } + """ + ).strip() + ) + + +def test_merge_selection_sets_merges_two_overlapping_flat_sets(): + set_a = parse("{ hello world }").definitions[0].selection_set + set_b = parse("{ world }").definitions[0].selection_set + + result = merge_selection_sets(set_a, set_b) + assert ( + print_ast(result) + == dedent( + """ + { + hello + world + } + """ + ).strip() + ) + + +def test_merge_selection_sets_keeps_nested_selections(): + set_a = parse("{ hello { sub } }").definitions[0].selection_set + set_b = parse("{ world }").definitions[0].selection_set + + result = merge_selection_sets(set_a, set_b) + assert ( + print_ast(result) + == dedent( + """ + { + hello { + sub + } + world + } + """ + ).strip() + ) + + +def test_merge_selection_sets_merges_selection_sets_recursively(): + set_a = parse("{ hello { sub } }").definitions[0].selection_set + set_b = parse("{ hello { set } world }").definitions[0].selection_set + + result = merge_selection_sets(set_a, set_b) + assert ( + print_ast(result) + == dedent( + """ + { + hello { + sub + set + } + world + } + """ + ).strip() + ) diff --git a/tests/test_merge_selections.py b/tests/test_merge_selections.py new file mode 100644 index 0000000..cfe6253 --- /dev/null +++ b/tests/test_merge_selections.py @@ -0,0 +1,82 @@ +from textwrap import dedent + +from graphql import SelectionSetNode, parse, print_ast + +from ariadne_graphql_proxy import merge_selections + + +def test_merge_selections_merges_two_flat_sets(): + set_a = parse("{ hello }").definitions[0].selection_set.selections + set_b = parse("{ world }").definitions[0].selection_set.selections + + result = merge_selections(set_a, set_b) + assert ( + print_ast(SelectionSetNode(selections=result)) + == dedent( + """ + { + hello + world + } + """ + ).strip() + ) + + +def test_merge_selections_merges_two_overlapping_flat_sets(): + set_a = parse("{ hello world }").definitions[0].selection_set.selections + set_b = parse("{ world }").definitions[0].selection_set.selections + + result = merge_selections(set_a, set_b) + assert ( + print_ast(SelectionSetNode(selections=result)) + == dedent( + """ + { + hello + world + } + """ + ).strip() + ) + + +def test_merge_selections_keeps_nested_selections(): + set_a = parse("{ hello { sub } }").definitions[0].selection_set.selections + set_b = parse("{ world }").definitions[0].selection_set.selections + + result = merge_selections(set_a, set_b) + assert ( + print_ast(SelectionSetNode(selections=result)) + == dedent( + """ + { + hello { + sub + } + world + } + """ + ).strip() + ) + + +def test_merge_selections_merges_selection_sets_recursively(): + set_a = parse("{ hello { sub } }").definitions[0].selection_set.selections + set_b = parse("{ hello { set } world }").definitions[0].selection_set.selections + + result = merge_selections(set_a, set_b) + assert ( + print_ast(SelectionSetNode(selections=result)) + == dedent( + """ + { + hello { + sub + set + } + world + } + """ + ).strip() + ) diff --git a/tests/test_proxy_schema.py b/tests/test_proxy_schema.py index 14d8a27..068b282 100644 --- a/tests/test_proxy_schema.py +++ b/tests/test_proxy_schema.py @@ -1265,6 +1265,36 @@ async def test_root_value_for_remote_schema_excludes_extensions( } +@pytest.mark.asyncio +async def test_add_field_dependencies_for_nonexisting_schema_raises_error( + httpx_mock, schema_json +): + httpx_mock.add_response(json=schema_json) + + proxy_schema = ProxySchema() + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") + + with pytest.raises(ValueError) as exc_info: + proxy_schema.add_field_dependencies( + schema_id + 1, "Complex", "invalid", "{ group { name } }" + ) + + assert "Schema with ID '1' doesn't exist." == str(exc_info.value) + + +@pytest.mark.asyncio +async def test_add_field_dependencies_for_local_schema_raises_error(schema): + proxy_schema = ProxySchema() + schema_id = proxy_schema.add_schema(schema) + + with pytest.raises(ValueError) as exc_info: + proxy_schema.add_field_dependencies( + schema_id, "Complex", "invalid", "{ group { name } }" + ) + + assert "Schema with ID '0' is not a remote schema." == str(exc_info.value) + + @pytest.mark.asyncio async def test_add_field_dependencies_for_query_field_raises_error( httpx_mock, schema_json @@ -1272,9 +1302,7 @@ async def test_add_field_dependencies_for_query_field_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies(schema_id, "Query", "basic", "{ complex }") @@ -1291,9 +1319,7 @@ async def test_add_field_dependencies_for_mutation_field_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies( @@ -1312,9 +1338,7 @@ async def test_add_field_dependencies_for_subscription_field_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies( @@ -1333,9 +1357,7 @@ async def test_add_field_dependencies_for_nonexisting_type_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies( @@ -1352,9 +1374,7 @@ async def test_add_field_dependencies_for_invalid_type_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies( @@ -1373,16 +1393,57 @@ async def test_add_field_dependencies_for_nonexisting_type_field_raises_error( httpx_mock.add_response(json=schema_json) proxy_schema = ProxySchema() - schema_id = proxy_schema.add_remote_schema( - "http://graphql.example.com/", proxy_extensions=False - ) + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") with pytest.raises(ValueError) as exc_info: proxy_schema.add_field_dependencies( schema_id, "Complex", "invalid", "{ group { name } }" ) - assert "Type 'Complex' doesn't define the 'invalid' field." == str(exc_info.value) + assert ( + "Type 'Complex' doesn't define the 'invalid' field in any of schemas." + == str(exc_info.value) + ) + + +@pytest.mark.asyncio +async def test_add_field_dependencies_with_invalid_dependencies_arg_raises_error( + httpx_mock, schema_json +): + httpx_mock.add_response(url="http://graphql.example.com/", json=schema_json) + + proxy_schema = ProxySchema() + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") + + with pytest.raises(ValueError) as exc_info: + proxy_schema.add_field_dependencies( + schema_id, "Complex", "class", "group { id }" + ) + + assert ( + "'class' field dependencies should be defined as a single GraphQL " + "operation, e.g.: '{ field other { subfield } }'." + ) == str(exc_info.value) + + +@pytest.mark.asyncio +async def test_add_field_dependencies_with_invalid_dependencies_arg_op_raises_error( + httpx_mock, schema_json +): + httpx_mock.add_response(url="http://graphql.example.com/", json=schema_json) + + proxy_schema = ProxySchema() + schema_id = proxy_schema.add_remote_schema("http://graphql.example.com/") + + with pytest.raises(ValueError) as exc_info: + proxy_schema.add_field_dependencies( + schema_id, "Complex", "class", "mutation { group { id } }" + ) + + assert ( + "'class' field dependencies should be defined as a single GraphQL " + "operation, e.g.: '{ field other { subfield } }'." + ) == str(exc_info.value) @pytest.mark.asyncio