From dcbd53444b3d45d0e95fb3101c0cd5ec012c182b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 20 Sep 2023 15:54:20 +0200 Subject: [PATCH 1/5] Add insert_field method to ProxySchema --- GUIDE.md | 14 ++ ariadne_graphql_proxy/proxy_schema.py | 20 +- ariadne_graphql_proxy/standard_types.py | 14 ++ ariadne_graphql_proxy/str_to_field.py | 58 ++++++ tests/conftest.py | 8 + tests/test_proxy_schema.py | 46 +++++ tests/test_str_to_field.py | 246 ++++++++++++++++++++++++ 7 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 ariadne_graphql_proxy/str_to_field.py create mode 100644 tests/test_str_to_field.py diff --git a/GUIDE.md b/GUIDE.md index 9d1c92a..c4c52d3 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -762,6 +762,20 @@ def get_sub_schema(self, schema_id: int) -> GraphQLSchema: Returns sub schema with given id. If schema doesn't exist, raises `IndexError`. +### `insert_field` + +```python +def insert_field(self, type_name: str, field_str: str): +``` + +Inserts field into all schemas with given `type_name`. The field is automatically delayed - excluded from queries run by `root_resolver` against the remote GraphQL APIs. + +#### Required arguments + +- `type_name`: a `str` with the name of the type into which the field will be inserted. +- `field_str`: a `str` with SDL field representation, e.g. `"fieldA(argA: String!) Int"`. + + ### `get_final_schema` ```python diff --git a/ariadne_graphql_proxy/proxy_schema.py b/ariadne_graphql_proxy/proxy_schema.py index bedc458..2647cfb 100644 --- a/ariadne_graphql_proxy/proxy_schema.py +++ b/ariadne_graphql_proxy/proxy_schema.py @@ -18,7 +18,11 @@ from .merge import merge_schemas from .query_filter import QueryFilter from .remote_schema import get_remote_schema -from .standard_types import STANDARD_TYPES +from .standard_types import STANDARD_TYPES, add_missing_scalar_types +from .str_to_field import ( + get_field_definition_from_str, + get_graphql_field_from_field_definition, +) class ProxySchema: @@ -83,6 +87,8 @@ def add_schema( exclude_directives_args=exclude_directives_args, ) + schema.type_map = add_missing_scalar_types(schema.type_map) + self.schemas.append(schema) self.urls.append(url) @@ -135,6 +141,18 @@ def add_delayed_fields(self, delayed_fields: Dict[str, List[str]]): for field_name in type_fields: self.fields_map[type_name].pop(field_name, None) + def insert_field(self, type_name: str, field_str: str): + field_definition = get_field_definition_from_str(field_str=field_str) + field_name = field_definition.name.value + for schema in self.schemas: + type_ = schema.type_map.get(type_name) + if not type_ or not hasattr(type_, "fields"): + continue + + type_.fields[field_name] = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + def get_sub_schema(self, schema_id: int) -> GraphQLSchema: try: return self.schemas[schema_id] diff --git a/ariadne_graphql_proxy/standard_types.py b/ariadne_graphql_proxy/standard_types.py index 80bbdee..d15f8a1 100644 --- a/ariadne_graphql_proxy/standard_types.py +++ b/ariadne_graphql_proxy/standard_types.py @@ -1,3 +1,5 @@ +from graphql import GraphQLBoolean, GraphQLFloat, GraphQLID, GraphQLInt, GraphQLString + STANDARD_TYPES = ( "ID", "Boolean", @@ -13,3 +15,15 @@ "__Directive", "__DirectiveLocation", ) + + +def add_missing_scalar_types(schema_types: dict) -> dict: + scalar_types = { + "ID": GraphQLID, + "Boolean": GraphQLBoolean, + "Float": GraphQLFloat, + "Int": GraphQLInt, + "String": GraphQLString, + } + scalar_types.update(schema_types) + return scalar_types diff --git a/ariadne_graphql_proxy/str_to_field.py b/ariadne_graphql_proxy/str_to_field.py new file mode 100644 index 0000000..6a23957 --- /dev/null +++ b/ariadne_graphql_proxy/str_to_field.py @@ -0,0 +1,58 @@ +from typing import cast + +from graphql import ( + FieldDefinitionNode, + GraphQLArgument, + GraphQLField, + GraphQLInputType, + GraphQLOutputType, + GraphQLSchema, + assert_input_type, + assert_output_type, + parse, + type_from_ast, + value_from_ast, +) + + +def get_field_definition_from_str(field_str: str) -> FieldDefinitionNode: + document = parse(f"type Placeholder{{ {field_str} }}") + + if len(document.definitions) != 1: + raise ValueError("Field str has to define 1 type.") + + definition = document.definitions[0] + + fields = getattr(definition, "fields", []) + if len(fields) != 1: + raise ValueError("Field str has to provide only 1 field.") + + return fields[0] + + +def get_graphql_field_from_field_definition( + schema: GraphQLSchema, field_definition: FieldDefinitionNode +) -> GraphQLField: + field_type = cast(GraphQLOutputType, type_from_ast(schema, field_definition.type)) + assert_output_type(field_type) + + field_args = {} + for arg in field_definition.arguments: + arg_type = cast(GraphQLInputType, type_from_ast(schema, arg.type)) + assert_input_type(arg_type) + arg_default_value = value_from_ast(value_node=arg.default_value, type_=arg_type) + + field_args[arg.name.value] = GraphQLArgument( + type_=arg_type, + default_value=arg_default_value, + ) + + description = ( + None if not field_definition.description else field_definition.description.value + ) + + return GraphQLField( + type_=field_type, + args=field_args, + description=description, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 89448e1..054d6ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,14 @@ def schema(): name(arg: Generic, other: Generic): String! rank(arg: Generic, other: Generic): Int! } + + input InputType { + arg1: Float! + arg2: Boolean! + arg3: String! + arg4: ID! + arg5: Int! + } """ ) diff --git a/tests/test_proxy_schema.py b/tests/test_proxy_schema.py index dd5127c..69226fe 100644 --- a/tests/test_proxy_schema.py +++ b/tests/test_proxy_schema.py @@ -769,3 +769,49 @@ async def test_proxy_schema_splits_variables_from_fragments_between_schemas( """ ).strip(), } + + +def test_insert_field_adds_field_into_local_schemas_with_given_type( + schema, complex_schema +): + proxy_schema = ProxySchema() + proxy_schema.add_schema(schema) + proxy_schema.add_schema(complex_schema) + + proxy_schema.insert_field("Complex", "newField: String!") + + assert proxy_schema.get_sub_schema(0).type_map["Complex"].fields["newField"] + assert proxy_schema.get_sub_schema(1).type_map["Complex"].fields["newField"] + + +def test_insert_field_adds_field_into_remote_schemas_with_given_type( + httpx_mock, schema_json, complex_schema_json +): + httpx_mock.add_response(url="http://graphql.example.com/schema/", json=schema_json) + httpx_mock.add_response( + url="http://graphql.example.com/complex_schema/", json=complex_schema_json + ) + proxy_schema = ProxySchema() + proxy_schema.add_remote_schema("http://graphql.example.com/schema/") + proxy_schema.add_remote_schema("http://graphql.example.com/complex_schema/") + + proxy_schema.insert_field("Complex", "newField: String!") + + assert proxy_schema.get_sub_schema(0).type_map["Complex"].fields["newField"] + assert proxy_schema.get_sub_schema(1).type_map["Complex"].fields["newField"] + + +def test_insert_field_adds_field_into_both_local_and_remote_schema( + httpx_mock, schema, complex_schema_json +): + httpx_mock.add_response( + url="http://graphql.example.com/complex_schema/", json=complex_schema_json + ) + proxy_schema = ProxySchema() + proxy_schema.add_schema(schema) + proxy_schema.add_remote_schema("http://graphql.example.com/complex_schema/") + + proxy_schema.insert_field("Complex", "newField: String!") + + assert proxy_schema.get_sub_schema(0).type_map["Complex"].fields["newField"] + assert proxy_schema.get_sub_schema(1).type_map["Complex"].fields["newField"] diff --git a/tests/test_str_to_field.py b/tests/test_str_to_field.py new file mode 100644 index 0000000..786fc1e --- /dev/null +++ b/tests/test_str_to_field.py @@ -0,0 +1,246 @@ +import pytest +from graphql import ( + BooleanValueNode, + FieldDefinitionNode, + GraphQLArgument, + GraphQLNonNull, + GraphQLScalarType, + GraphQLSyntaxError, + InputValueDefinitionNode, + IntValueNode, + NamedTypeNode, + NonNullTypeNode, + parse, + GraphQLField, +) + +from ariadne_graphql_proxy.str_to_field import ( + get_field_definition_from_str, + get_graphql_field_from_field_definition, +) + + +def test_get_field_definition_from_str_returns_field_definition_node(): + field_definition = get_field_definition_from_str("field: String!") + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "field" + assert isinstance(field_definition.type, NonNullTypeNode) + assert isinstance(field_definition.type.type, NamedTypeNode) + assert field_definition.type.type.name.value == "String" + + +def test_get_field_definition_from_str_returns_node_with_nullable_type(): + field_definition = get_field_definition_from_str("nullableField: Int") + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "nullableField" + assert isinstance(field_definition.type, NamedTypeNode) + assert field_definition.type.name.value == "Int" + + +def test_get_field_definition_from_str_returns_node_with_arguments(): + field_definition = get_field_definition_from_str( + "fieldWithArgs(arg1: Int, arg2: Boolean!): Float!" + ) + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "fieldWithArgs" + assert isinstance(field_definition.type, NonNullTypeNode) + assert isinstance(field_definition.type.type, NamedTypeNode) + assert field_definition.type.type.name.value == "Float" + + arg1 = field_definition.arguments[0] + assert arg1.name.value == "arg1" + assert isinstance(arg1, InputValueDefinitionNode) + assert isinstance(arg1.type, NamedTypeNode) + assert arg1.type.name.value == "Int" + + arg2 = field_definition.arguments[1] + assert arg2.name.value == "arg2" + assert isinstance(arg2, InputValueDefinitionNode) + assert isinstance(arg2.type, NonNullTypeNode) + assert isinstance(arg2.type.type, NamedTypeNode) + assert arg2.type.type.name.value == "Boolean" + + +def test_get_field_definition_from_str_returns_node_with_arguments_default_values(): + field_definition = get_field_definition_from_str( + "fieldWithArgs(arg1: Int = 5, arg2: Boolean! = true): Float!" + ) + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "fieldWithArgs" + assert isinstance(field_definition.type, NonNullTypeNode) + assert isinstance(field_definition.type.type, NamedTypeNode) + assert field_definition.type.type.name.value == "Float" + + arg1 = field_definition.arguments[0] + assert arg1.name.value == "arg1" + assert isinstance(arg1, InputValueDefinitionNode) + assert isinstance(arg1.type, NamedTypeNode) + assert arg1.type.name.value == "Int" + assert isinstance(arg1.default_value, IntValueNode) + assert arg1.default_value.value == "5" + + arg2 = field_definition.arguments[1] + assert arg2.name.value == "arg2" + assert isinstance(arg2, InputValueDefinitionNode) + assert isinstance(arg2.type, NonNullTypeNode) + assert isinstance(arg2.type.type, NamedTypeNode) + assert arg2.type.type.name.value == "Boolean" + assert isinstance(arg2.default_value, BooleanValueNode) + assert arg2.default_value.value + + +def test_get_field_definition_from_str_returns_node_with_non_scalar_type(): + field_definition = get_field_definition_from_str("field: TestType!") + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "field" + assert isinstance(field_definition.type, NonNullTypeNode) + assert isinstance(field_definition.type.type, NamedTypeNode) + assert field_definition.type.type.name.value == "TestType" + + +def test_get_field_definition_from_str_returns_node_with_input_type_argument(): + field_definition = get_field_definition_from_str("field(arg: TestInput): String") + + assert isinstance(field_definition, FieldDefinitionNode) + assert field_definition.name.value == "field" + assert isinstance(field_definition.type, NamedTypeNode) + assert field_definition.type.name.value == "String" + + arg = field_definition.arguments[0] + assert arg.name.value == "arg" + assert isinstance(arg, InputValueDefinitionNode) + assert isinstance(arg.type, NamedTypeNode) + assert arg.type.name.value == "TestInput" + + +@pytest.mark.parametrize("invalid_str", ["", "field", "field String!"]) +def test_get_field_definition_from_str_raises_graphql_syntax_error_for_invalid_str( + invalid_str, +): + with pytest.raises(GraphQLSyntaxError): + get_field_definition_from_str(invalid_str) + + +def test_get_field_definition_from_str_raises_value_error_for_multiple_fields(): + with pytest.raises(ValueError): + get_field_definition_from_str("fieldA: String!\nfieldB: Float!") + + +def test_get_graphql_field_from_field_definition_returns_graphql_field(schema): + field_definition = parse("type A{ field: String! }").definitions[0].fields[0] + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert isinstance(graphql_field.type, GraphQLNonNull) + assert isinstance(graphql_field.type.of_type, GraphQLScalarType) + assert graphql_field.type.of_type.name == "String" + + +def test_get_graphql_field_from_field_definition_returns_field_with_nullable_type( + schema, +): + field_definition = parse("type A{ nullableField: Float }").definitions[0].fields[0] + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert isinstance(graphql_field.type, GraphQLScalarType) + assert graphql_field.type.name == "Float" + + +def test_get_graphql_field_from_field_definition_returns_field_with_arguments( + schema, +): + field_definition = ( + parse("type A{ fieldWithArgs(arg1: Int, arg2: Boolean!): Float }") + .definitions[0] + .fields[0] + ) + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert isinstance(graphql_field.type, GraphQLScalarType) + assert graphql_field.type.name == "Float" + + arg1 = graphql_field.args["arg1"] + assert isinstance(arg1, GraphQLArgument) + assert isinstance(arg1.type, GraphQLScalarType) + assert arg1.type.name == "Int" + + arg2 = graphql_field.args["arg2"] + assert isinstance(arg2, GraphQLArgument) + assert isinstance(arg2.type, GraphQLNonNull) + assert isinstance(arg2.type.of_type, GraphQLScalarType) + assert arg2.type.of_type.name == "Boolean" + + +def test_get_graphql_field_from_field_definition_returns_arguments_with_default_values( + schema, +): + field_definition = ( + parse("type A{ fieldWithArgs(arg1: Int = 5, arg2: Boolean! = true): Float }") + .definitions[0] + .fields[0] + ) + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert isinstance(graphql_field.type, GraphQLScalarType) + assert graphql_field.type.name == "Float" + + arg1 = graphql_field.args["arg1"] + assert isinstance(arg1, GraphQLArgument) + assert isinstance(arg1.type, GraphQLScalarType) + assert arg1.type.name == "Int" + assert arg1.default_value == 5 + + arg2 = graphql_field.args["arg2"] + assert isinstance(arg2, GraphQLArgument) + assert isinstance(arg2.type, GraphQLNonNull) + assert isinstance(arg2.type.of_type, GraphQLScalarType) + assert arg2.type.of_type.name == "Boolean" + assert arg2.default_value is True + + +def test_get_graphql_field_from_field_definition_returns_field_with_non_scalar_type( + schema, +): + field_definition = parse("type A{ field: Complex }").definitions[0].fields[0] + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert graphql_field.type is schema.type_map["Complex"] + + +def test_get_graphql_field_from_field_definition_returns_field_with_non_scalar_argument( + schema, +): + field_definition = ( + parse("type A{ field(arg: InputType): String }").definitions[0].fields[0] + ) + + graphql_field = get_graphql_field_from_field_definition( + schema=schema, field_definition=field_definition + ) + + assert isinstance(graphql_field, GraphQLField) + assert graphql_field.args["arg"].type is schema.type_map["InputType"] From 1d5189fb4c5cbb8cf0669a02c17173b8a643d9ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 20 Sep 2023 16:02:41 +0200 Subject: [PATCH 2/5] Add get_query_params_resolver as factory for imgix resolvers. --- GUIDE.md | 77 ++++++++++++++++ .../contrib/imgix/__init__.py | 3 + .../contrib/imgix/query_params_resolver.py | 59 ++++++++++++ tests/contrib/imgix/__init__.py | 0 .../imgix/test_query_params_resolver.py | 90 +++++++++++++++++++ 5 files changed, 229 insertions(+) create mode 100644 ariadne_graphql_proxy/contrib/imgix/__init__.py create mode 100644 ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py create mode 100644 tests/contrib/imgix/__init__.py create mode 100644 tests/contrib/imgix/test_query_params_resolver.py diff --git a/GUIDE.md b/GUIDE.md index c4c52d3..ac9b6ab 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -446,6 +446,83 @@ app = GraphQL( ``` +## imgix query params resolver + +`get_query_params_resolver` returns a preconfigured resolver that takes URL string and passed arguments to generate a URL with arguments as query params. It can be used to add [rendering options](https://docs.imgix.com/apis/rendering) to [imgix.com](https://imgix.com) image URL. + +### Function arguments: + +- `get_url`: a `str` or `Callable` which returns `str`. If `get_url` is a `str` then the resolver will split it by `.` and use substrings as keys to get value from `obj` dict, e.g. with `get_url` set to `"imageData.url"` the resolver will use `obj["imageData"]["url"]` as URL string. If `get_url` is a callable, then resolver will call it with `obj`, `info` and `**kwargs` and use result as URL string. +- `extra_params`: an optional `dict` of query params to be added to the URL string. These can be overridden by kwargs passed to the resolver. +- `get_params`: an optional `Callable` to be called on passed `**kwargs` before they are added to the URL string. +- `serialize_url`: an optional `Callable` to be called on URL string with query params already added. Result is returned directly by the resolver. + +### Example with `insert_field` + +In this example we assume there is a graphql server which provides following schema: + +```gql +type Query { + product: Product! +} + +type Product { + imageUrl: String! +} +``` + +`imageUrl` returns URL string served by [imgix.com](https://imgix.com) and we want to add another field with thumbnail URL. + +```python +from ariadne_graphql_proxy import ProxySchema, get_context_value, set_resolver +from ariadne_graphql_proxy.contrib.imgix import get_query_params_resolver + + +proxy_schema = ProxySchema() +proxy_schema.add_remote_schema("https://remote-schema.local") +proxy_schema.insert_field( + type_name="Product", + field_str="thumbnailUrl(w: Int, h: Int): String!", +) + +final_schema = proxy_schema.get_final_schema() + +set_resolver( + final_schema, + "Product", + "thumbnailUrl", + get_query_params_resolver( + "imageUrl", + extra_params={"h": 128, "w": 128, "fit": "min"}, + ), +) +``` + +With an added resolver, `thumbnailUrl` will return `imageUrl` with additional query parameters. `fit` is always set to `min`. `w` and `h` are set to `128` by default, but can be changed by query argument, e.g. + +```gql +query getProduct { + product { + imageUrl + thumbnailUrl + smallThumbnailUrl: thumbnailUrl(w: 32, h: 32) + } +} +``` + +```json +{ + "data": { + "product": { + "imageUrl": "https://test-imageix.com/product-image.jpg", + "thumbnailUrl": "https://test-imageix.com/product-image.jpg?h=128&w=128&fit=min", + "smallThumbnailUrl": "https://test-imageix.com/product-image.jpg?h=32&w=32&fit=min" + } + } +} +``` + + ## Proxying headers Ariadne GraphQL Proxy requires that `GraphQLResolveInfo.context` attribute is a dictionary containing `headers` key, which in itself is a `Dict[str, str]` dictionary. diff --git a/ariadne_graphql_proxy/contrib/imgix/__init__.py b/ariadne_graphql_proxy/contrib/imgix/__init__.py new file mode 100644 index 0000000..74116ff --- /dev/null +++ b/ariadne_graphql_proxy/contrib/imgix/__init__.py @@ -0,0 +1,3 @@ +from .query_params_resolver import get_query_params_resolver + +__all__ = ["get_query_params_resolver"] diff --git a/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py b/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py new file mode 100644 index 0000000..0c14c72 --- /dev/null +++ b/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py @@ -0,0 +1,59 @@ +from functools import partial +from typing import Any, Callable, Optional, Union, cast +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from graphql import GraphQLResolveInfo + + +def get_attribute_value( + obj: Any, info: GraphQLResolveInfo, attribute_str: str, **kwargs +) -> Any: + value = obj + for attr in attribute_str.split("."): + value = value.get(attr) + return value + + +def get_query_params_resolver( + get_url: Union[str, Callable[..., str]], + extra_params: Optional[dict[str, Any]] = None, + get_params: Optional[Callable[..., dict[str, Any]]] = None, + serialize_url: Optional[Callable[[str], Any]] = None, +): + get_source_url = cast( + Callable[..., str], + ( + get_url + if callable(get_url) + else partial(get_attribute_value, attribute_str=get_url) + ), + ) + params = cast(dict[str, Any], extra_params if extra_params is not None else {}) + get_params_from_kwargs = cast( + Callable[..., dict[str, Any]], + get_params if get_params is not None else lambda **kwargs: kwargs, + ) + serialize = cast( + Callable[[str], Any], + serialize_url if serialize_url is not None else lambda url: url, + ) + + def resolver(obj: Any, info: GraphQLResolveInfo, **kwargs): + source_url = get_source_url(obj, info, **kwargs) + parse_result = urlparse(source_url) + query_params = parse_qs(parse_result.query) + query_params.update(params) + query_params.update(get_params_from_kwargs(**kwargs)) + result_url = urlunparse( + ( + parse_result.scheme, + parse_result.netloc, + parse_result.path, + parse_result.params, + urlencode(query_params), + parse_result.fragment, + ) + ) + return serialize(result_url) + + return resolver diff --git a/tests/contrib/imgix/__init__.py b/tests/contrib/imgix/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/contrib/imgix/test_query_params_resolver.py b/tests/contrib/imgix/test_query_params_resolver.py new file mode 100644 index 0000000..de9721d --- /dev/null +++ b/tests/contrib/imgix/test_query_params_resolver.py @@ -0,0 +1,90 @@ +import pytest + +from ariadne_graphql_proxy.contrib.imgix import get_query_params_resolver + + +def test_resolver_returns_url_from_given_attribute(): + resolver = get_query_params_resolver(get_url="a.b.c.url") + + assert ( + resolver({"a": {"b": {"c": {"url": "http://test.url"}}}}, None) + == "http://test.url" + ) + + +def test_resolver_calls_get_url_callable(mocker): + get_url = mocker.MagicMock(side_effect=lambda obj, info, **kwargs: obj["url"]) + resolver = get_query_params_resolver(get_url=get_url) + + assert resolver({"url": "http://test.url"}, None) == "http://test.url" + assert get_url.call_count == 1 + + +def test_resolver_passes_all_args_to_get_url_callable(mocker): + get_url = mocker.MagicMock(side_effect=lambda obj, info, **kwargs: obj["url"]) + resolver = get_query_params_resolver(get_url=get_url) + obj = {"url": "http://test.url"} + info = {"xyz": "XYZ"} + + resolver(obj, info, a="AA", b="BB") + assert get_url.call_count == 1 + assert get_url.call_args.args == (obj, info) + assert get_url.call_args.kwargs == {"a": "AA", "b": "BB"} + + +@pytest.mark.parametrize("get_url", ("url", lambda obj, info, **kwargs: obj["url"])) +def test_resolver_adds_predefined_query_param_to_url(get_url): + resolver = get_query_params_resolver( + get_url=get_url, extra_params={"abc": "test_value"} + ) + + assert ( + resolver({"url": "http://test.url"}, None) == "http://test.url?abc=test_value" + ) + + +@pytest.mark.parametrize("get_url", ("url", lambda obj, info, **kwargs: obj["url"])) +def test_resolver_adds_params_from_kwargs(get_url): + resolver = get_query_params_resolver(get_url=get_url) + + assert ( + resolver({"url": "http://test.url"}, None, xyz="XYZ") + == "http://test.url?xyz=XYZ" + ) + + +@pytest.mark.parametrize("get_url", ("url", lambda obj, info, **kwargs: obj["url"])) +def test_resolver_calls_get_params_on_given_kwargs(get_url, mocker): + get_params = mocker.MagicMock(side_effect=lambda **kwargs: kwargs["params"]) + resolver = get_query_params_resolver(get_url=get_url, get_params=get_params) + + assert ( + resolver({"url": "http://test.url"}, None, params={"a": "AAA"}) + == "http://test.url?a=AAA" + ) + + +@pytest.mark.parametrize("get_url", ("url", lambda obj, info, **kwargs: obj["url"])) +def test_resolver_adds_both_predefined_and_provided_in_kwargs_params(get_url): + resolver = get_query_params_resolver( + get_url=get_url, extra_params={"a": "AAA", "b": "BBB"} + ) + + assert ( + resolver({"url": "http://test.url"}, None, b="REPLACED-BBB", c="CCC") + == "http://test.url?a=AAA&b=REPLACED-BBB&c=CCC" + ) + + +@pytest.mark.parametrize("get_url", ("url", lambda obj, info, **kwargs: obj["url"])) +def test_resolver_calls_provided_serialize_url_on_modified_url(get_url, mocker): + serialize_url = mocker.MagicMock(side_effect=lambda url: url + "-serialized") + + resolver = get_query_params_resolver( + get_url=get_url, extra_params={"a": "AAA"}, serialize_url=serialize_url + ) + + assert ( + resolver({"url": "http://test.url"}, None, b="BBB") + == "http://test.url?a=AAA&b=BBB-serialized" + ) From 0150699225e7955431ad1c64f1b280c754aadc6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 20 Sep 2023 16:02:56 +0200 Subject: [PATCH 3/5] Add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a82116e..8c104cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Added `CloudflareCacheBackend`. - Added `DynamoDBCacheBackend`. - Changed `QueryFilter` and `root_resolver` to split variables between schemas. +- Added `insert_field` utility to `ProxySchema`. Added `get_query_params_resolver` as factory for `imgix` resolvers. ## 0.1.0 (2023-06-13) From 1de040d8740d087ad1be11270d398e6edacb64d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Thu, 21 Sep 2023 10:27:01 +0200 Subject: [PATCH 4/5] Fix formatting --- GUIDE.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/GUIDE.md b/GUIDE.md index ac9b6ab..5e430b0 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -450,6 +450,7 @@ app = GraphQL( `get_query_params_resolver` returns a preconfigured resolver that takes URL string and passed arguments to generate a URL with arguments as query params. It can be used to add [rendering options](https://docs.imgix.com/apis/rendering) to [imgix.com](https://imgix.com) image URL. + ### Function arguments: - `get_url`: a `str` or `Callable` which returns `str`. If `get_url` is a `str` then the resolver will split it by `.` and use substrings as keys to get value from `obj` dict, e.g. with `get_url` set to `"imageData.url"` the resolver will use `obj["imageData"]["url"]` as URL string. If `get_url` is a callable, then resolver will call it with `obj`, `info` and `**kwargs` and use result as URL string. @@ -457,6 +458,7 @@ app = GraphQL( - `get_params`: an optional `Callable` to be called on passed `**kwargs` before they are added to the URL string. - `serialize_url`: an optional `Callable` to be called on URL string with query params already added. Result is returned directly by the resolver. + ### Example with `insert_field` In this example we assume there is a graphql server which provides following schema: @@ -847,6 +849,7 @@ def insert_field(self, type_name: str, field_str: str): Inserts field into all schemas with given `type_name`. The field is automatically delayed - excluded from queries run by `root_resolver` against the remote GraphQL APIs. + #### Required arguments - `type_name`: a `str` with the name of the type into which the field will be inserted. From bea63fabf336675883d86f332dd1557b9d519c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Thu, 21 Sep 2023 10:29:59 +0200 Subject: [PATCH 5/5] Change get_attribute_value to handle non dict objects --- GUIDE.md | 2 +- .../contrib/imgix/query_params_resolver.py | 5 +- .../imgix/test_query_params_resolver.py | 73 ++++++++++++++++++- 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/GUIDE.md b/GUIDE.md index 5e430b0..a0ded8c 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -453,7 +453,7 @@ app = GraphQL( ### Function arguments: -- `get_url`: a `str` or `Callable` which returns `str`. If `get_url` is a `str` then the resolver will split it by `.` and use substrings as keys to get value from `obj` dict, e.g. with `get_url` set to `"imageData.url"` the resolver will use `obj["imageData"]["url"]` as URL string. If `get_url` is a callable, then resolver will call it with `obj`, `info` and `**kwargs` and use result as URL string. +- `get_url`: a `str` or `Callable` which returns `str`. If `get_url` is a `str` then the resolver will split it by `.` and use substrings as keys to get value from `obj` dict or as attribute names for non dict objects, e.g. with `get_url` set to `"imageData.url"` the resolver will use one of: `obj["imageData"]["url"]`, `obj["imageData"].url`, `obj.imageData["url"]`, `obj.imageData.url` as URL string. If `get_url` is a callable, then resolver will call it with `obj`, `info` and `**kwargs` and use result as URL string. - `extra_params`: an optional `dict` of query params to be added to the URL string. These can be overridden by kwargs passed to the resolver. - `get_params`: an optional `Callable` to be called on passed `**kwargs` before they are added to the URL string. - `serialize_url`: an optional `Callable` to be called on URL string with query params already added. Result is returned directly by the resolver. diff --git a/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py b/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py index 0c14c72..8495b8f 100644 --- a/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py +++ b/ariadne_graphql_proxy/contrib/imgix/query_params_resolver.py @@ -10,7 +10,10 @@ def get_attribute_value( ) -> Any: value = obj for attr in attribute_str.split("."): - value = value.get(attr) + try: + value = value.get(attr) + except AttributeError: + value = getattr(value, attr, None) return value diff --git a/tests/contrib/imgix/test_query_params_resolver.py b/tests/contrib/imgix/test_query_params_resolver.py index de9721d..a92f0ce 100644 --- a/tests/contrib/imgix/test_query_params_resolver.py +++ b/tests/contrib/imgix/test_query_params_resolver.py @@ -1,6 +1,77 @@ +from dataclasses import dataclass + import pytest -from ariadne_graphql_proxy.contrib.imgix import get_query_params_resolver +from ariadne_graphql_proxy.contrib.imgix.query_params_resolver import ( + get_attribute_value, + get_query_params_resolver, +) + + +def test_get_attribute_value_returns_value_from_dict(): + assert get_attribute_value({"key": "value"}, None, attribute_str="key") == "value" + + +def test_get_attribute_value_returns_nested_value_from_dict(): + assert ( + get_attribute_value( + {"keyA": {"keyB": {"keyC": "valueC"}}}, None, attribute_str="keyA.keyB.keyC" + ) + == "valueC" + ) + + +def test_get_attribute_value_returns_attribute_of_not_dict_object(): + @dataclass + class TypeA: + value_a: str + + assert ( + get_attribute_value(TypeA(value_a="valueA"), None, attribute_str="value_a") + == "valueA" + ) + + +def test_get_attribute_value_returns_nested_attribute_of_not_dict_object(): + @dataclass + class TypeC: + value_c: str + + @dataclass + class TypeB: + key_c: TypeC + + @dataclass + class TypeA: + key_b: TypeB + + assert ( + get_attribute_value( + TypeA(key_b=TypeB(key_c=TypeC(value_c="value_c"))), + None, + attribute_str="key_b.key_c.value_c", + ) + == "value_c" + ) + + +def test_get_attribute_value_returns_value_from_both_dict_and_non_dict_objects(): + @dataclass + class TypeB: + value_b: dict + + @dataclass + class TypeA: + key_a: TypeB + + assert ( + get_attribute_value( + {"xyz": {"a": TypeA(key_a=TypeB(value_b={"c": "value_c"}))}}, + None, + attribute_str="xyz.a.key_a.value_b.c", + ) + == "value_c" + ) def test_resolver_returns_url_from_given_attribute():