Skip to content

Commit

Permalink
fix: Fix ClientForwardRefsPlugin imports and add tests
Browse files Browse the repository at this point in the history
- Add missing tests for `ClientForwardRefsPlugin` with and without
  combining with `ShorterResultsPlugin`.
- Fix faulty imports
  - Store name and level separate to allow dots specified either on the
    module or via level. When importing just lookup what level and name
    to use.
  - Always use level 0 for `TYPE_CHECKING_MODULE`.

Fixes mirumee#314
  • Loading branch information
bombsimon committed Dec 25, 2024
1 parent 11bfe35 commit e25f4a1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
19 changes: 11 additions & 8 deletions ariadne_codegen/contrib/client_forward_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None:
# Imported classes are classes imported from local imports. We keep a
# map between name and module so we know how to import them in each
# method.
self.imported_classes: Dict[str, str] = {}
self.imported_classes: Dict[str, tuple[int, str]] = {}

# Imported classes in each method definition.
self.imported_in_method: Set[str] = set()
Expand Down Expand Up @@ -116,9 +116,8 @@ def _store_imported_classes(self, module_body: List[ast.stmt]):
continue

for name in node.names:
from_ = "." * node.level + node.module
if isinstance(name, ast.alias):
self.imported_classes[name.name] = from_
self.imported_classes[name.name] = (node.level, node.module)

def _rewrite_input_args_to_constants(
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
Expand Down Expand Up @@ -178,12 +177,14 @@ def _insert_import_statement_in_method(
# We add the class to our set of imported in methods - these classes
# don't need to be imported at all in the global scope.
self.imported_in_method.add(import_class_name)

level, module_name = self.imported_classes[import_class_name]
method_def.body.insert(
0,
ast.ImportFrom(
module=self.imported_classes[import_class_name],
module=module_name,
names=[import_class],
level=1,
level=level,
),
)

Expand Down Expand Up @@ -342,10 +343,12 @@ def _add_forward_ref_imports(
"""
type_checking_imports = {}
for cls in self.input_and_return_types:
module_name = self.imported_classes[cls]
level, module_name = self.imported_classes[cls]
if module_name not in type_checking_imports:
type_checking_imports[module_name] = ast.ImportFrom(
module=module_name, names=[], level=1
module=module_name,
names=[],
level=level,
)

type_checking_imports[module_name].names.append(ast.alias(cls))
Expand All @@ -364,7 +367,7 @@ def _add_forward_ref_imports(
ast.ImportFrom(
module=TYPE_CHECKING_MODULE,
names=[ast.alias(TYPE_CHECKING_FLAG)],
level=1,
level=0,
),
)

Expand Down
30 changes: 30 additions & 0 deletions tests/main/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,36 @@ def test_main_shows_version():
"example_client",
CLIENTS_PATH / "custom_sync_query_builder" / "expected_client",
),
(
(
CLIENTS_PATH / "client_forward_refs" / "pyproject.toml",
(
CLIENTS_PATH / "client_forward_refs" / "queries.graphql",
CLIENTS_PATH / "client_forward_refs" / "schema.graphql",
CLIENTS_PATH / "client_forward_refs" / "custom_scalars.py",
),
),
"client_forward_refs",
CLIENTS_PATH / "client_forward_refs" / "expected_client",
),
(
(
CLIENTS_PATH / "client_forward_refs_shorter_results" / "pyproject.toml",
(
CLIENTS_PATH
/ "client_forward_refs_shorter_results"
/ "queries.graphql",
CLIENTS_PATH
/ "client_forward_refs_shorter_results"
/ "schema.graphql",
CLIENTS_PATH
/ "client_forward_refs_shorter_results"
/ "custom_scalars.py",
),
),
"client_forward_refs_shorter_results",
CLIENTS_PATH / "client_forward_refs_shorter_results" / "expected_client",
),
],
indirect=["project_dir"],
)
Expand Down

0 comments on commit e25f4a1

Please sign in to comment.