Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model_rebuild calls for top level fragments #258

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## UNRELEASED

- Restored `model_rebuild` calls for top level fragment models.


## 0.11.0 (2023-12-05)

- Removed `model_rebuild` calls for generated input, fragment and result models.
Expand Down
4 changes: 4 additions & 0 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
```

### Init file
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
MODEL_VALIDATE_METHOD = "model_validate"
PLAIN_SERIALIZER = "PlainSerializer"
BEFORE_VALIDATOR = "BeforeValidator"
MODEL_REBUILD_METHOD = "model_rebuild"

ENUM_MODULE = "enum"
ENUM_CLASS = "Enum"
Expand Down
32 changes: 28 additions & 4 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from graphql import FragmentDefinitionNode, GraphQLSchema

from ..codegen import generate_module
from ..codegen import generate_expr, generate_method_call, generate_module
from ..plugins.manager import PluginManager
from .constants import BASE_MODEL_IMPORT
from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD
from .result_types import ResultTypesGenerator
from .scalars import ScalarData

Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict: Dict[str, List[ast.ClassDef]] = {}
imports: List[ast.ImportFrom] = []
top_level_class_names: List[str] = []
dependencies_dict: Dict[str, Set[str]] = {}

names_to_exclude = exclude_names or set()
Expand All @@ -53,7 +54,10 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
plugin_manager=self.plugin_manager,
)
imports.extend(generator.get_imports())
class_defs_dict[name] = generator.get_classes()
class_defs = generator.get_classes()
class_defs_dict[name] = class_defs
if class_defs:
top_level_class_names.append(class_defs[0].name)
dependencies_dict[name] = generator.get_fragments_used_as_mixins()
self._generated_public_names.extend(generator.get_generated_public_names())
self._used_enums.extend(generator.get_used_enums())
Expand All @@ -62,7 +66,15 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict
)
module = generate_module(
body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs)
body=cast(List[ast.stmt], imports)
+ cast(List[ast.stmt], sorted_class_defs)
+ cast(
List[ast.stmt],
self._get_model_rebuild_calls(
top_level_fragments_names=top_level_class_names,
class_defs=sorted_class_defs,
),
)
)
if self.plugin_manager:
module = self.plugin_manager.generate_fragments_module(
Expand Down Expand Up @@ -108,3 +120,15 @@ def visit(name):
visit(name)

return sorted_names

def _get_model_rebuild_calls(
self, top_level_fragments_names: List[str], class_defs: List[ast.ClassDef]
) -> List[ast.Call]:
class_names = [c.name for c in class_defs]
sorted_fragments_names = sorted(
top_level_fragments_names, key=lambda n: class_names.index(n)
)
return [
generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD))
for name in sorted_fragments_names
]
4 changes: 4 additions & 0 deletions tests/main/clients/example/expected_client/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ class GetQueryAFragment(BaseModel):

class GetQueryAFragmentQueryA(BaseModel, MixinA, CommonMixin):
field_a: int = Field(alias="fieldA")


FragmentA.model_rebuild()
FragmentB.model_rebuild()
GetQueryAFragment.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class FragmentA(BaseModel):
class FragmentB(BaseModel):
id: str
value_b: str = Field(alias="valueB")


FragmentA.model_rebuild()
FragmentB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ class FragmentOnQueryWithUnionQueryUTypeC(BaseModel):
class UnusedFragmentOnTypeA(BaseModel):
id: str
field_a: str = Field(alias="fieldA")


FragmentOnQueryWithInterface.model_rebuild()
FragmentOnQueryWithUnion.model_rebuild()
UnusedFragmentOnTypeA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ class MinimalA(BaseModel):

class MinimalAFieldB(MinimalB):
pass


CompleteA.model_rebuild()
FullB.model_rebuild()
FullA.model_rebuild()
MinimalB.model_rebuild()
MinimalA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class FragmentG(BaseModel):

class FragmentGG(BaseModel):
val: EnumGG


FragmentG.model_rebuild()
FragmentGG.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/operations/expected_client/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class FragmentB(BaseModel):

class FragmentY(BaseModel):
value_y: int = Field(alias="valueY")


FragmentB.model_rebuild()
FragmentY.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class ListAnimalsFragment(BaseModel):
class ListAnimalsFragmentListAnimals(BaseModel):
typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename")
name: str


FragmentWithSingleField.model_rebuild()
ListAnimalsFragment.model_rebuild()
Loading