Skip to content

Commit

Permalink
Inherit from GraphQLObject in the Interface type
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianCzajkowski committed Apr 16, 2024
1 parent 2124ac0 commit c3b91fb
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 604 deletions.
50 changes: 0 additions & 50 deletions ariadne_graphql_modules/next/field.py

This file was deleted.

205 changes: 105 additions & 100 deletions ariadne_graphql_modules/next/interfacetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Union,
cast,
Sequence,
NoReturn,
)

from ariadne import InterfaceType
Expand All @@ -21,47 +20,33 @@
GraphQLField,
GraphQLObjectType,
GraphQLSchema,
InputValueDefinitionNode,
NameNode,
NamedTypeNode,
InterfaceTypeDefinitionNode,
)

from .metadata import get_graphql_object_data
from .objectmixin import GraphQLModelHelpersMixin
from .value import get_value_node

from .objecttype import (
GraphQLObject,
GraphQLObjectResolver,
get_field_args_from_resolver,
get_field_args_out_names,
get_field_node_from_obj_field,
get_graphql_object_data,
update_field_args_options,
)

from ..utils import parse_definition
from .base import GraphQLMetadata, GraphQLModel, GraphQLType
from .description import get_description_node
from .typing import get_graphql_type
from .validators import validate_description, validate_name


class GraphQLInterface(GraphQLType, GraphQLModelHelpersMixin):
class GraphQLInterface(GraphQLObject):
__types__: Sequence[Type[GraphQLType]]
__implements__: Optional[Iterable[Union[Type[GraphQLType], Type[Enum]]]]

def __init_subclass__(cls) -> None:
super().__init_subclass__()

if cls.__dict__.get("__abstract__"):
return

cls.__abstract__ = False

if cls.__dict__.get("__schema__"):
validate_interface_type_with_schema(cls)
else:
validate_interface_type(cls)

@classmethod
def __get_graphql_model__(cls, metadata: GraphQLMetadata) -> "GraphQLModel":
name = cls.__get_graphql_name__()
metadata.set_graphql_name(cls, name)

if getattr(cls, "__schema__", None):
return cls.__get_graphql_model_with_schema__(metadata, name)

return cls.__get_graphql_model_without_schema__(metadata, name)
__valid_type__ = InterfaceTypeDefinitionNode

@classmethod
def __get_graphql_model_with_schema__(
Expand All @@ -71,9 +56,75 @@ def __get_graphql_model_with_schema__(
InterfaceTypeDefinitionNode,
parse_definition(InterfaceTypeDefinitionNode, cls.__schema__),
)
resolvers: Dict[str, Resolver] = cls.collect_resolvers_with_schema()

descriptions: Dict[str, str] = {}
args_descriptions: Dict[str, Dict[str, str]] = {}
args_defaults: Dict[str, Dict[str, Any]] = {}
resolvers: Dict[str, Resolver] = {}
out_names: Dict[str, Dict[str, str]] = {}
fields: List[FieldDefinitionNode] = cls.gather_fields_with_schema(definition)

for attr_name in dir(cls):
cls_attr = getattr(cls, attr_name)
if isinstance(cls_attr, GraphQLObjectResolver):
resolvers[cls_attr.field] = cls_attr.resolver
if cls_attr.description:
descriptions[cls_attr.field] = get_description_node(
cls_attr.description
)

field_args = get_field_args_from_resolver(cls_attr.resolver)
if field_args:
args_descriptions[cls_attr.field] = {}
args_defaults[cls_attr.field] = {}

final_args = update_field_args_options(field_args, cls_attr.args)

for arg_name, arg_options in final_args.items():
arg_description = get_description_node(arg_options.description)
if arg_description:
args_descriptions[cls_attr.field][arg_name] = (
arg_description
)

arg_default = arg_options.default_value
if arg_default is not None:
args_defaults[cls_attr.field][arg_name] = get_value_node(
arg_default
)

fields: List[FieldDefinitionNode] = []
for field in definition.fields:
field_args_descriptions = args_descriptions.get(field.name.value, {})
field_args_defaults = args_defaults.get(field.name.value, {})

args: List[InputValueDefinitionNode] = []
for arg in field.arguments:
arg_name = arg.name.value
args.append(
InputValueDefinitionNode(
description=(
arg.description or field_args_descriptions.get(arg_name)
),
name=arg.name,
directives=arg.directives,
type=arg.type,
default_value=(
arg.default_value or field_args_defaults.get(arg_name)
),
)
)

fields.append(
FieldDefinitionNode(
name=field.name,
description=(
field.description or descriptions.get(field.name.value)
),
directives=field.directives,
arguments=tuple(args),
type=field.type,
)
)

return GraphQLInterfaceModel(
name=definition.name.value,
Expand All @@ -94,16 +145,30 @@ def __get_graphql_model_without_schema__(
cls, metadata: GraphQLMetadata, name: str
) -> "GraphQLInterfaceModel":
type_data = get_graphql_object_data(metadata, cls)
type_aliases = getattr(cls, "__aliases__", None) or {}

fields_ast: List[FieldDefinitionNode] = cls.gather_fields_without_schema(
metadata
)
interfaces_ast: List[NamedTypeNode] = cls.gather_interfaces_without_schema(
type_data
)
resolvers: Dict[str, Resolver] = cls.collect_resolvers_without_schema(type_data)
aliases: Dict[str, str] = cls.collect_aliases(type_data)
out_names: Dict[str, Dict[str, str]] = cls.collect_out_names(type_data)
fields_ast: List[FieldDefinitionNode] = []
resolvers: Dict[str, Resolver] = {}
aliases: Dict[str, str] = {}
out_names: Dict[str, Dict[str, str]] = {}

for attr_name, field in type_data.fields.items():
fields_ast.append(get_field_node_from_obj_field(cls, metadata, field))

if attr_name in type_aliases:
aliases[field.name] = type_aliases[attr_name]
elif attr_name != field.name:
aliases[field.name] = attr_name

if field.resolver:
resolvers[field.name] = field.resolver

if field.args:
out_names[field.name] = get_field_args_out_names(field.args)

interfaces_ast: List[NamedTypeNode] = []
for interface_name in type_data.interfaces:
interfaces_ast.append(NamedTypeNode(name=NameNode(value=interface_name)))

return GraphQLInterfaceModel(
name=name,
Expand All @@ -122,44 +187,6 @@ def __get_graphql_model_without_schema__(
out_names=out_names,
)

@classmethod
def __get_graphql_types__(
cls, metadata: "GraphQLMetadata"
) -> Iterable["GraphQLType"]:
"""Returns iterable with GraphQL types associated with this type"""
if getattr(cls, "__schema__", None):
return cls.__get_graphql_types_with_schema__(metadata)

return cls.__get_graphql_types_without_schema__(metadata)

@classmethod
def __get_graphql_types_with_schema__(
cls, metadata: "GraphQLMetadata"
) -> Iterable["GraphQLType"]:
types: List[GraphQLType] = [cls]
types.extend(getattr(cls, "__requires__", []))
return types

@classmethod
def __get_graphql_types_without_schema__(
cls, metadata: "GraphQLMetadata"
) -> Iterable["GraphQLType"]:
types: List[GraphQLType] = [cls]
type_data = get_graphql_object_data(metadata, cls)

for field in type_data.fields.values():
field_type = get_graphql_type(field.type)
if field_type and field_type not in types:
types.append(field_type)

if field.args:
for field_arg in field.args.values():
field_arg_type = get_graphql_type(field_arg.type)
if field_arg_type and field_arg_type not in types:
types.append(field_arg_type)

return types

@staticmethod
def resolve_type(obj: Any, *args) -> str:
if isinstance(obj, GraphQLInterface):
Expand All @@ -170,28 +197,6 @@ def resolve_type(obj: Any, *args) -> str:
)


def validate_interface_type(cls: Type[GraphQLInterface]) -> NoReturn:
pass


def validate_interface_type_with_schema(cls: Type[GraphQLInterface]) -> NoReturn:
definition = cast(
InterfaceTypeDefinitionNode,
parse_definition(InterfaceTypeDefinitionNode, cls.__schema__),
)

if not isinstance(definition, InterfaceTypeDefinitionNode):
raise ValueError(
f"Class '{cls.__name__}' defines a '__schema__' attribute "
"with declaration for an invalid GraphQL type. "
f"('{definition.__class__.__name__}' != "
f"'{InterfaceTypeDefinitionNode.__name__}')"
)

validate_name(cls, definition)
validate_description(cls, definition)


@dataclass(frozen=True)
class GraphQLInterfaceModel(GraphQLModel):
resolvers: Dict[str, Resolver]
Expand Down
Loading

0 comments on commit c3b91fb

Please sign in to comment.