Skip to content

Commit

Permalink
Basic subset schema implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Apr 9, 2024
1 parent 111c72c commit dd65c08
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 15 deletions.
352 changes: 338 additions & 14 deletions ariadne_graphql_proxy/copy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, List, Optional, Tuple
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast

from graphql import (
DirectiveNode,
GraphQLArgument,
GraphQLBoolean,
GraphQLDirective,
Expand All @@ -24,47 +26,369 @@
)

from .standard_types import STANDARD_TYPES
from .output_types import unwrap_output_type


ROOTS_ARGS_NAMES = {
"Query": "queries",
"Mutation": "mutations",
"Subscription": "subscriptions",
}


def copy_schema(
schema: GraphQLSchema,
*,
queries: Optional[List[str]] = None,
mutations: Optional[List[str]] = None,
subscriptions: Optional[List[str]] = None,
exclude_types: Optional[List[str]] = None,
exclude_args: Optional[Dict[str, Dict[str, List[str]]]] = None,
exclude_fields: Optional[Dict[str, List[str]]] = None,
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
) -> GraphQLSchema:
new_types = copy_schema_types(
schema,
exclude_types=exclude_types,
exclude_args=exclude_args,
exclude_fields=exclude_fields,
)
if queries or mutations or subscriptions:
roots_dependencies = find_roots_dependencies(
schema,
{
"Query": queries or [],
"Mutation": mutations or [],
"Subscription": subscriptions or [],
},
exclude_types,
exclude_args,
exclude_fields,
)

fin_exclude_types = exclude_types[:] if exclude_types else []
for schema_type in schema.type_map:
if (
schema_type not in roots_dependencies
and schema_type not in STANDARD_TYPES
):
fin_exclude_types.append(schema_type)

fin_exclude_fields = deepcopy(exclude_fields) if exclude_fields else {}
if queries:
fin_exclude_fields.setdefault("Query", [])
if schema.query_type:
for field_name in schema.query_type.fields:
if field_name not in queries:
fin_exclude_fields["Query"].append(field_name)

if mutations:
fin_exclude_fields.setdefault("Mutation", [])
if schema.mutation_type:
for field_name in schema.mutation_type.fields:
if field_name not in mutations:
fin_exclude_fields["Mutation"].append(field_name)

new_types = copy_schema_types(
schema,
exclude_types=fin_exclude_types,
exclude_args=exclude_args,
exclude_fields=fin_exclude_fields,
)
else:
new_types = copy_schema_types(
schema,
exclude_types=exclude_types,
exclude_args=exclude_args,
exclude_fields=exclude_fields,
)

query_type = None
if schema.query_type:
query_type = new_types[schema.query_type.name]

mutation_type = None
if schema.mutation_type:
if schema.mutation_type and schema.mutation_type.name in new_types:
mutation_type = new_types[schema.mutation_type.name]

if queries or mutations or subscriptions:
new_directives = tuple()
else:
new_directives = (
copy_directives(
new_types,
schema.directives,
exclude_directives=exclude_directives,
exclude_directives_args=exclude_directives_args,
),
)

new_schema = GraphQLSchema(
query=query_type,
mutation=mutation_type,
types=new_types.values(),
directives=copy_directives(
new_types,
schema.directives,
exclude_directives=exclude_directives,
exclude_directives_args=exclude_directives_args,
),
directives=new_directives,
)
assert_valid_schema(new_schema)
return new_schema


def find_roots_dependencies(
schema: GraphQLSchema,
roots: Dict[str, List[str]],
exclude_types: Optional[List[str]] = None,
exclude_args: Optional[Dict[str, Dict[str, List[str]]]] = None,
exclude_fields: Optional[Dict[str, List[str]]] = None,
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
) -> List[str]:
visitor = TypesDependenciesVisitor(
schema,
exclude_types,
exclude_args,
exclude_fields,
exclude_directives,
exclude_directives_args,
)

return visitor.get_dependencies(roots)


class TypesDependenciesVisitor:
def __init__(
self,
schema: GraphQLSchema,
exclude_types: Optional[List[str]] = None,
exclude_args: Optional[Dict[str, Dict[str, List[str]]]] = None,
exclude_fields: Optional[Dict[str, List[str]]] = None,
exclude_directives: Optional[List[str]] = None,
exclude_directives_args: Optional[Dict[str, List[str]]] = None,
):
self.schema = schema
self.exclude_types = exclude_types
self.exclude_args = exclude_args
self.exclude_fields = exclude_fields
self.exclude_directives = exclude_directives
self.exclude_directives_args = exclude_directives_args

def exclude_type(self, type_name: str) -> bool:
return self.exclude_types and type_name in self.exclude_types

def exclude_type_field(self, type_name: str, field_name: str) -> bool:
return (
self.exclude_fields
and type_name in self.exclude_fields
and field_name in self.exclude_fields[type_name]
)

def exclude_type_field_arg(
self, type_name: str, field_name: str, arg_name: str
) -> bool:
return (
self.exclude_args
and type_name in self.exclude_args
and field_name in self.exclude_args[type_name]
and arg_name in self.exclude_args[type_name][field_name]
)

def exclude_directive(self, type_name: str) -> bool:
return self.exclude_directives and type_name in self.exclude_directives

def exclude_directive_arg(self, type_name: str, arg_name: str) -> bool:
return (
self.exclude_directives_args
and type_name in self.exclude_directives_args
and arg_name in self.exclude_directives_args[type_name]
)

def get_dependencies(self, roots: Dict[str, List[str]]) -> List[str]:
dependencies: Set[str] = set("Query")

for root, fields in roots.items():
if not fields:
continue

arg_name = ROOTS_ARGS_NAMES[root]
dependencies.add(root)

if root not in self.schema.type_map:
raise ValueError(f"Root type '{root}' is not defined by the schema.")

root_type = cast(GraphQLObjectType, self.schema.type_map[root])

if root_type.ast_node:
self.find_ast_directives_dependencies(
dependencies, root_type.ast_node.directives
)

for field_name in fields:
if field_name not in root_type.fields:
raise ValueError(
f"Root type '{root}' is not defining the '{field_name}' field."
)

if self.exclude_type_field(root, field_name):
raise ValueError(
f"Field '{field_name}' for type '{root}' is specified in both "
f"'exclude_fields' and '{arg_name}'."
)

field = root_type.fields[field_name]
field_type = unwrap_output_type(field.type)
if self.exclude_type(field_type.name):
raise ValueError(
f"Field '{field_name}' for type '{root}' that is specified in "
f"'{arg_name}' has a return type '{field_type.name}' that is "
"also specified in 'exclude_types'."
)

self.find_type_dependencies(dependencies, field_type)

if field.ast_node:
self.find_ast_directives_dependencies(
dependencies, field.ast_node.directives
)

for arg_name, arg in field.args.items():
if self.exclude_type_field_arg(root, field_name, arg_name):
continue

if arg.ast_node:
self.find_ast_directives_dependencies(
dependencies, arg.ast_node.directives
)

arg_type = unwrap_output_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

return [dep for dep in dependencies if dep not in STANDARD_TYPES]

def find_ast_directives_dependencies(
self, dependencies: Set[str], directives_ast: Tuple[DirectiveNode]
):
for directive in directives_ast:
directive_name = directive.name.value
directive_type = self.schema.type_map[directive_name]
self.find_type_dependencies(dependencies, directive_type)

def find_type_dependencies(
self,
dependencies: Set[str],
type_def: GraphQLNamedType,
):
if type_def.name in dependencies:
return

if isinstance(type_def, GraphQLDirective):
self.find_directive_dependencies(dependencies, type_def)
return

if self.exclude_types and type_def.name in self.exclude_types:
return

dependencies.add(type_def.name)

if isinstance(type_def, GraphQLInputObjectType):
self.find_input_type_dependencies(dependencies, type_def)

if isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)):
self.find_object_type_dependencies(dependencies, type_def)

if isinstance(type_def, GraphQLUnionType):
self.find_union_type_dependencies(dependencies, type_def)

def find_directive_dependencies(
self,
dependencies: Set[str],
type_def: GraphQLDirective,
):
if self.exclude_directive(type_def.name):
return

dependencies.add(type_def.name)

for arg_name, arg in type_def.args.items():
if self.exclude_directive_arg(type_def.name, arg_name):
continue

arg_type = unwrap_output_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

def find_input_type_dependencies(
self, dependencies: Set[str], type_def: GraphQLInputObjectType
):
if type_def.ast_node:
self.find_ast_directives_dependencies(
dependencies, type_def.ast_node.directives
)

for field_name, field in type_def.fields.items():
if self.exclude_type_field(type_def.name, field_name):
return

if field.ast_node:
self.find_ast_directives_dependencies(
dependencies, field.ast_node.directives
)

field_type = unwrap_output_type(field.type)
self.find_type_dependencies(dependencies, field_type)

def find_object_type_dependencies(
self,
dependencies: Set[str],
type_def: Union[GraphQLObjectType, GraphQLInterfaceType],
):
if type_def.ast_node:
self.find_ast_directives_dependencies(
dependencies, type_def.ast_node.directives
)

for interface in type_def.interfaces:
self.find_type_dependencies(dependencies, interface)

for field_name, field in type_def.fields.items():
self.find_object_type_field_dependencies(
dependencies, type_def, field_name, field
)

def find_object_type_field_dependencies(
self,
dependencies: Set[str],
type_def: Union[GraphQLObjectType, GraphQLInterfaceType],
field_name: str,
field_def: GraphQLField,
):
if self.exclude_type_field(type_def.name, field_name):
return

if field_def.ast_node:
self.find_ast_directives_dependencies(
dependencies, field_def.ast_node.directives
)

field_type = unwrap_output_type(field_def.type)
self.find_type_dependencies(dependencies, field_type)

for arg_name, arg in field_def.args.items():
if self.exclude_type_field_arg(type_def.name, field_name, arg_name):
continue

if arg.ast_node:
self.find_ast_directives_dependencies(
dependencies, arg.ast_node.directives
)

arg_type = unwrap_output_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

def find_union_type_dependencies(
self, dependencies: Set[str], type_def: GraphQLUnionType
):
if type_def.ast_node:
self.find_ast_directives_dependencies(
dependencies, type_def.ast_node.directives
)

for union_type in type_def._types:
self.find_type_dependencies(dependencies, union_type)


def copy_schema_types(
schema: GraphQLSchema,
exclude_types: Optional[List[str]] = None,
Expand Down
Loading

0 comments on commit dd65c08

Please sign in to comment.