From c3b91fbbeebe48234586e6e6051a22e0f151da97 Mon Sep 17 00:00:00 2001 From: Damian Czajkowski Date: Tue, 16 Apr 2024 11:32:06 +0200 Subject: [PATCH] Inherit from GraphQLObject in the Interface type --- ariadne_graphql_modules/next/field.py | 50 --- ariadne_graphql_modules/next/interfacetype.py | 205 ++++----- ariadne_graphql_modules/next/metadata.py | 123 ------ ariadne_graphql_modules/next/objectmixin.py | 196 --------- ariadne_graphql_modules/next/objecttype.py | 392 ++++++++++++++++-- ariadne_graphql_modules/next/resolver.py | 94 ----- .../test_get_field_args_from_resolver.py | 2 +- 7 files changed, 458 insertions(+), 604 deletions(-) delete mode 100644 ariadne_graphql_modules/next/field.py delete mode 100644 ariadne_graphql_modules/next/metadata.py delete mode 100644 ariadne_graphql_modules/next/objectmixin.py delete mode 100644 ariadne_graphql_modules/next/resolver.py diff --git a/ariadne_graphql_modules/next/field.py b/ariadne_graphql_modules/next/field.py deleted file mode 100644 index 01ae10d..0000000 --- a/ariadne_graphql_modules/next/field.py +++ /dev/null @@ -1,50 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from ariadne.types import Resolver - - -@dataclass -class GraphQLObjectFieldArg: - name: str - out_name: str - type: Any - description: Optional[str] = None - default_value: Optional[Any] = None - - -class GraphQLObjectField: - name: Optional[str] - description: Optional[str] - type: Optional[Any] - args: Optional[Dict[str, dict]] - resolver: Optional[Resolver] - default_value: Optional[Any] - - def __init__( - self, - *, - name: Optional[str] = None, - description: Optional[str] = None, - type: Optional[Any] = None, - args: Optional[Dict[str, dict]] = None, - resolver: Optional[Resolver] = None, - default_value: Optional[Any] = None, - ): - self.name = name - self.description = description - self.type = type - self.args = args - self.resolver = resolver - self.default_value = default_value - - def __call__(self, resolver: Resolver): - """Makes GraphQLObjectField instances work as decorators.""" - self.resolver = resolver - if not self.type: - self.type = get_field_type_from_resolver(resolver) - return self - - -def get_field_type_from_resolver(resolver: Resolver) -> Any: - return resolver.__annotations__.get("return") diff --git a/ariadne_graphql_modules/next/interfacetype.py b/ariadne_graphql_modules/next/interfacetype.py index 3f97f28..be7beb5 100644 --- a/ariadne_graphql_modules/next/interfacetype.py +++ b/ariadne_graphql_modules/next/interfacetype.py @@ -11,7 +11,6 @@ Union, cast, Sequence, - NoReturn, ) from ariadne import InterfaceType @@ -21,47 +20,33 @@ GraphQLField, GraphQLObjectType, GraphQLSchema, + InputValueDefinitionNode, NameNode, NamedTypeNode, InterfaceTypeDefinitionNode, ) -from .metadata import get_graphql_object_data -from .objectmixin import GraphQLModelHelpersMixin +from .value import get_value_node + +from .objecttype import ( + GraphQLObject, + GraphQLObjectResolver, + get_field_args_from_resolver, + get_field_args_out_names, + get_field_node_from_obj_field, + get_graphql_object_data, + update_field_args_options, +) from ..utils import parse_definition from .base import GraphQLMetadata, GraphQLModel, GraphQLType from .description import get_description_node -from .typing import get_graphql_type -from .validators import validate_description, validate_name -class GraphQLInterface(GraphQLType, GraphQLModelHelpersMixin): +class GraphQLInterface(GraphQLObject): __types__: Sequence[Type[GraphQLType]] __implements__: Optional[Iterable[Union[Type[GraphQLType], Type[Enum]]]] - - def __init_subclass__(cls) -> None: - super().__init_subclass__() - - if cls.__dict__.get("__abstract__"): - return - - cls.__abstract__ = False - - if cls.__dict__.get("__schema__"): - validate_interface_type_with_schema(cls) - else: - validate_interface_type(cls) - - @classmethod - def __get_graphql_model__(cls, metadata: GraphQLMetadata) -> "GraphQLModel": - name = cls.__get_graphql_name__() - metadata.set_graphql_name(cls, name) - - if getattr(cls, "__schema__", None): - return cls.__get_graphql_model_with_schema__(metadata, name) - - return cls.__get_graphql_model_without_schema__(metadata, name) + __valid_type__ = InterfaceTypeDefinitionNode @classmethod def __get_graphql_model_with_schema__( @@ -71,9 +56,75 @@ def __get_graphql_model_with_schema__( InterfaceTypeDefinitionNode, parse_definition(InterfaceTypeDefinitionNode, cls.__schema__), ) - resolvers: Dict[str, Resolver] = cls.collect_resolvers_with_schema() + + descriptions: Dict[str, str] = {} + args_descriptions: Dict[str, Dict[str, str]] = {} + args_defaults: Dict[str, Dict[str, Any]] = {} + resolvers: Dict[str, Resolver] = {} out_names: Dict[str, Dict[str, str]] = {} - fields: List[FieldDefinitionNode] = cls.gather_fields_with_schema(definition) + + for attr_name in dir(cls): + cls_attr = getattr(cls, attr_name) + if isinstance(cls_attr, GraphQLObjectResolver): + resolvers[cls_attr.field] = cls_attr.resolver + if cls_attr.description: + descriptions[cls_attr.field] = get_description_node( + cls_attr.description + ) + + field_args = get_field_args_from_resolver(cls_attr.resolver) + if field_args: + args_descriptions[cls_attr.field] = {} + args_defaults[cls_attr.field] = {} + + final_args = update_field_args_options(field_args, cls_attr.args) + + for arg_name, arg_options in final_args.items(): + arg_description = get_description_node(arg_options.description) + if arg_description: + args_descriptions[cls_attr.field][arg_name] = ( + arg_description + ) + + arg_default = arg_options.default_value + if arg_default is not None: + args_defaults[cls_attr.field][arg_name] = get_value_node( + arg_default + ) + + fields: List[FieldDefinitionNode] = [] + for field in definition.fields: + field_args_descriptions = args_descriptions.get(field.name.value, {}) + field_args_defaults = args_defaults.get(field.name.value, {}) + + args: List[InputValueDefinitionNode] = [] + for arg in field.arguments: + arg_name = arg.name.value + args.append( + InputValueDefinitionNode( + description=( + arg.description or field_args_descriptions.get(arg_name) + ), + name=arg.name, + directives=arg.directives, + type=arg.type, + default_value=( + arg.default_value or field_args_defaults.get(arg_name) + ), + ) + ) + + fields.append( + FieldDefinitionNode( + name=field.name, + description=( + field.description or descriptions.get(field.name.value) + ), + directives=field.directives, + arguments=tuple(args), + type=field.type, + ) + ) return GraphQLInterfaceModel( name=definition.name.value, @@ -94,16 +145,30 @@ def __get_graphql_model_without_schema__( cls, metadata: GraphQLMetadata, name: str ) -> "GraphQLInterfaceModel": type_data = get_graphql_object_data(metadata, cls) + type_aliases = getattr(cls, "__aliases__", None) or {} - fields_ast: List[FieldDefinitionNode] = cls.gather_fields_without_schema( - metadata - ) - interfaces_ast: List[NamedTypeNode] = cls.gather_interfaces_without_schema( - type_data - ) - resolvers: Dict[str, Resolver] = cls.collect_resolvers_without_schema(type_data) - aliases: Dict[str, str] = cls.collect_aliases(type_data) - out_names: Dict[str, Dict[str, str]] = cls.collect_out_names(type_data) + fields_ast: List[FieldDefinitionNode] = [] + resolvers: Dict[str, Resolver] = {} + aliases: Dict[str, str] = {} + out_names: Dict[str, Dict[str, str]] = {} + + for attr_name, field in type_data.fields.items(): + fields_ast.append(get_field_node_from_obj_field(cls, metadata, field)) + + if attr_name in type_aliases: + aliases[field.name] = type_aliases[attr_name] + elif attr_name != field.name: + aliases[field.name] = attr_name + + if field.resolver: + resolvers[field.name] = field.resolver + + if field.args: + out_names[field.name] = get_field_args_out_names(field.args) + + interfaces_ast: List[NamedTypeNode] = [] + for interface_name in type_data.interfaces: + interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name))) return GraphQLInterfaceModel( name=name, @@ -122,44 +187,6 @@ def __get_graphql_model_without_schema__( out_names=out_names, ) - @classmethod - def __get_graphql_types__( - cls, metadata: "GraphQLMetadata" - ) -> Iterable["GraphQLType"]: - """Returns iterable with GraphQL types associated with this type""" - if getattr(cls, "__schema__", None): - return cls.__get_graphql_types_with_schema__(metadata) - - return cls.__get_graphql_types_without_schema__(metadata) - - @classmethod - def __get_graphql_types_with_schema__( - cls, metadata: "GraphQLMetadata" - ) -> Iterable["GraphQLType"]: - types: List[GraphQLType] = [cls] - types.extend(getattr(cls, "__requires__", [])) - return types - - @classmethod - def __get_graphql_types_without_schema__( - cls, metadata: "GraphQLMetadata" - ) -> Iterable["GraphQLType"]: - types: List[GraphQLType] = [cls] - type_data = get_graphql_object_data(metadata, cls) - - for field in type_data.fields.values(): - field_type = get_graphql_type(field.type) - if field_type and field_type not in types: - types.append(field_type) - - if field.args: - for field_arg in field.args.values(): - field_arg_type = get_graphql_type(field_arg.type) - if field_arg_type and field_arg_type not in types: - types.append(field_arg_type) - - return types - @staticmethod def resolve_type(obj: Any, *args) -> str: if isinstance(obj, GraphQLInterface): @@ -170,28 +197,6 @@ def resolve_type(obj: Any, *args) -> str: ) -def validate_interface_type(cls: Type[GraphQLInterface]) -> NoReturn: - pass - - -def validate_interface_type_with_schema(cls: Type[GraphQLInterface]) -> NoReturn: - definition = cast( - InterfaceTypeDefinitionNode, - parse_definition(InterfaceTypeDefinitionNode, cls.__schema__), - ) - - if not isinstance(definition, InterfaceTypeDefinitionNode): - raise ValueError( - f"Class '{cls.__name__}' defines a '__schema__' attribute " - "with declaration for an invalid GraphQL type. " - f"('{definition.__class__.__name__}' != " - f"'{InterfaceTypeDefinitionNode.__name__}')" - ) - - validate_name(cls, definition) - validate_description(cls, definition) - - @dataclass(frozen=True) class GraphQLInterfaceModel(GraphQLModel): resolvers: Dict[str, Resolver] diff --git a/ariadne_graphql_modules/next/metadata.py b/ariadne_graphql_modules/next/metadata.py deleted file mode 100644 index a507852..0000000 --- a/ariadne_graphql_modules/next/metadata.py +++ /dev/null @@ -1,123 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, List - -from .base import GraphQLMetadata -from .convert_name import convert_python_name_to_graphql -from .field import GraphQLObjectField, GraphQLObjectFieldArg -from .resolver import ( - GraphQLObjectResolver, - get_field_args_from_resolver, - update_field_args_options, -) -from ariadne.types import Resolver - - -@dataclass(frozen=True) -class GraphQLObjectData: - fields: Dict[str, GraphQLObjectField] - interfaces: List[str] - - -def get_graphql_object_data(metadata: GraphQLMetadata, cls): - try: - return metadata.get_data(cls) - except KeyError: - if getattr(cls, "__schema__", None): - raise NotImplementedError( - "'get_graphql_object_data' is not supported for objects with '__schema__'." - ) - else: - return create_graphql_object_data_without_schema(cls) - - -def create_graphql_object_data_without_schema( - cls, -) -> GraphQLObjectData: - fields_types: Dict[str, str] = {} - fields_names: Dict[str, str] = {} - fields_descriptions: Dict[str, str] = {} - fields_args: Dict[str, Dict[str, GraphQLObjectFieldArg]] = {} - fields_resolvers: Dict[str, Resolver] = {} - fields_defaults: Dict[str, Any] = {} - fields_order: List[str] = [] - - type_hints = cls.__annotations__ - - aliases: Dict[str, str] = getattr(cls, "__aliases__", None) or {} - aliases_targets: List[str] = list(aliases.values()) - - interfaces: List[str] = [ - interface.__name__ for interface in getattr(cls, "__implements__", []) - ] - - for attr_name, attr_type in type_hints.items(): - if attr_name.startswith("__"): - continue - - if attr_name in aliases_targets: - # Alias target is not included in schema - # unless its explicit field - cls_attr = getattr(cls, attr_name, None) - if not isinstance(cls_attr, GraphQLObjectField): - continue - - fields_order.append(attr_name) - - fields_names[attr_name] = convert_python_name_to_graphql(attr_name) - fields_types[attr_name] = attr_type - - for attr_name in dir(cls): - if attr_name.startswith("__"): - continue - - cls_attr = getattr(cls, attr_name) - if isinstance(cls_attr, GraphQLObjectField): - if attr_name not in fields_order: - fields_order.append(attr_name) - - fields_names[attr_name] = cls_attr.name or convert_python_name_to_graphql( - attr_name - ) - - if cls_attr.type and attr_name not in fields_types: - fields_types[attr_name] = cls_attr.type - if cls_attr.description: - fields_descriptions[attr_name] = cls_attr.description - if cls_attr.resolver: - fields_resolvers[attr_name] = cls_attr.resolver - field_args = get_field_args_from_resolver(cls_attr.resolver) - if field_args: - fields_args[attr_name] = update_field_args_options( - field_args, cls_attr.args - ) - if cls_attr.default_value: - fields_defaults[attr_name] = cls_attr.default_value - - elif isinstance(cls_attr, GraphQLObjectResolver): - if cls_attr.type and cls_attr.field not in fields_types: - fields_types[cls_attr.field] = cls_attr.type - if cls_attr.description: - fields_descriptions[cls_attr.field] = cls_attr.description - if cls_attr.resolver: - fields_resolvers[cls_attr.field] = cls_attr.resolver - field_args = get_field_args_from_resolver(cls_attr.resolver) - if field_args: - fields_args[cls_attr.field] = update_field_args_options( - field_args, cls_attr.args - ) - - elif attr_name not in aliases_targets and not callable(cls_attr): - fields_defaults[attr_name] = cls_attr - - fields: Dict[str, "GraphQLObjectField"] = {} - for field_name in fields_order: - fields[field_name] = GraphQLObjectField( - name=fields_names[field_name], - description=fields_descriptions.get(field_name), - type=fields_types[field_name], - args=fields_args.get(field_name), - resolver=fields_resolvers.get(field_name), - default_value=fields_defaults.get(field_name), - ) - - return GraphQLObjectData(fields=fields, interfaces=interfaces) diff --git a/ariadne_graphql_modules/next/objectmixin.py b/ariadne_graphql_modules/next/objectmixin.py deleted file mode 100644 index b0b650c..0000000 --- a/ariadne_graphql_modules/next/objectmixin.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Dict, List, Optional -from graphql import ( - FieldDefinitionNode, - InputValueDefinitionNode, - NameNode, - NamedTypeNode, -) - -from .base import GraphQLMetadata -from .description import get_description_node -from .field import GraphQLObjectField, GraphQLObjectFieldArg -from .metadata import ( - GraphQLObjectData, - get_graphql_object_data, -) -from .resolver import ( - GraphQLObjectResolver, - get_field_args_from_resolver, - update_field_args_options, -) -from .typing import get_type_node -from .value import get_value_node - - -class GraphQLModelHelpersMixin: - @classmethod - def gather_fields_without_schema(cls, metadata: GraphQLMetadata): - type_data = get_graphql_object_data(metadata, cls) - fields_ast = [ - cls.create_field_node(metadata, field) - for field in type_data.fields.values() - ] - return fields_ast - - @staticmethod - def collect_resolvers_without_schema(type_data: GraphQLObjectData): - return { - field.name: field.resolver - for field in type_data.fields.values() - if field.resolver - } - - @staticmethod - def collect_aliases(type_data: GraphQLObjectData): - aliases = {} - for field in type_data.fields.values(): - attr_name = field.name # Placeholder for actual attribute name logic - if attr_name != field.name: - aliases[field.name] = attr_name - return aliases - - @staticmethod - def collect_out_names(type_data: GraphQLObjectData): - out_names = {} - for field in type_data.fields.values(): - if field.args: - out_names[field.name] = { - arg.name: arg.out_name for arg in field.args.values() - } - return out_names - - @classmethod - def create_field_node(cls, metadata: GraphQLMetadata, field: GraphQLObjectField): - args_nodes = cls.get_field_args_nodes_from_obj_field_args(metadata, field.args) - return FieldDefinitionNode( - description=get_description_node(field.description), - name=NameNode(value=field.name), - type=get_type_node(metadata, field.type), - arguments=tuple(args_nodes) if args_nodes else None, - ) - - @staticmethod - def get_field_args_nodes_from_obj_field_args( - metadata: GraphQLMetadata, - field_args: Optional[Dict[str, GraphQLObjectFieldArg]], - ): - if not field_args: - return [] - - return [ - InputValueDefinitionNode( - description=get_description_node(arg.description), - name=NameNode(value=arg.name), - type=get_type_node(metadata, arg.type), - default_value=get_value_node(arg.default_value) - if arg.default_value is not None - else None, - ) - for arg in field_args.values() - ] - - @classmethod - def gather_interfaces_without_schema(cls, type_data: GraphQLObjectData): - interfaces_ast: List[NamedTypeNode] = [] - for interface_name in type_data.interfaces: - interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name))) - return interfaces_ast - - @classmethod - def gather_interfaces_with_schema(cls): - return [interface for interface in getattr(cls, "__implements__", [])] - - @classmethod - def gather_fields_with_schema(cls, definition): - descriptions, args_descriptions, args_defaults = ( - cls.collect_descriptions_and_defaults() - ) - - fields = [ - cls.create_field_definition_node( - field, descriptions, args_descriptions, args_defaults - ) - for field in definition.fields - ] - return fields - - @classmethod - def collect_resolvers_with_schema(cls): - resolvers = {} - for attr_name in dir(cls): - cls_attr = getattr(cls, attr_name) - if isinstance(cls_attr, GraphQLObjectResolver): - resolvers[cls_attr.field] = cls_attr.resolver - return resolvers - - @classmethod - def collect_descriptions_and_defaults(cls): - descriptions = {} - args_descriptions = {} - args_defaults = {} - for attr_name in dir(cls): - cls_attr = getattr(cls, attr_name) - if isinstance(cls_attr, GraphQLObjectResolver): - if cls_attr.description: - descriptions[cls_attr.field] = get_description_node( - cls_attr.description - ) - - field_args = get_field_args_from_resolver(cls_attr.resolver) - if field_args: - args_descriptions[cls_attr.field], args_defaults[cls_attr.field] = ( - cls.process_field_args(field_args, cls_attr.args) - ) - - return descriptions, args_descriptions, args_defaults - - @classmethod - def process_field_args(cls, field_args, resolver_args): - descriptions = {} - defaults = {} - final_args = update_field_args_options(field_args, resolver_args) - - for arg_name, arg_options in final_args.items(): - descriptions[arg_name] = ( - get_description_node(arg_options.description) - if arg_options.description - else None - ) - defaults[arg_name] = ( - get_value_node(arg_options.default_value) - if arg_options.default_value is not None - else None - ) - - return descriptions, defaults - - @classmethod - def create_field_definition_node( - cls, field, descriptions, args_descriptions, args_defaults - ): - field_name = field.name.value - field_args_descriptions = args_descriptions.get(field_name, {}) - field_args_defaults = args_defaults.get(field_name, {}) - - args = [ - InputValueDefinitionNode( - description=( - arg.description or field_args_descriptions.get(arg.name.value) - ), - name=arg.name, - directives=arg.directives, - type=arg.type, - default_value=( - arg.default_value or field_args_defaults.get(arg.name.value) - ), - ) - for arg in field.arguments - ] - - return FieldDefinitionNode( - name=field.name, - description=(field.description or descriptions.get(field_name)), - directives=field.directives, - arguments=tuple(args), - type=field.type, - ) diff --git a/ariadne_graphql_modules/next/objecttype.py b/ariadne_graphql_modules/next/objecttype.py index 88508c5..ff803af 100644 --- a/ariadne_graphql_modules/next/objecttype.py +++ b/ariadne_graphql_modules/next/objecttype.py @@ -1,6 +1,7 @@ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum +from inspect import signature from typing import ( Any, Dict, @@ -22,22 +23,9 @@ GraphQLSchema, InputValueDefinitionNode, NameNode, - ObjectTypeDefinitionNode, NamedTypeNode, -) - -from .field import ( - GraphQLObjectField, - GraphQLObjectFieldArg, - get_field_type_from_resolver, -) -from .metadata import ( - get_graphql_object_data, -) -from .objectmixin import GraphQLModelHelpersMixin -from .resolver import ( - GraphQLObjectResolver, - get_field_args_from_resolver, + ObjectTypeDefinitionNode, + TypeDefinitionNode, ) from ..utils import parse_definition @@ -49,7 +37,7 @@ from .value import get_value_node -class GraphQLObject(GraphQLType, GraphQLModelHelpersMixin): +class GraphQLObject(GraphQLType): __kwargs__: Dict[str, Any] __abstract__: bool = True __schema__: Optional[str] @@ -78,10 +66,10 @@ def __init_subclass__(cls) -> None: return cls.__abstract__ = False - # cls.__validate_interfaces__() if cls.__dict__.get("__schema__"): - cls.__kwargs__ = validate_object_type_with_schema(cls) + valid_type = getattr(cls, "__valid_type__", ObjectTypeDefinitionNode) + cls.__kwargs__ = validate_object_type_with_schema(cls, valid_type) else: cls.__kwargs__ = validate_object_type_without_schema(cls) @@ -104,9 +92,74 @@ def __get_graphql_model_with_schema__( parse_definition(ObjectTypeDefinitionNode, cls.__schema__), ) - resolvers: Dict[str, Resolver] = cls.collect_resolvers_with_schema() + descriptions: Dict[str, str] = {} + args_descriptions: Dict[str, Dict[str, str]] = {} + args_defaults: Dict[str, Dict[str, Any]] = {} + resolvers: Dict[str, Resolver] = {} out_names: Dict[str, Dict[str, str]] = {} - fields: List[FieldDefinitionNode] = cls.gather_fields_with_schema(definition) + + for attr_name in dir(cls): + cls_attr = getattr(cls, attr_name) + if isinstance(cls_attr, GraphQLObjectResolver): + resolvers[cls_attr.field] = cls_attr.resolver + if cls_attr.description: + descriptions[cls_attr.field] = get_description_node( + cls_attr.description + ) + + field_args = get_field_args_from_resolver(cls_attr.resolver) + if field_args: + args_descriptions[cls_attr.field] = {} + args_defaults[cls_attr.field] = {} + + final_args = update_field_args_options(field_args, cls_attr.args) + + for arg_name, arg_options in final_args.items(): + arg_description = get_description_node(arg_options.description) + if arg_description: + args_descriptions[cls_attr.field][arg_name] = ( + arg_description + ) + + arg_default = arg_options.default_value + if arg_default is not None: + args_defaults[cls_attr.field][arg_name] = get_value_node( + arg_default + ) + + fields: List[FieldDefinitionNode] = [] + for field in definition.fields: + field_args_descriptions = args_descriptions.get(field.name.value, {}) + field_args_defaults = args_defaults.get(field.name.value, {}) + + args: List[InputValueDefinitionNode] = [] + for arg in field.arguments: + arg_name = arg.name.value + args.append( + InputValueDefinitionNode( + description=( + arg.description or field_args_descriptions.get(arg_name) + ), + name=arg.name, + directives=arg.directives, + type=arg.type, + default_value=( + arg.default_value or field_args_defaults.get(arg_name) + ), + ) + ) + + fields.append( + FieldDefinitionNode( + name=field.name, + description=( + field.description or descriptions.get(field.name.value) + ), + directives=field.directives, + arguments=tuple(args), + type=field.type, + ) + ) return GraphQLObjectModel( name=definition.name.value, @@ -126,16 +179,30 @@ def __get_graphql_model_without_schema__( cls, metadata: GraphQLMetadata, name: str ) -> "GraphQLObjectModel": type_data = get_graphql_object_data(metadata, cls) + type_aliases = getattr(cls, "__aliases__", None) or {} - fields_ast: List[FieldDefinitionNode] = cls.gather_fields_without_schema( - metadata - ) - interfaces_ast: List[NamedTypeNode] = cls.gather_interfaces_without_schema( - type_data - ) - resolvers: Dict[str, Resolver] = cls.collect_resolvers_without_schema(type_data) - aliases: Dict[str, str] = cls.collect_aliases(type_data) - out_names: Dict[str, Dict[str, str]] = cls.collect_out_names(type_data) + fields_ast: List[FieldDefinitionNode] = [] + resolvers: Dict[str, Resolver] = {} + aliases: Dict[str, str] = {} + out_names: Dict[str, Dict[str, str]] = {} + + for attr_name, field in type_data.fields.items(): + fields_ast.append(get_field_node_from_obj_field(cls, metadata, field)) + + if attr_name in type_aliases: + aliases[field.name] = type_aliases[attr_name] + elif attr_name != field.name: + aliases[field.name] = attr_name + + if field.resolver: + resolvers[field.name] = field.resolver + + if field.args: + out_names[field.name] = get_field_args_out_names(field.args) + + interfaces_ast: List[NamedTypeNode] = [] + for interface_name in type_data.interfaces: + interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name))) return GraphQLObjectModel( name=name, @@ -244,12 +311,155 @@ def argument( options["default_value"] = default_value return options - @classmethod - def __validate_interfaces__(cls): - if getattr(cls, "__implements__", None): - for interface in cls.__implements__: - if not issubclass(interface, GraphQLType): - raise TypeError() + +@dataclass(frozen=True) +class GraphQLObjectData: + fields: Dict[str, "GraphQLObjectField"] + interfaces: List[str] + + +def get_graphql_object_data( + metadata: GraphQLMetadata, cls: Type[GraphQLObject] +) -> GraphQLObjectData: + try: + return metadata.get_data(cls) + except KeyError: + if getattr(cls, "__schema__", None): + raise NotImplementedError( + "'get_graphql_object_data' is not supported for " + "objects with '__schema__'." + ) + else: + object_data = create_graphql_object_data_without_schema(cls) + + metadata.set_data(cls, object_data) + return object_data + + +def create_graphql_object_data_without_schema( + cls, +) -> GraphQLObjectData: + fields_types: Dict[str, str] = {} + fields_names: Dict[str, str] = {} + fields_descriptions: Dict[str, str] = {} + fields_args: Dict[str, Dict[str, GraphQLObjectFieldArg]] = {} + fields_resolvers: Dict[str, Resolver] = {} + fields_defaults: Dict[str, Any] = {} + fields_order: List[str] = [] + + type_hints = cls.__annotations__ + + aliases: Dict[str, str] = getattr(cls, "__aliases__", None) or {} + aliases_targets: List[str] = list(aliases.values()) + + interfaces: List[str] = [ + interface.__name__ for interface in getattr(cls, "__implements__", []) + ] + + for attr_name, attr_type in type_hints.items(): + if attr_name.startswith("__"): + continue + + if attr_name in aliases_targets: + # Alias target is not included in schema + # unless its explicit field + cls_attr = getattr(cls, attr_name, None) + if not isinstance(cls_attr, GraphQLObjectField): + continue + + fields_order.append(attr_name) + + fields_names[attr_name] = convert_python_name_to_graphql(attr_name) + fields_types[attr_name] = attr_type + + for attr_name in dir(cls): + if attr_name.startswith("__"): + continue + + cls_attr = getattr(cls, attr_name) + if isinstance(cls_attr, GraphQLObjectField): + if attr_name not in fields_order: + fields_order.append(attr_name) + + fields_names[attr_name] = cls_attr.name or convert_python_name_to_graphql( + attr_name + ) + + if cls_attr.type and attr_name not in fields_types: + fields_types[attr_name] = cls_attr.type + if cls_attr.description: + fields_descriptions[attr_name] = cls_attr.description + if cls_attr.resolver: + fields_resolvers[attr_name] = cls_attr.resolver + field_args = get_field_args_from_resolver(cls_attr.resolver) + if field_args: + fields_args[attr_name] = update_field_args_options( + field_args, cls_attr.args + ) + if cls_attr.default_value: + fields_defaults[attr_name] = cls_attr.default_value + + elif isinstance(cls_attr, GraphQLObjectResolver): + if cls_attr.type and cls_attr.field not in fields_types: + fields_types[cls_attr.field] = cls_attr.type + if cls_attr.description: + fields_descriptions[cls_attr.field] = cls_attr.description + if cls_attr.resolver: + fields_resolvers[cls_attr.field] = cls_attr.resolver + field_args = get_field_args_from_resolver(cls_attr.resolver) + if field_args: + fields_args[cls_attr.field] = update_field_args_options( + field_args, cls_attr.args + ) + + elif attr_name not in aliases_targets and not callable(cls_attr): + fields_defaults[attr_name] = cls_attr + + fields: Dict[str, "GraphQLObjectField"] = {} + for field_name in fields_order: + fields[field_name] = GraphQLObjectField( + name=fields_names[field_name], + description=fields_descriptions.get(field_name), + type=fields_types[field_name], + args=fields_args.get(field_name), + resolver=fields_resolvers.get(field_name), + default_value=fields_defaults.get(field_name), + ) + + return GraphQLObjectData(fields=fields, interfaces=interfaces) + + +class GraphQLObjectField: + name: Optional[str] + description: Optional[str] + type: Optional[Any] + args: Optional[Dict[str, dict]] + resolver: Optional[Resolver] + default_value: Optional[Any] + + def __init__( + self, + *, + name: Optional[str] = None, + description: Optional[str] = None, + type: Optional[Any] = None, + args: Optional[Dict[str, dict]] = None, + resolver: Optional[Resolver] = None, + default_value: Optional[Any] = None, + ): + self.name = name + self.description = description + self.type = type + self.args = args + self.resolver = resolver + self.default_value = default_value + + def __call__(self, resolver: Resolver): + """Makes GraphQLObjectField instances work as decorators.""" + self.resolver = resolver + if not self.type: + self.type = get_field_type_from_resolver(resolver) + return self def object_field( @@ -275,6 +485,19 @@ def object_field( ) +def get_field_type_from_resolver(resolver: Resolver) -> Any: + return resolver.__annotations__.get("return") + + +@dataclass(frozen=True) +class GraphQLObjectResolver: + resolver: Resolver + field: str + description: Optional[str] = None + args: Optional[Dict[str, dict]] = None + type: Optional[Any] = None + + def object_resolver( field: str, type: Optional[Any] = None, @@ -329,6 +552,60 @@ def get_field_node_from_obj_field( ) +@dataclass(frozen=True) +class GraphQLObjectFieldArg: + name: Optional[str] + out_name: Optional[str] + type: Optional[Any] + description: Optional[str] = None + default_value: Optional[Any] = None + + +def get_field_args_from_resolver( + resolver: Resolver, +) -> Dict[str, GraphQLObjectFieldArg]: + resolver_signature = signature(resolver) + type_hints = resolver.__annotations__ + type_hints.pop("return", None) + + field_args: Dict[str, GraphQLObjectFieldArg] = {} + field_args_start = 0 + + # Fist pass: (arg, *_, something, something) or (arg, *, something, something): + for i, param in enumerate(resolver_signature.parameters.values()): + param_repr = str(param) + if param_repr.startswith("*") and not param_repr.startswith("**"): + field_args_start = i + 1 + break + else: + if len(resolver_signature.parameters) < 2: + raise TypeError( + f"Resolver function '{resolver_signature}' should accept at least " + "'obj' and 'info' positional arguments." + ) + + field_args_start = 2 + + args_parameters = tuple(resolver_signature.parameters.items())[field_args_start:] + if not args_parameters: + return field_args + + for param_name, param in args_parameters: + if param.default != param.empty: + param_default = param.default + else: + param_default = None + + field_args[param_name] = GraphQLObjectFieldArg( + name=convert_python_name_to_graphql(param_name), + out_name=param_name, + type=type_hints.get(param_name), + default_value=param_default, + ) + + return field_args + + def get_field_args_out_names( field_args: Dict[str, GraphQLObjectFieldArg], ) -> Dict[str, str]: @@ -367,15 +644,50 @@ def get_field_arg_node_from_obj_field_arg( ) -def validate_object_type_with_schema(cls: Type[GraphQLObject]) -> Dict[str, Any]: - definition = parse_definition(ObjectTypeDefinitionNode, cls.__schema__) +def update_field_args_options( + field_args: Dict[str, GraphQLObjectFieldArg], + args_options: Optional[Dict[str, dict]], +) -> Dict[str, GraphQLObjectFieldArg]: + if not args_options: + return field_args + + updated_args: Dict[str, GraphQLObjectFieldArg] = {} + for arg_name in field_args: + arg_options = args_options.get(arg_name) + if not arg_options: + updated_args[arg_name] = field_args[arg_name] + continue + + args_update = {} + if arg_options.get("name"): + args_update["name"] = arg_options["name"] + if arg_options.get("description"): + args_update["description"] = arg_options["description"] + if arg_options.get("default_value") is not None: + args_update["default_value"] = arg_options["default_value"] + if arg_options.get("type"): + args_update["type"] = arg_options["type"] + + if args_update: + updated_args[arg_name] = replace(field_args[arg_name], **args_update) + else: + updated_args[arg_name] = field_args[arg_name] + + return updated_args + + +def validate_object_type_with_schema( + cls: Type[GraphQLObject], + valid_type: Type[TypeDefinitionNode] = ObjectTypeDefinitionNode, +) -> Dict[str, Any]: + definition = parse_definition(valid_type, cls.__schema__) - if not isinstance(definition, ObjectTypeDefinitionNode): + if not isinstance(definition, valid_type): raise ValueError( f"Class '{cls.__name__}' defines '__schema__' attribute " "with declaration for an invalid GraphQL type. " f"('{definition.__class__.__name__}' != " - f"'{ObjectTypeDefinitionNode.__name__}')" + f"'{valid_type.__name__}')" ) validate_name(cls, definition) diff --git a/ariadne_graphql_modules/next/resolver.py b/ariadne_graphql_modules/next/resolver.py deleted file mode 100644 index 0049793..0000000 --- a/ariadne_graphql_modules/next/resolver.py +++ /dev/null @@ -1,94 +0,0 @@ -from dataclasses import dataclass, replace -from inspect import signature -from typing import Any, Dict, Optional - -from ariadne.types import Resolver - -from .convert_name import convert_python_name_to_graphql -from .field import GraphQLObjectFieldArg - - -@dataclass(frozen=True) -class GraphQLObjectResolver: - resolver: Resolver - field: str - description: Optional[str] = None - args: Optional[Dict[str, dict]] = None - type: Optional[Any] = None - - -def get_field_args_from_resolver( - resolver: Resolver, -) -> Dict[str, GraphQLObjectFieldArg]: - resolver_signature = signature(resolver) - type_hints = resolver.__annotations__ - type_hints.pop("return", None) - - field_args: Dict[str, GraphQLObjectFieldArg] = {} - field_args_start = 0 - - # Fist pass: (arg, *_, something, something) or (arg, *, something, something): - for i, param in enumerate(resolver_signature.parameters.values()): - param_repr = str(param) - if param_repr.startswith("*") and not param_repr.startswith("**"): - field_args_start = i + 1 - break - else: - if len(resolver_signature.parameters) < 2: - raise TypeError( - f"Resolver function '{resolver_signature}' should accept at least " - "'obj' and 'info' positional arguments." - ) - - field_args_start = 2 - - args_parameters = tuple(resolver_signature.parameters.items())[field_args_start:] - if not args_parameters: - return field_args - - for param_name, param in args_parameters: - if param.default != param.empty: - param_default = param.default - else: - param_default = None - - field_args[param_name] = GraphQLObjectFieldArg( - name=convert_python_name_to_graphql(param_name), - out_name=param_name, - type=type_hints.get(param_name), - default_value=param_default, - ) - - return field_args - - -def update_field_args_options( - field_args: Dict[str, GraphQLObjectFieldArg], - args_options: Optional[Dict[str, dict]], -) -> Dict[str, GraphQLObjectFieldArg]: - if not args_options: - return field_args - - updated_args: Dict[str, GraphQLObjectFieldArg] = {} - for arg_name in field_args: - arg_options = args_options.get(arg_name) - if not arg_options: - updated_args[arg_name] = field_args[arg_name] - continue - - args_update = {} - if arg_options.get("name"): - args_update["name"] = arg_options["name"] - if arg_options.get("description"): - args_update["description"] = arg_options["description"] - if arg_options.get("default_value") is not None: - args_update["default_value"] = arg_options["default_value"] - if arg_options.get("type"): - args_update["type"] = arg_options["type"] - - if args_update: - updated_args[arg_name] = replace(field_args[arg_name], **args_update) - else: - updated_args[arg_name] = field_args[arg_name] - - return updated_args diff --git a/tests_next/test_get_field_args_from_resolver.py b/tests_next/test_get_field_args_from_resolver.py index 6e8174c..054fd99 100644 --- a/tests_next/test_get_field_args_from_resolver.py +++ b/tests_next/test_get_field_args_from_resolver.py @@ -1,4 +1,4 @@ -from ariadne_graphql_modules.next.resolver import get_field_args_from_resolver +from ariadne_graphql_modules.next.objecttype import get_field_args_from_resolver def test_field_has_no_args_after_obj_and_info_args():