Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Apr 18, 2024
1 parent 7ccaf1c commit 214f591
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 22 deletions.
2 changes: 2 additions & 0 deletions ariadne_graphql_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .remote_schema import get_remote_schema
from .resolvers import set_resolver, unset_resolver
from .selections import merge_selection_sets, merge_selections
from .unwrap_type import unwrap_graphql_type

__all__ = [
"ForeignKeyResolver",
Expand Down Expand Up @@ -95,4 +96,5 @@
"set_resolver",
"setup_root_resolver",
"unset_resolver",
"unwrap_graphql_type",
]
30 changes: 16 additions & 14 deletions ariadne_graphql_proxy/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

from .standard_types import STANDARD_DIRECTIVES, STANDARD_TYPES
from .output_types import unwrap_output_type
from .unwrap_type import unwrap_graphql_type


ROOTS_ARGS_NAMES = {
Expand Down Expand Up @@ -221,10 +221,10 @@ def __init__(
self.exclude_directives_args = exclude_directives_args

def exclude_type(self, type_name: str) -> bool:
return self.exclude_types and type_name in self.exclude_types
return bool(self.exclude_types and type_name in self.exclude_types)

def exclude_type_field(self, type_name: str, field_name: str) -> bool:
return (
return bool(
self.exclude_fields
and type_name in self.exclude_fields
and field_name in self.exclude_fields[type_name]
Expand All @@ -233,18 +233,18 @@ def exclude_type_field(self, type_name: str, field_name: str) -> bool:
def exclude_type_field_arg(
self, type_name: str, field_name: str, arg_name: str
) -> bool:
return (
return bool(
self.exclude_args
and type_name in self.exclude_args
and field_name in self.exclude_args[type_name]
and arg_name in self.exclude_args[type_name][field_name]
)

def exclude_directive(self, type_name: str) -> bool:
return self.exclude_directives and type_name in self.exclude_directives
return bool(self.exclude_directives and type_name in self.exclude_directives)

def exclude_directive_arg(self, type_name: str, arg_name: str) -> bool:
return (
return bool(
self.exclude_directives_args
and type_name in self.exclude_directives_args
and arg_name in self.exclude_directives_args[type_name]
Expand Down Expand Up @@ -283,7 +283,7 @@ def get_dependencies(self, roots: Dict[str, List[str]]) -> List[str]:
)

field = root_type.fields[field_name]
field_type = unwrap_output_type(field.type)
field_type = unwrap_graphql_type(field.type)
if self.exclude_type(field_type.name):
raise ValueError(
f"Field '{field_name}' of type '{root}' that is specified in "
Expand All @@ -307,13 +307,15 @@ def get_dependencies(self, roots: Dict[str, List[str]]) -> List[str]:
dependencies, arg.ast_node.directives
)

arg_type = unwrap_output_type(arg.type)
arg_type = unwrap_graphql_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

return [dep for dep in dependencies if dep not in STANDARD_TYPES]

def find_ast_directives_dependencies(
self, dependencies: Set[str], directives_ast: Tuple[DirectiveNode]
self,
dependencies: Set[str],
directives_ast: Tuple[DirectiveNode, ...],
):
for directive in directives_ast:
directive_name = directive.name.value
Expand All @@ -323,7 +325,7 @@ def find_ast_directives_dependencies(
def find_type_dependencies(
self,
dependencies: Set[str],
type_def: GraphQLNamedType,
type_def: Union[GraphQLNamedType, GraphQLDirective],
):
if type_def.name in dependencies:
return
Expand Down Expand Up @@ -366,7 +368,7 @@ def find_directive_dependencies(
if self.exclude_directive_arg(type_def.name, arg_name):
continue

arg_type = unwrap_output_type(arg.type)
arg_type = unwrap_graphql_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

def find_enum_type_dependencies(
Expand Down Expand Up @@ -400,7 +402,7 @@ def find_input_type_dependencies(
dependencies, field.ast_node.directives
)

field_type = unwrap_output_type(field.type)
field_type = unwrap_graphql_type(field.type)
self.find_type_dependencies(dependencies, field_type)

def find_object_type_dependencies(
Expand Down Expand Up @@ -436,7 +438,7 @@ def find_object_type_field_dependencies(
dependencies, field_def.ast_node.directives
)

field_type = unwrap_output_type(field_def.type)
field_type = unwrap_graphql_type(field_def.type)
self.find_type_dependencies(dependencies, field_type)

for arg_name, arg in field_def.args.items():
Expand All @@ -448,7 +450,7 @@ def find_object_type_field_dependencies(
dependencies, arg.ast_node.directives
)

arg_type = unwrap_output_type(arg.type)
arg_type = unwrap_graphql_type(arg.type)
self.find_type_dependencies(dependencies, arg_type)

def find_scalar_type_dependencies(
Expand Down
8 changes: 0 additions & 8 deletions ariadne_graphql_proxy/output_types.py

This file was deleted.

0 comments on commit 214f591

Please sign in to comment.