Skip to content

Commit

Permalink
Remaining type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
binste committed Sep 29, 2023
1 parent 26fa128 commit 6a7a2d3
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 54 deletions.
19 changes: 12 additions & 7 deletions tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def recursive_dict_update(schema: dict, root: dict, def_dict: dict) -> None:


def get_field_datum_value_defs(propschema: SchemaInfo, root: dict) -> dict:
def_dict = {k: None for k in ("field", "datum", "value")}
def_dict: Dict[str, Optional[str]] = {k: None for k in ("field", "datum", "value")}
schema = propschema.schema
if propschema.is_reference() and "properties" in schema:
if "field" in schema["properties"]:
Expand Down Expand Up @@ -381,13 +381,14 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str:

for name, schema in definitions.items():
graph[name] = []
for child in schema.subclasses():
child = get_valid_identifier(child)
graph[name].append(child)
child = definitions[child]
for child_name in schema.subclasses():
child_name = get_valid_identifier(child_name)
graph[name].append(child_name)
child: SchemaGenerator = definitions[child_name]
if child.basename == basename:
child.basename = [name]
else:
assert isinstance(child.basename, list)
child.basename.append(name)

contents = [
Expand Down Expand Up @@ -524,9 +525,13 @@ def generate_vegalite_mark_mixin(
arg_info.kwds -= {"type"}

def_args = ["self"] + [
"{}=Undefined".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds))
"{}=Undefined".format(p)
for p in (sorted(arg_info.required) + sorted(arg_info.kwds))
]
dict_args = [
"{0}={0}".format(p)
for p in (sorted(arg_info.required) + sorted(arg_info.kwds))
]
dict_args = ["{0}={0}".format(p) for p in (sorted(arg_info.required) + sorted(arg_info.kwds))]

if arg_info.additional or arg_info.invalid_kwds:
def_args.append("**kwds")
Expand Down
83 changes: 51 additions & 32 deletions tools/schemapi/codegen.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""Code generation utilities"""
import re
import textwrap
from typing import Tuple, Set
from typing import Set, Final, Optional, List, Iterable, Union
from dataclasses import dataclass

from .utils import SchemaInfo, is_valid_identifier, indent_docstring, indent_arglist
from .utils import (
SchemaInfo,
is_valid_identifier,
indent_docstring,
indent_arglist,
SchemaProperties,
)


class CodeSnippet:
Expand Down Expand Up @@ -84,7 +90,7 @@ class SchemaGenerator:
The dictionary defining the schema class
rootschema : dict (optional)
The root schema for the class
basename : string or tuple (default: "SchemaBase")
basename : string or list of strings (default: "SchemaBase")
The name(s) of the base class(es) to use in the class definition
schemarepr : CodeSnippet or object, optional
An object whose repr will be used in the place of the explicit schema.
Expand All @@ -109,47 +115,51 @@ class {classname}({basename}):
'''
)

init_template = textwrap.dedent(
init_template: Final = textwrap.dedent(
"""
def __init__({arglist}):
super({classname}, self).__init__({super_arglist})
"""
).lstrip()

def _process_description(self, description):
def _process_description(self, description: str):
return description

def __init__(
self,
classname,
schema,
rootschema=None,
basename="SchemaBase",
schemarepr=None,
rootschemarepr=None,
nodefault=(),
haspropsetters=False,
classname: str,
schema: dict,
rootschema: Optional[dict] = None,
basename: Union[str, List[str]] = "SchemaBase",
schemarepr: Optional[object] = None,
rootschemarepr: Optional[object] = None,
nodefault: Optional[List[str]] = None,
haspropsetters: bool = False,
**kwargs,
):
) -> None:
self.classname = classname
self.schema = schema
self.rootschema = rootschema
self.basename = basename
self.schemarepr = schemarepr
self.rootschemarepr = rootschemarepr
self.nodefault = nodefault
self.nodefault = nodefault or ()
self.haspropsetters = haspropsetters
self.kwargs = kwargs

def subclasses(self):
def subclasses(self) -> List[str]:
"""Return a list of subclass names, if any."""
info = SchemaInfo(self.schema, self.rootschema)
return [child.refname for child in info.anyOf if child.is_reference()]

def schema_class(self):
def schema_class(self) -> str:
"""Generate code for a schema class"""
rootschema = self.rootschema if self.rootschema is not None else self.schema
schemarepr = self.schemarepr if self.schemarepr is not None else self.schema
rootschema: dict = (
self.rootschema if self.rootschema is not None else self.schema
)
schemarepr: object = (
self.schemarepr if self.schemarepr is not None else self.schema
)
rootschemarepr = self.rootschemarepr
if rootschemarepr is None:
if rootschema is self.schema:
Expand All @@ -171,7 +181,7 @@ def schema_class(self):
**self.kwargs,
)

def docstring(self, indent=0):
def docstring(self, indent: int = 0) -> str:
# TODO: add a general description at the top, derived from the schema.
# for example, a non-object definition should list valid type, enum
# values, etc.
Expand Down Expand Up @@ -207,7 +217,7 @@ def docstring(self, indent=0):
doc += [""]
return indent_docstring(doc, indent_level=indent, width=100, lstrip=True)

def init_code(self, indent=0):
def init_code(self, indent: int = 0) -> str:
"""Return code suitable for the __init__ function of a Schema class"""
info = SchemaInfo(self.schema, rootschema=self.rootschema)
arg_info = get_args(info)
Expand All @@ -216,8 +226,8 @@ def init_code(self, indent=0):
arg_info.required -= nodefault
arg_info.kwds -= nodefault

args = ["self"]
super_args = []
args: List[str] = ["self"]
super_args: List[str] = []

self.init_kwds = sorted(arg_info.kwds)

Expand All @@ -227,10 +237,15 @@ def init_code(self, indent=0):
args.append("*args")
super_args.append("*args")

args.extend("{}=Undefined".format(p) for p in sorted(arg_info.required) + sorted(arg_info.kwds))
args.extend(
"{}=Undefined".format(p)
for p in sorted(arg_info.required) + sorted(arg_info.kwds)
)
super_args.extend(
"{0}={0}".format(p)
for p in sorted(nodefault) + sorted(arg_info.required) + sorted(arg_info.kwds)
for p in sorted(nodefault)
+ sorted(arg_info.required)
+ sorted(arg_info.kwds)
)

if arg_info.additional:
Expand Down Expand Up @@ -261,9 +276,9 @@ def init_code(self, indent=0):
"null": "None",
}

def get_args(self, si):
def get_args(self, si: SchemaInfo) -> List[str]:
contents = ["self"]
props = []
props: Union[List[str], SchemaProperties] = []
if si.is_anyOf():
props = sorted({p for si_sub in si.anyOf for p in si_sub.properties})
elif si.properties:
Expand Down Expand Up @@ -296,7 +311,9 @@ def get_args(self, si):

return contents

def get_signature(self, attr, sub_si, indent, has_overload=False):
def get_signature(
self, attr: str, sub_si: SchemaInfo, indent: int, has_overload: bool = False
) -> List[str]:
lines = []
if has_overload:
lines.append("@overload # type: ignore[no-overload-impl]")
Expand All @@ -305,14 +322,16 @@ def get_signature(self, attr, sub_si, indent, has_overload=False):
lines.append(indent * " " + "...\n")
return lines

def setter_hint(self, attr, indent):
def setter_hint(self, attr: str, indent: int) -> List[str]:
si = SchemaInfo(self.schema, self.rootschema).properties[attr]
if si.is_anyOf():
return self._get_signature_any_of(si, attr, indent)
else:
return self.get_signature(attr, si, indent)

def _get_signature_any_of(self, si: SchemaInfo, attr, indent):
def _get_signature_any_of(
self, si: SchemaInfo, attr: str, indent: int
) -> List[str]:
signatures = []
for sub_si in si.anyOf:
if sub_si.is_anyOf():
Expand All @@ -324,7 +343,7 @@ def _get_signature_any_of(self, si: SchemaInfo, attr, indent):
)
return list(flatten(signatures))

def method_code(self, indent=0):
def method_code(self, indent: int = 0) -> Optional[str]:
"""Return code to assist setter methods"""
if not self.haspropsetters:
return None
Expand All @@ -334,7 +353,7 @@ def method_code(self, indent=0):
return ("\n" + indent * " ").join(type_hints)


def flatten(container):
def flatten(container: Iterable) -> Iterable:
"""Flatten arbitrarily flattened list
From https://stackoverflow.com/a/10824420
Expand Down
30 changes: 15 additions & 15 deletions tools/schemapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import textwrap
import urllib
from typing import Final, Optional, List, Dict, Literal, Union
from typing import Final, Optional, List, Dict, Any

from .schemapi import _resolve_references as resolve_references

Expand Down Expand Up @@ -93,7 +93,10 @@ class SchemaProperties:
"""A wrapper for properties within a schema"""

def __init__(
self, properties: dict, schema: dict, rootschema: Optional[dict] = None
self,
properties: Dict[str, Any],
schema: dict,
rootschema: Optional[dict] = None,
) -> None:
self._properties = properties
self._schema = schema
Expand Down Expand Up @@ -134,7 +137,7 @@ class SchemaInfo:
"""A wrapper for inspecting a JSON schema"""

def __init__(
self, schema: dict, rootschema: Optional[dict] = None
self, schema: Dict[str, Any], rootschema: Optional[Dict[str, Any]] = None
) -> None:
if not rootschema:
rootschema = schema
Expand Down Expand Up @@ -186,11 +189,7 @@ def short_description(self) -> str:

@property
def medium_description(self) -> str:
if self.is_list():
return "[{0}]".format(
", ".join(self.child(s).short_description for s in self.schema)
)
elif self.is_empty():
if self.is_empty():
return "Any"
elif self.is_enum():
return "enum({})".format(", ".join(map(repr, self.enum)))
Expand Down Expand Up @@ -229,6 +228,10 @@ def medium_description(self) -> str:
stacklevel=1,
)
return "any"
else:
raise ValueError(
"No medium_description available for this schema for schema"
)

@property
def properties(self) -> SchemaProperties:
Expand Down Expand Up @@ -259,19 +262,19 @@ def type(self) -> Optional[str]:
return self.schema.get("type", None)

@property
def anyOf(self) -> list:
def anyOf(self) -> List["SchemaInfo"]:
return [self.child(s) for s in self.schema.get("anyOf", [])]

@property
def oneOf(self) -> list:
def oneOf(self) -> List["SchemaInfo"]:
return [self.child(s) for s in self.schema.get("oneOf", [])]

@property
def allOf(self) -> list:
def allOf(self) -> List["SchemaInfo"]:
return [self.child(s) for s in self.schema.get("allOf", [])]

@property
def not_(self) -> dict:
def not_(self) -> "SchemaInfo":
return self.child(self.schema.get("not", {}))

@property
Expand Down Expand Up @@ -314,9 +317,6 @@ def _get_description(self, include_sublevels: bool = False) -> str:
desc = sub_desc
return desc

def is_list(self) -> bool:
return isinstance(self.schema, list)

def is_reference(self) -> bool:
return "$ref" in self.raw_schema

Expand Down

0 comments on commit 6a7a2d3

Please sign in to comment.