diff --git a/graphene_federation/entity.py b/graphene_federation/entity.py index 2383fa2..b3dcc50 100644 --- a/graphene_federation/entity.py +++ b/graphene_federation/entity.py @@ -1,21 +1,19 @@ from __future__ import annotations +import collections.abc from typing import Any, Callable, Dict -from graphene import List, NonNull, Union - +from graphene import Field, List, NonNull, ObjectType, Union from graphene.types.schema import Schema from graphene.types.schema import TypeMap from .types import _Any from .utils import ( - field_name_to_type_attribute, check_fields_exist_on_type, + field_name_to_type_attribute, is_valid_compound_key, ) -import collections.abc - def update(d, u): for k, v in u.items(): @@ -80,7 +78,7 @@ class EntityQuery: required=True, ) - def resolve_entities(self, info, representations): + def resolve_entities(self, info, representations, sub_field_resolution=False): entities = [] for representation in representations: type_ = schema.graphql_schema.get_type(representation["__typename"]) @@ -92,12 +90,53 @@ def resolve_entities(self, info, representations): model_arguments = { get_model_attr(k): v for k, v in model_arguments.items() } + + # convert subfields of models from dict to a corresponding graphql type + for model_field, value in model_arguments.items(): + if not hasattr(model, model_field): + continue + + field = getattr(model, model_field) + if isinstance(field, Field) and isinstance(value, dict): + if value.get("__typename") is None: + value["__typename"] = field.type.of_type._meta.name + model_arguments[model_field] = EntityQuery.resolve_entities( + self, + info, + representations=[value], + sub_field_resolution=True, + ).pop() + elif all( + [ + isinstance(field, List), + isinstance(value, list), + any( + [ + ( + hasattr(field, "of_type") + and issubclass(field.of_type, ObjectType) + ), + ( + hasattr(field, "of_type") + and issubclass(field.of_type, Union) + ), + ] + ), + ] + ): + for sub_value in value: + if sub_value.get("__typename") is None: + sub_value["__typename"] = field.type.of_type._meta.name + model_arguments[model_field] = EntityQuery.resolve_entities( + self, info, representations=value, sub_field_resolution=True + ) + model_instance = model(**model_arguments) resolver = getattr( model, "_%s__resolve_reference" % model.__name__, None ) or getattr(model, "_resolve_reference", None) - if resolver: + if resolver and not sub_field_resolution: model_instance = resolver(model_instance, info) entities.append(model_instance) diff --git a/graphene_federation/service.py b/graphene_federation/service.py index 27655c9..7c55aca 100644 --- a/graphene_federation/service.py +++ b/graphene_federation/service.py @@ -38,10 +38,25 @@ def convert_fields(schema: Schema, fields: List[str]) -> str: return " ".join([get_field_name(field) for field in fields]) +def convert_fields_for_requires(schema: Schema, fields: List[str]) -> str: + """ + Adds __typename for resolving union,sub-field types + """ + get_field_name = type_attribute_to_field_name(schema) + new_fields = [] + for field in fields: + if "typename" not in field.lower(): # skip user defined typename + new_fields.append(get_field_name(field)) + if "{" in field: + new_fields.append("__typename") + + return " ".join(new_fields) + + DECORATORS = { "_external": lambda schema, fields: "@external", "_requires": lambda schema, fields: f'@requires(fields: "{convert_fields(schema, fields)}")', - "_provides": lambda schema, fields: f'@provides(fields: "{convert_fields(schema, fields)}")', + "_provides": lambda schema, fields: f'@provides(fields: "{convert_fields_for_requires(schema, fields)}")', "_shareable": lambda schema, fields: "@shareable", "_inaccessible": lambda schema, fields: "@inaccessible", "_override": lambda schema, from_: f'@override(from: "{from_}")',