Skip to content

Commit

Permalink
Merge pull request #54 from mirumee/fix-53-fields-dependencies
Browse files Browse the repository at this point in the history
Add field dependencies
  • Loading branch information
rafalp authored Mar 21, 2024
2 parents c243831 + 783d3d3 commit cedcfc8
Show file tree
Hide file tree
Showing 10 changed files with 700 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ jobs:
pytest
- name: Linters
run: |
ruff ariadne_graphql_proxy tests
ruff check ariadne_graphql_proxy tests
mypy ariadne_graphql_proxy --ignore-missing-imports --check-untyped-defs
black --check ariadne_graphql_proxy tests
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# CHANGELOG

## UNRELEASED
## 0.3.0 (UNRELEASED)

- Added `CacheSerializer`, `NoopCacheSerializer` and `JSONCacheSerializer`. Changed `CacheBackend`, `InMemoryCache`, `CloudflareCacheBackend` and `DynamoDBCacheBackend` to accept `serializer` initialization option.
- Fixed schema proxy returning an error when variable defined in an operation is missing from its variables.
- Improved custom headers handling in `ProxyResolver` and `ProxySchema`.
- Added fields dependencies configuration option to `ProxySchema`.


## 0.2.0 (2023-09-25)
Expand Down
93 changes: 93 additions & 0 deletions GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,71 @@ If `proxy_headers` is a callable, it will be called with single argument (`conte
If `proxy_headers` is `None` or `False`, no headers are proxied to the other service.


## Fields dependencies

In situations where field depends on data from sibling fields in order to be resolved, `ProxySchema` can be configured to include those additional fields in root value query sent to remote schema.

Below example pulls a remote schema that defines `Product` type, extends this type with `image: String` field, and then uses `ProxySchema.add_field_dependencies` to configure `{ metadata { thumb} }` as additional fields to retrieve when `image` field is queried. It also includes custom resolver for `image` field that uses this additional data:


```python
from ariadne.asgi import GraphQL
from ariadne_graphql_proxy import (
ProxySchema,
get_context_value,
set_resolver,
)
from graphql import build_ast_schema, parse


proxy_schema = ProxySchema()

# Store schema ID for remote schema
remote_schema_id = proxy_schema.add_remote_schema(
"https://example.com/graphql/",
)

# Extend Product type with additional image field
proxy_schema.add_schema(
build_ast_schema(
parse(
"""
type Product {
image: String
}
"""
)
)
)

# Configure proxy schema to retrieve thumb from metadata
# from remote schema when image is queried
proxy_schema.add_field_dependencies(
remote_schema_id, "Product", "image", "{ metadata { thumb } }"
)

# Create schema instance
final_schema = proxy_schema.get_final_schema()


# Add product image resolver
def resolve_product_image(obj, info):
return obj["metadata"]["thumb"]


set_resolver(final_schema, "Product", "image", resolve_product_image)


# Setup Ariadne ASGI GraphQL application
app = GraphQL(
final_schema,
context_value=get_context_value,
root_value=proxy_schema.root_resolver,
debug=True,
)
```


## Cache framework

Ariadne GraphQL Proxy implements basic cache framework that enables of caching parts of GraphQL queries.
Expand Down Expand Up @@ -855,6 +920,34 @@ def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
Sets specific fields in schema as delayed. Delayed fields are excluded from queries ran by `root_resolver` against the remote GraphQL APIs.


#### `delayed_fields`

This is a dict of type name and fields names lists:

```python
{"Type": ["field", "otherField"], "OtherType": ["field"]}
```


### `add_field_dependencies`

```python
def add_field_dependencies(
self, schema_id: int, type_name: str, field_name: str, query: str
):
```

Adds fields specified in `query` as dependencies for `field_name` of `type_name` that should be retrieved from schema with `schema_id`.


#### Required arguments

- `schema_id`: an `int` with ID of schema returned by `add_remote_schema` or `add_schema`.
- `type_name`: a `str` with name of type for which dependencies will be set.
- `field_name`: a `str` with name of field which dependencies will be set.
- `query`: a `str` with additional fields to fetch when `field_name` is included, eg. `{ metadata { key value} }`.


### `add_foreign_key`

```python
Expand Down
3 changes: 3 additions & 0 deletions ariadne_graphql_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .query_filter import QueryFilter, QueryFilterContext
from .remote_schema import get_remote_schema
from .resolvers import set_resolver, unset_resolver
from .selections import merge_selection_sets, merge_selections

__all__ = [
"ForeignKeyResolver",
Expand Down Expand Up @@ -84,6 +85,8 @@
"merge_objects",
"merge_scalars",
"merge_schemas",
"merge_selection_sets",
"merge_selections",
"merge_type_maps",
"merge_types",
"merge_unions",
Expand Down
104 changes: 104 additions & 0 deletions ariadne_graphql_proxy/proxy_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
GraphQLSchema,
GraphQLUnionType,
GraphQLWrappingType,
OperationDefinitionNode,
OperationType,
SelectionSetNode,
parse,
print_ast,
)
from httpx import AsyncClient
Expand All @@ -20,6 +24,7 @@
from .proxy_root_value import ProxyRootValue
from .query_filter import QueryFilter
from .remote_schema import get_remote_schema
from .selections import merge_selection_sets
from .standard_types import STANDARD_TYPES, add_missing_scalar_types
from .str_to_field import (
get_field_definition_from_str,
Expand All @@ -46,6 +51,7 @@ def __init__(
self.fields_types: Dict[str, Dict[str, str]] = {}
self.unions: Dict[str, List[str]] = {}
self.foreign_keys: Dict[str, Dict[str, List[str]]] = {}
self.dependencies: Dict[int, Dict[str, Dict[str, SelectionSetNode]]] = {}

self.proxy_root_value = proxy_root_value

Expand Down Expand Up @@ -176,8 +182,105 @@ def add_foreign_key(
if field_name in self.foreign_keys[type_name]:
raise ValueError(f"Foreign key already exists on {type_name}.{field_name}")

for schema_dependencies in self.dependencies.values():
if (
type_name in schema_dependencies
and field_name in schema_dependencies[type_name]
):
raise ValueError(
f"Foreign key can't be created for {type_name}.{field_name} because "
"field dependencies were previously defined for it."
)

self.foreign_keys[type_name][field_name] = [on] if isinstance(on, str) else on

def add_field_dependencies(
self, schema_id: int, type_name: str, field_name: str, query: str
):
if type_name in ("Query", "Mutation", "Subscription"):
raise ValueError(
f"Defining field dependencies for {type_name} fields is not allowed."
)

if (
type_name in self.foreign_keys
and field_name in self.foreign_keys[type_name]
):
raise ValueError(
f"Dependencies can't be created for {type_name}.{field_name} because "
"foreign key was previously defined for it."
)

if schema_id < 0 or schema_id + 1 > len(self.urls):
raise ValueError(f"Schema with ID '{schema_id}' doesn't exist.")
if not self.urls[schema_id]:
raise ValueError(f"Schema with ID '{schema_id}' is not a remote schema.")

schema = self.schemas[schema_id]
if type_name not in schema.type_map:
raise ValueError(
f"Type '{type_name}' doesn't exist in schema with ID '{schema_id}'."
)

schema_type = schema.type_map[type_name]
if not isinstance(schema_type, GraphQLObjectType):
raise ValueError(
f"Type '{type_name}' in schema with ID '{schema_id}' is not "
"an object type."
)

self.validate_field_with_dependencies(type_name, field_name)

if schema_id not in self.dependencies:
self.dependencies[schema_id] = {}
if type_name not in self.dependencies[schema_id]:
self.dependencies[schema_id][type_name] = {}

selection_set = self.parse_field_dependencies(field_name, query)

type_dependencies = self.dependencies[schema_id][type_name]
if not type_dependencies.get(field_name):
type_dependencies[field_name] = selection_set
else:
type_dependencies[field_name] = merge_selection_sets(
type_dependencies[field_name], selection_set
)

def parse_field_dependencies(self, field_name: str, query: str) -> SelectionSetNode:
clean_query = query.strip()
if not clean_query.startswith("{") or not clean_query.endswith("}"):
raise ValueError(
f"'{field_name}' field dependencies should be defined as a single "
"GraphQL operation, e.g.: '{ field other { subfield } }'."
)

ast = parse(clean_query)

if (
not len(ast.definitions) == 1
or not isinstance(ast.definitions[0], OperationDefinitionNode)
or ast.definitions[0].operation != OperationType.QUERY
):
raise ValueError(
f"'{field_name}' field dependencies should be defined as a single "
"GraphQL operation, e.g.: '{ field other { subfield } }'."
)

return ast.definitions[0].selection_set

def validate_field_with_dependencies(self, type_name: str, field_name: str) -> None:
for schema in self.schemas:
if (
type_name in schema.type_map
and isinstance(schema.type_map[type_name], GraphQLObjectType)
and field_name in schema.type_map[type_name].fields # type: ignore
):
return

raise ValueError(
f"Type '{type_name}' doesn't define the '{field_name}' field in any of schemas."
)

def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]):
for type_name, type_fields in delayed_fields.items():
if type_name not in self.fields_map:
Expand Down Expand Up @@ -227,6 +330,7 @@ def get_final_schema(self) -> GraphQLSchema:
self.fields_types,
self.unions,
self.foreign_keys,
self.dependencies,
)

return self.schema
Expand Down
56 changes: 50 additions & 6 deletions ariadne_graphql_proxy/query_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
VariableNode,
)

from .selections import merge_selections


class QueryFilterContext:
schema_id: int
Expand All @@ -35,13 +37,15 @@ def __init__(
fields_types: Dict[str, Dict[str, str]],
unions: Dict[str, List[str]],
foreign_keys: Dict[str, Dict[str, List[str]]],
dependencies: Dict[int, Dict[str, Dict[str, SelectionSetNode]]],
):
self.schema = schema
self.schemas = schemas
self.fields_map = fields_map
self.fields_types = fields_types
self.unions = unions
self.foreign_keys = foreign_keys
self.dependencies = dependencies

def split_query(
self, document: DocumentNode
Expand Down Expand Up @@ -189,12 +193,22 @@ def filter_field_node(
else:
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in field_node.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -244,12 +258,22 @@ def filter_inline_fragment_node(
type_name = fragment_node.type_condition.name.value
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in fragment_node.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -294,12 +318,22 @@ def filter_fragment_spread_node(
type_name = fragment.type_condition.name.value
type_fields = self.fields_map[type_name]

fields_dependencies = self.get_type_fields_dependencies(
context.schema_id, type_name
)

new_selections: List[SelectionNode] = []
for selection in fragment.selection_set.selections:
if isinstance(selection, FieldNode):
field_name = selection.name.value
if fields_dependencies and field_name in fields_dependencies:
new_selections = merge_selections(
new_selections, fields_dependencies[field_name].selections
)

if (
selection.name.value not in type_fields
or context.schema_id not in type_fields[selection.name.value]
field_name not in type_fields
or context.schema_id not in type_fields[field_name]
):
continue

Expand Down Expand Up @@ -347,3 +381,13 @@ def inline_fragment_spread_node(
selections=tuple(selections),
),
)

def get_type_fields_dependencies(
self,
schema_id: int,
type_name: str,
) -> Optional[Dict[str, SelectionSetNode]]:
if schema_id in self.dependencies and type_name in self.dependencies[schema_id]:
return self.dependencies[schema_id][type_name]

return None
Loading

0 comments on commit cedcfc8

Please sign in to comment.