Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Optional field to be not required #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/gql/fragments/fragment1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ class Config:
class VaultSecret(ConfiguredBaseModel):
path: str = Field(..., alias="path")
field: str = Field(..., alias="field")
version: Optional[int] = Field(..., alias="version")
q_format: Optional[str] = Field(..., alias="format")
version: Optional[int] = Field(alias="version")
q_format: Optional[str] = Field(alias="format")
4 changes: 2 additions & 2 deletions demo/gql/queries/example1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ class ResourceV1(ConfiguredBaseModel):

class JenkinsConfigV1_JenkinsConfigV1(JenkinsConfigV1):
q_type: str = Field(..., alias="type")
config_path: Optional[ResourceV1] = Field(..., alias="config_path")
config_path: Optional[ResourceV1] = Field(alias="config_path")


class JenkinsConfigsQueryData(ConfiguredBaseModel):
jenkins_configs: Optional[list[Union[JenkinsConfigV1_JenkinsConfigV1, JenkinsConfigV1]]] = Field(..., alias="jenkins_configs")
jenkins_configs: Optional[list[Union[JenkinsConfigV1_JenkinsConfigV1, JenkinsConfigV1]]] = Field(alias="jenkins_configs")


def query(query_func: Callable, **kwargs: Any) -> JenkinsConfigsQueryData:
Expand Down
2 changes: 1 addition & 1 deletion demo/gql/queries/example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GitlabInstanceV1(ConfiguredBaseModel):


class GitlabInstanceQueryData(ConfiguredBaseModel):
instances: Optional[list[GitlabInstanceV1]] = Field(..., alias="instances")
instances: Optional[list[GitlabInstanceV1]] = Field(alias="instances")


def query(query_func: Callable, **kwargs: Any) -> GitlabInstanceQueryData:
Expand Down
11 changes: 11 additions & 0 deletions qenerate/plugins/pydantic_v1/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
unwrapped_python_type="",
wrapped_python_type="",
is_primitive=False,
is_optional=False,
enum_map={},
),
)
Expand Down Expand Up @@ -178,6 +179,7 @@ def enter_operation_definition(self, node: OperationDefinitionNode, *_):
unwrapped_python_type=node.name.value,
wrapped_python_type=node.name.value,
is_primitive=False,
is_optional=False,
enum_map={},
),
)
Expand Down Expand Up @@ -211,6 +213,7 @@ def enter_fragment_spread(self, node: FragmentSpreadNode, *_):
fragment_name = graphql_class_name_str_to_python(node.name.value)
field_type = ParsedFieldType(
is_primitive=False,
is_optional=False,
unwrapped_python_type=fragment_name,
wrapped_python_type=fragment_name,
enum_map={},
Expand Down Expand Up @@ -253,6 +256,7 @@ def _parse_type(self, graphql_type: GraphQLOutputType) -> ParsedFieldType:
unwrapped_type = self._to_python_type(unwrapper_result.inner_gql_type)
is_primitive = unwrapper_result.is_primitive
enum_map = unwrapper_result.enum_map
is_optional = self._is_optional_type(unwrapper_result.wrapper_stack)
wrapped_type = unwrapped_type
for wrapper in reversed(unwrapper_result.wrapper_stack):
if wrapper == WrapperType.LIST:
Expand All @@ -264,9 +268,16 @@ def _parse_type(self, graphql_type: GraphQLOutputType) -> ParsedFieldType:
unwrapped_python_type=unwrapped_type,
wrapped_python_type=wrapped_type,
is_primitive=is_primitive,
is_optional=is_optional,
enum_map=enum_map,
)

@staticmethod
def _is_optional_type(wrapper_stack: list[WrapperType]) -> bool:
if not wrapper_stack:
return False
return wrapper_stack[0] == WrapperType.OPTIONAL

def _to_python_type(self, graphql_type: GraphQLOutputType) -> str:
if isinstance(graphql_type, GraphQLScalarType):
return graphql_primitive_to_python(
Expand Down
42 changes: 15 additions & 27 deletions qenerate/plugins/pydantic_v1/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
INDENT = " "


def _build_field_code_string(field: ParsedClassNode) -> str:
field_arg = (
"" if field.parsed_type.is_optional or field.parsed_type.enum_map else "..., "
)
return (
f"{INDENT}{field.py_key}: {field.field_type()} = "
f'Field({field_arg}alias="{field.gql_key}")'
)


@dataclass
class ParsedNode:
parent: Optional[ParsedNode]
Expand Down Expand Up @@ -54,12 +64,7 @@ def class_code_string(self) -> str:
fields_added = False
for field in self.fields:
if isinstance(field, ParsedClassNode):
lines.append(
(
f"{INDENT}{field.py_key}: {field.field_type()} = "
f'Field(..., alias="{field.gql_key}")'
)
)
lines.append(_build_field_code_string(field))
fields_added = True

if not fields_added:
Expand Down Expand Up @@ -88,16 +93,8 @@ def _class_code(self) -> str:
lines.append(f"class {self.parsed_type.unwrapped_python_type}({base_classes}):")
fields_added = False
for field in self.fields:
field_arg = "..., "
if field.parsed_type.enum_map:
field_arg = ""
if isinstance(field, ParsedClassNode):
lines.append(
(
f"{INDENT}{field.py_key}: {field.field_type()} = "
f'Field({field_arg}alias="{field.gql_key}")'
)
)
lines.append(_build_field_code_string(field))
fields_added = True

if not fields_added:
Expand Down Expand Up @@ -170,12 +167,7 @@ def class_code_string(self) -> str:
fields_added = False
for field in self.fields:
if isinstance(field, ParsedClassNode):
lines.append(
(
f"{INDENT}{field.py_key}: {field.field_type()} = "
f'Field(..., alias="{field.gql_key}")'
)
)
lines.append(_build_field_code_string(field))
fields_added = True

if not fields_added:
Expand All @@ -195,12 +187,7 @@ def class_code_string(self) -> str:
fields_added = False
for field in self.fields:
if isinstance(field, ParsedClassNode):
lines.append(
(
f"{INDENT}{field.py_key}: {field.field_type()} = "
f'Field(..., alias="{field.gql_key}")'
)
)
lines.append(_build_field_code_string(field))
fields_added = True

if not fields_added:
Expand All @@ -220,4 +207,5 @@ class ParsedFieldType:
unwrapped_python_type: str
wrapped_python_type: str
is_primitive: bool
is_optional: bool
enum_map: dict[str, Any]
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class ClusterSpecV1__2(ConfiguredBaseModel):

class ClusterV1__2(ConfiguredBaseModel):
name: str = Field(..., alias="name")
spec: Optional[ClusterSpecV1__2] = Field(..., alias="spec")
internal: Optional[bool] = Field(..., alias="internal")
spec: Optional[ClusterSpecV1__2] = Field(alias="spec")
internal: Optional[bool] = Field(alias="internal")


class ClusterPeeringConnectionClusterRequesterV1(ClusterPeeringConnectionV1):
Expand All @@ -89,27 +89,27 @@ class ClusterSpecV1__3(ConfiguredBaseModel):

class ClusterV1__3(ConfiguredBaseModel):
name: str = Field(..., alias="name")
spec: Optional[ClusterSpecV1__3] = Field(..., alias="spec")
internal: Optional[bool] = Field(..., alias="internal")
spec: Optional[ClusterSpecV1__3] = Field(alias="spec")
internal: Optional[bool] = Field(alias="internal")


class ClusterPeeringConnectionClusterAccepterV1(ClusterPeeringConnectionV1):
cluster: ClusterV1__3 = Field(..., alias="cluster")


class ClusterPeeringV1(ConfiguredBaseModel):
connections: Optional[list[Union[ClusterPeeringConnectionClusterRequesterV1, ClusterPeeringConnectionClusterAccepterV1, ClusterPeeringConnectionV1]]] = Field(..., alias="connections")
connections: Optional[list[Union[ClusterPeeringConnectionClusterRequesterV1, ClusterPeeringConnectionClusterAccepterV1, ClusterPeeringConnectionV1]]] = Field(alias="connections")


class ClusterV1(ConfiguredBaseModel):
name: str = Field(..., alias="name")
spec: Optional[ClusterSpecV1] = Field(..., alias="spec")
internal: Optional[bool] = Field(..., alias="internal")
peering: Optional[ClusterPeeringV1] = Field(..., alias="peering")
spec: Optional[ClusterSpecV1] = Field(alias="spec")
internal: Optional[bool] = Field(alias="internal")
peering: Optional[ClusterPeeringV1] = Field(alias="peering")


class EnumerateCollisionsQueryData(ConfiguredBaseModel):
clusters: Optional[list[Optional[ClusterV1]]] = Field(..., alias="clusters")
clusters: Optional[list[Optional[ClusterV1]]] = Field(alias="clusters")


def query(query_func: Callable, **kwargs: Any) -> EnumerateCollisionsQueryData:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ class ClusterAuthGithubOrgTeamV1(ClusterAuthV1):


class ClusterV1(ConfiguredBaseModel):
auth: Optional[Union[ClusterAuthGithubOrgTeamV1, ClusterAuthGithubOrgV1, ClusterAuthV1]] = Field(..., alias="auth")
auth: Optional[Union[ClusterAuthGithubOrgTeamV1, ClusterAuthGithubOrgV1, ClusterAuthV1]] = Field(alias="auth")


class OcpReleaseMirrorV1(ConfiguredBaseModel):
hive_cluster: ClusterV1 = Field(..., alias="hiveCluster")


class OCPWithInterfaceQueryData(ConfiguredBaseModel):
ocp_release_mirror: Optional[list[Optional[OcpReleaseMirrorV1]]] = Field(..., alias="ocp_release_mirror")
ocp_release_mirror: Optional[list[Optional[OcpReleaseMirrorV1]]] = Field(alias="ocp_release_mirror")


def query(query_func: Callable, **kwargs: Any) -> OCPWithInterfaceQueryData:
Expand Down
Loading