Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overlapping fragments #77

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion qenerate/core/plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, MutableMapping

from graphql import GraphQLSchema

Expand All @@ -15,12 +17,19 @@ def save(self):
self.file.write_text(self.content)


@dataclass
class FragmentClass:
class_name: str
fields: MutableMapping[str, FragmentClass]
parent: Optional[FragmentClass]


@dataclass
class Fragment(GeneratedFile):
root_class: FragmentClass
definition: GQLDefinition
import_path: str
fragment_name: str
class_name: str


class Plugin:
Expand Down
92 changes: 74 additions & 18 deletions qenerate/plugins/pydantic_v1/plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Mapping
from typing import Mapping, Optional
from functools import reduce
import operator

Expand All @@ -20,6 +20,7 @@
)
from qenerate.core.plugin import (
Fragment,
FragmentClass,
Plugin,
GeneratedFile,
)
Expand Down Expand Up @@ -132,12 +133,14 @@ def __init__(
type_info: TypeInfo,
definition: GQLDefinition,
feature_flags: FeatureFlags,
fragment_map: Mapping[str, Fragment],
):
Visitor.__init__(self)
self.schema = schema
self.type_info = type_info
self.definition = definition
self.feature_flags = feature_flags
self.fragment_map = fragment_map
self.parsed = ParsedNode(
parent=None,
fields=[],
Expand Down Expand Up @@ -197,7 +200,7 @@ def enter_fragment_definition(self, node: FragmentDefinitionNode, *_):
fields=[],
parent=self.parent,
parsed_type=field_type,
class_name=name,
fragment_class_name=name,
fragment_name=node.name.value,
)

Expand All @@ -209,6 +212,7 @@ def leave_fragment_definition(self, *_):

def enter_fragment_spread(self, node: FragmentSpreadNode, *_):
fragment_name = graphql_class_name_str_to_python(node.name.value)
self.current_fragment_class = self.fragment_map[fragment_name].root_class
field_type = ParsedFieldType(
is_primitive=False,
unwrapped_python_type=fragment_name,
Expand All @@ -217,6 +221,8 @@ def enter_fragment_spread(self, node: FragmentSpreadNode, *_):
)
current = ParsedFragmentSpreadNode(
fields=[],
fragment_root_class=self.fragment_map[fragment_name].root_class,
fragment_name=fragment_name,
parent=self.parent,
parsed_type=field_type,
)
Expand All @@ -239,6 +245,7 @@ def enter_field(self, node: FieldNode, *_):
parsed_type=field_type,
py_key=py_key,
gql_key=gql_key,
fragment_base_classes=[],
)

self.parent.fields.append(current)
Expand Down Expand Up @@ -303,12 +310,16 @@ def _to_python_type(self, graphql_type: GraphQLOutputType) -> str:
class QueryParser:
@staticmethod
def parse(
definition: GQLDefinition, schema: GraphQLSchema, feature_flags: FeatureFlags
definition: GQLDefinition,
schema: GraphQLSchema,
feature_flags: FeatureFlags,
fragment_map: Mapping[str, Fragment],
) -> ParsedNode:
document_ast = parse(definition.definition)
type_info = TypeInfo(schema)
visitor = FieldToTypeMatcherVisitor(
schema=schema,
fragment_map=fragment_map,
type_info=type_info,
definition=definition,
feature_flags=feature_flags,
Expand Down Expand Up @@ -337,6 +348,22 @@ def _traverse(self, node: ParsedNode) -> str:
result = f"{result}{self._traverse(child)}"
return result

def _traverse_fragment_classes(
self, node: ParsedNode, parent: Optional[FragmentClass] = None
) -> FragmentClass:
fragment_class = FragmentClass(
class_name=node.class_name(),
fields={},
parent=parent,
)
for field in node.fields:
if not isinstance(field, ParsedClassNode):
continue
fragment_class.fields[field.py_key] = self._traverse_fragment_classes(
node=field, parent=fragment_class
)
return fragment_class

def generate_fragments(
self, definitions: list[GQLDefinition], schema: GraphQLSchema
) -> list[Fragment]:
Expand Down Expand Up @@ -366,22 +393,24 @@ def generate_fragments(
# this and will try again in the next iteration.
next_to_process.append(definition)
continue
parser = QueryParser()
ast = parser.parse(
fragment_map=processed,
definition=definition,
schema=schema,
feature_flags=definition.feature_flags,
)
fragment_imports = self._fragment_imports(
definition=definition,
fragment_map=processed,
root_node=ast,
)
result = HEADER + IMPORTS
if fragment_imports:
result += "\n"
result += fragment_imports
result += f"\n\n\n{CONF}"
qf = definition.source_file
parser = QueryParser()
ast = parser.parse(
definition=definition,
schema=schema,
feature_flags=definition.feature_flags,
)
fragment = ast.fields[0]
if not isinstance(fragment, ParsedFragmentDefinitionNode):
print(f"[WARNING] {qf} is not a fragment")
Expand All @@ -395,7 +424,7 @@ def generate_fragments(
definition=definition,
file=qf.with_suffix(".py"),
content=result,
class_name=fragment.class_name,
root_class=self._traverse_fragment_classes(fragment),
import_path=import_path,
fragment_name=fragment.fragment_name,
)
Expand All @@ -413,12 +442,37 @@ def generate_fragments(
return list(processed.values())

def _fragment_imports(
self, definition: GQLDefinition, fragment_map: Mapping[str, Fragment]
self,
definition: GQLDefinition,
fragment_map: Mapping[str, Fragment],
root_node: ParsedNode,
) -> str:
def traverse(node: ParsedNode) -> dict[str, dict[str, str]]:
ans = {}
for field in node.fields:
if isinstance(field, ParsedFragmentSpreadNode):
ans[field.fragment_name] = field.add_fragment_base_classes_to_nodes(
node=node, fragment_class=field.fragment_root_class
)
ans.update(traverse(field))
return ans

fragment_classes = traverse(node=root_node)
imports = ""
for dep in sorted(definition.fragment_dependencies):
fragment = fragment_map[dep]
imports += f"\nfrom {fragment.import_path} import {fragment.class_name}"
imported_cls = fragment_classes[fragment.root_class.class_name]
if len(imported_cls) == 1:
imports += f"\nfrom {fragment.import_path} import "
imports += fragment.root_class.class_name
else:
imports += f"\nfrom {fragment.import_path} import (\n"
for k, v in imported_cls.items():
if k == v:
imports += f"{INDENT}{k},\n"
else:
imports += f"{INDENT}{k} as {v},\n"
imports += ")"
return imports

def _assemble_definition(
Expand Down Expand Up @@ -447,9 +501,17 @@ def generate_operations(
fragment_map = {f.fragment_name: f for f in fragments}
fragment_definitions = {f.fragment_name: f.definition for f in fragments}
for definition in definitions:
parser = QueryParser()
ast = parser.parse(
fragment_map=fragment_map,
definition=definition,
schema=schema,
feature_flags=definition.feature_flags,
)
fragment_imports = self._fragment_imports(
definition=definition,
fragment_map=fragment_map,
root_node=ast,
)

result = HEADER + IMPORTS
Expand All @@ -470,12 +532,6 @@ def generate_operations(
)
result += 'DEFINITION = """\n' f"{assembled_definition}" '\n"""'
result += f"\n\n\n{CONF}"
parser = QueryParser()
ast = parser.parse(
definition=definition,
schema=schema,
feature_flags=definition.feature_flags,
)
result += self._traverse(ast)
result += "\n\n"
cls = ast.fields[0].parsed_type.unwrapped_python_type
Expand Down
58 changes: 44 additions & 14 deletions qenerate/plugins/pydantic_v1/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional

from qenerate.core.preprocessor import GQLDefinitionType
from qenerate.core.plugin import FragmentClass


BASE_CLASS_NAME = "ConfiguredBaseModel"
Expand All @@ -15,6 +16,9 @@ class ParsedNode:
fields: list[ParsedNode]
parsed_type: ParsedFieldType

def class_name(self) -> str:
return "TODO: implement class_name()"

def class_code_string(self) -> str:
return ""

Expand Down Expand Up @@ -72,6 +76,10 @@ def class_code_string(self) -> str:
class ParsedClassNode(ParsedNode):
gql_key: str
py_key: str
fragment_base_classes: list[str]

def class_name(self) -> str:
return self.parsed_type.unwrapped_python_type

def class_code_string(self) -> str:
if not self._needs_class_rendering():
Expand All @@ -83,9 +91,9 @@ def class_code_string(self) -> str:
return self._class_code()

def _class_code(self) -> str:
base_classes = ", ".join(self._base_classes())
base_classes = ", ".join(self.fragment_base_classes or [BASE_CLASS_NAME])
lines = ["\n\n"]
lines.append(f"class {self.parsed_type.unwrapped_python_type}({base_classes}):")
lines.append(f"class {self.class_name()}({base_classes}):")
fields_added = False
for field in self.fields:
field_arg = "..., "
Expand Down Expand Up @@ -113,16 +121,6 @@ def _enum_code(self) -> str:
lines.append(f"{INDENT}{k} = {val}")
return "\n".join(lines)

def _base_classes(self) -> list[str]:
base_classes: list[str] = []
for field in self.fields:
if not isinstance(field, ParsedFragmentSpreadNode):
continue
base_classes.append(field.parsed_type.unwrapped_python_type)
if not base_classes:
base_classes.append(BASE_CLASS_NAME)
return base_classes

def field_type(self) -> str:
# This is a full (non-partial) fragment spread
if len(self.fields) == 1 and isinstance(
Expand Down Expand Up @@ -186,12 +184,15 @@ def class_code_string(self) -> str:

@dataclass
class ParsedFragmentDefinitionNode(ParsedNode):
class_name: str
fragment_class_name: str
fragment_name: str

def class_name(self) -> str:
return self.fragment_class_name

def class_code_string(self) -> str:
lines = ["\n\n"]
lines.append(f"class {self.class_name}({BASE_CLASS_NAME}):")
lines.append(f"class {self.class_name()}({BASE_CLASS_NAME}):")
fields_added = False
for field in self.fields:
if isinstance(field, ParsedClassNode):
Expand All @@ -211,9 +212,38 @@ def class_code_string(self) -> str:

@dataclass
class ParsedFragmentSpreadNode(ParsedNode):
fragment_root_class: FragmentClass
fragment_name: str

def class_code_string(self) -> str:
return ""

def add_fragment_base_classes_to_nodes(
self, node: ParsedNode, fragment_class: FragmentClass
) -> dict[str, str]:
return self._traverse(node=node, fragment_class=fragment_class)

def _traverse(
self, node: ParsedNode, fragment_class: FragmentClass
) -> dict[str, str]:
ans = {}
if isinstance(node, ParsedClassNode):
cls_name = "" if not fragment_class.parent else self.fragment_name
cls_name += fragment_class.class_name
node.fragment_base_classes.append(cls_name)
ans[fragment_class.class_name] = cls_name
for field in node.fields:
if not isinstance(field, ParsedClassNode):
continue
if field.py_key not in fragment_class.fields:
continue
ans.update(
self._traverse(
node=field, fragment_class=fragment_class.fields[field.py_key]
)
)
return ans


@dataclass
class ParsedFieldType:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
fragment NamespaceProtocol on Namespace_v1 {
name
clusterAdmin
cluster {
name
internal
disable {
integrations
}
}
}
14 changes: 14 additions & 0 deletions tests/generator/definitions/merge_fragments/namespaces.gql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
query NamespacesMinimal {
namespaces: namespaces_v1 {
... NamespaceProtocol
delete
labels
cluster {
serverUrl
insecureSkipTLSVerify
disable {
e2eTests
}
}
}
}
Loading