Skip to content

Commit

Permalink
Add guide, improve args validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Mar 20, 2024
1 parent fa2287d commit a2fd2e9
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 24 deletions.
93 changes: 93 additions & 0 deletions GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,71 @@ If `proxy_headers` is a callable, it will be called with single argument (`conte
If `proxy_headers` is `None` or `False`, no headers are proxied to the other service.


## Fields dependencies

In situations where field depends on data from sibling fields in order to be resolved, `ProxySchema` can be configured to include those additional fields in root value query sent to remote schema.

Below example pulls a remote schema that defines `Product` type, extends this type with `image: String` field, and then uses `ProxySchema.add_field_dependencies` to configure `{ metadata { thumb} }` as additional fields to retrieve when `image` field is queried. It also includes custom resolver for `image` field that uses this additional data:


```python
from ariadne.asgi import GraphQL
from ariadne_graphql_proxy import (
ProxySchema,
get_context_value,
set_resolver,
)
from graphql import build_ast_schema, parse


proxy_schema = ProxySchema()

# Store schema ID for remote schema
remote_schema_id = proxy_schema.add_remote_schema(
"https://example.com/graphql/",
)

# Extend Product type with additional image field
proxy_schema.add_schema(
build_ast_schema(
parse(
"""
type Product {
image: String
}
"""
)
)
)

# Configure proxy schema to retrieve thumb from metadata
# from remote schema when image is queried
proxy_schema.add_field_dependencies(
remote_schema_id, "Product", "image", "{ metadata { thumb } }"
)

# Create schema instance
final_schema = proxy_schema.get_final_schema()


# Add product image resolver
def resolve_product_image(obj, info):
return obj["metadata"]["thumb"]


set_resolver(final_schema, "Product", "image", resolve_product_image)


# Setup Ariadne ASGI GraphQL application
app = GraphQL(
final_schema,
context_value=get_context_value,
root_value=proxy_schema.root_resolver,
debug=True,
)
```


## Cache framework

Ariadne GraphQL Proxy implements basic cache framework that enables of caching parts of GraphQL queries.
Expand Down Expand Up @@ -855,6 +920,34 @@ def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
Sets specific fields in schema as delayed. Delayed fields are excluded from queries ran by `root_resolver` against the remote GraphQL APIs.


#### `delayed_fields`

This is a dict of type name and fields names lists:

```python
{"Type": ["field", "otherField"], "OtherType": ["field"]}
```


### `add_field_dependencies`

```python
def add_field_dependencies(
self, schema_id: int, type_name: str, field_name: str, query: str
):
```

Adds fields specified in `query` as dependencies for `field_name` of `type_name` that should be retrieved from schema with `schema_id`.


#### Required arguments

- `schema_id`: an `int` with ID of schema returned by `add_remote_schema` or `add_schema`.
- `type_name`: a `str` with name of type for which dependencies will be set.
- `field_name`: a `str` with name of field which dependencies will be set.
- `query`: a `str` with additional fields to fetch when `field_name` is included, eg. `{ metadata { key value} }`.


### `add_foreign_key`

```python
Expand Down
28 changes: 23 additions & 5 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@ def add_field_dependencies(
f"Type '{type_name}' in schema with ID '{schema_id}' is not "
"an object type."
)
if field_name not in schema_type.fields:
raise ValueError(
f"Type '{type_name}' doesn't define the '{field_name}' field."
)

self.validate_field_with_dependencies(type_name, field_name)

if schema_id not in self.dependencies:
self.dependencies[schema_id] = {}
Expand All @@ -248,7 +246,14 @@ def add_field_dependencies(
)

def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetNode:
ast = parse(query)
clean_query = query.strip()
if not clean_query.startswith("{") or not clean_query.endswith("}"):
raise ValueError(
f"'{field_name}' field dependencies should be defined as a single "
"GraphQL operation, e.g.: '{ field other { subfield } }'."
)

ast = parse(clean_query)

if (
not len(ast.definitions) == 1
Expand All @@ -262,6 +267,19 @@ def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetN

return ast.definitions[0].selection_set

def validate_field_with_dependencies(self, type_name: str, field_name: str) -> None:
for schema in self.schemas:
if (
type_name in schema.type_map
and isinstance(schema.type_map[type_name], GraphQLObjectType)
and field_name in schema.type_map[type_name].fields
):
return

raise ValueError(
f"Type '{type_name}' doesn't define the '{field_name}' field in any of schemas."
)

def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
for type_name, type_fields in delayed_fields.items():
if type_name not in self.fields_map:
Expand Down
53 changes: 53 additions & 0 deletions ariadne_graphql_proxy/selections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Dict, Sequence, List, cast

from graphql import FieldNode, SelectionNode, SelectionSetNode


def merge_selection_sets(
set_a: SelectionSetNode, set_b: SelectionSetNode
) -> SelectionSetNode:
return SelectionSetNode(
selections=tuple(merge_selections(set_a.selections, set_b.selections)),
)


def merge_selections(
set_a: Sequence[SelectionNode], set_b: Sequence[SelectionNode]
) -> List[SelectionNode]:
final_set: List[SelectionNode] = list(set_a)

index: Dict[str, int] = {}
for i, field in enumerate(final_set):
if isinstance(field, FieldNode):
index[(field.alias or field.name).value] = i

for field in set_b:
if isinstance(field, FieldNode):
field_name = (field.alias or field.name).value
if field_name in index:
field_index = index[field_name]
other_field = cast(FieldNode, final_set[field_index])
if other_field.selection_set and field.selection_set:
final_set[field_index] = FieldNode(
directives=other_field.directives,
alias=other_field.alias,
name=field.name,
arguments=other_field.arguments,
selection_set=merge_selection_sets(
other_field.selection_set, field.selection_set
),
)
elif other_field.selection_set or field.selection_set:
final_set[field_index] = FieldNode(
directives=other_field.directives,
alias=other_field.alias,
name=field.name,
arguments=other_field.arguments,
selection_set=(
other_field.selection_set or field.selection_set
),
)
else:
final_set.append(field)

return final_set
82 changes: 82 additions & 0 deletions tests/test_merge_selection_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from textwrap import dedent

from graphql import parse, print_ast

from ariadne_graphql_proxy import merge_selection_sets


def test_merge_selection_sets_merges_two_flat_sets():
set_a = parse("{ hello }").definitions[0].selection_set
set_b = parse("{ world }").definitions[0].selection_set

result = merge_selection_sets(set_a, set_b)
assert (
print_ast(result)
== dedent(
"""
{
hello
world
}
"""
).strip()
)


def test_merge_selection_sets_merges_two_overlapping_flat_sets():
set_a = parse("{ hello world }").definitions[0].selection_set
set_b = parse("{ world }").definitions[0].selection_set

result = merge_selection_sets(set_a, set_b)
assert (
print_ast(result)
== dedent(
"""
{
hello
world
}
"""
).strip()
)


def test_merge_selection_sets_keeps_nested_selections():
set_a = parse("{ hello { sub } }").definitions[0].selection_set
set_b = parse("{ world }").definitions[0].selection_set

result = merge_selection_sets(set_a, set_b)
assert (
print_ast(result)
== dedent(
"""
{
hello {
sub
}
world
}
"""
).strip()
)


def test_merge_selection_sets_merges_selection_sets_recursively():
set_a = parse("{ hello { sub } }").definitions[0].selection_set
set_b = parse("{ hello { set } world }").definitions[0].selection_set

result = merge_selection_sets(set_a, set_b)
assert (
print_ast(result)
== dedent(
"""
{
hello {
sub
set
}
world
}
"""
).strip()
)
82 changes: 82 additions & 0 deletions tests/test_merge_selections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from textwrap import dedent

from graphql import SelectionSetNode, parse, print_ast

from ariadne_graphql_proxy import merge_selections


def test_merge_selections_merges_two_flat_sets():
set_a = parse("{ hello }").definitions[0].selection_set.selections
set_b = parse("{ world }").definitions[0].selection_set.selections

result = merge_selections(set_a, set_b)
assert (
print_ast(SelectionSetNode(selections=result))
== dedent(
"""
{
hello
world
}
"""
).strip()
)


def test_merge_selections_merges_two_overlapping_flat_sets():
set_a = parse("{ hello world }").definitions[0].selection_set.selections
set_b = parse("{ world }").definitions[0].selection_set.selections

result = merge_selections(set_a, set_b)
assert (
print_ast(SelectionSetNode(selections=result))
== dedent(
"""
{
hello
world
}
"""
).strip()
)


def test_merge_selections_keeps_nested_selections():
set_a = parse("{ hello { sub } }").definitions[0].selection_set.selections
set_b = parse("{ world }").definitions[0].selection_set.selections

result = merge_selections(set_a, set_b)
assert (
print_ast(SelectionSetNode(selections=result))
== dedent(
"""
{
hello {
sub
}
world
}
"""
).strip()
)


def test_merge_selections_merges_selection_sets_recursively():
set_a = parse("{ hello { sub } }").definitions[0].selection_set.selections
set_b = parse("{ hello { set } world }").definitions[0].selection_set.selections

result = merge_selections(set_a, set_b)
assert (
print_ast(SelectionSetNode(selections=result))
== dedent(
"""
{
hello {
sub
set
}
world
}
"""
).strip()
)
Loading

0 comments on commit a2fd2e9

Please sign in to comment.