Skip to content

Commit

Permalink
Add type hints to channel mixins
Browse files Browse the repository at this point in the history
  • Loading branch information
binste committed Oct 1, 2023
1 parent 0e803bc commit ebaaee2
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 63 deletions.
47 changes: 33 additions & 14 deletions altair/vegalite/v5/schema/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,21 @@
import pandas as pd
from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters
from altair.utils import parse_shorthand
from typing import Any, overload, Sequence, List, Literal, Union
from typing import Any, overload, Sequence, List, Literal, Union, Optional
from typing import Dict as TypingDict


class FieldChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> Union[dict, List[dict]]:
context = context or {}
shorthand = self._get("shorthand")
field = self._get("field")
ignore = ignore or []
shorthand = self._get("shorthand") # type: ignore[attr-defined]
field = self._get("field") # type: ignore[attr-defined]

if shorthand is not Undefined and field is not Undefined:
raise ValueError(
Expand All @@ -31,10 +38,10 @@ def to_dict(self, validate=True, ignore=(), context=None):

if isinstance(shorthand, (tuple, list)):
# If given a list of shorthands, then transform it to a list of classes
kwds = self._kwds.copy()
kwds = self._kwds.copy() # type: ignore[attr-defined]
kwds.pop("shorthand")
return [
self.__class__(sh, **kwds).to_dict(
self.__class__(sh, **kwds).to_dict( # type: ignore[call-arg]
validate=validate, ignore=ignore, context=context
)
for sh in shorthand
Expand All @@ -44,9 +51,9 @@ def to_dict(self, validate=True, ignore=(), context=None):
parsed = {}
elif isinstance(shorthand, str):
parsed = parse_shorthand(shorthand, data=context.get("data", None))
type_required = "type" in self._kwds
type_required = "type" in self._kwds # type: ignore[attr-defined]
type_in_shorthand = "type" in parsed
type_defined_explicitly = self._get("type") is not Undefined
type_defined_explicitly = self._get("type") is not Undefined # type: ignore[attr-defined]
if not type_required:
# Secondary field names don't require a type argument in VegaLite 3+.
# We still parse it out of the shorthand, but drop it here.
Expand Down Expand Up @@ -80,26 +87,38 @@ def to_dict(self, validate=True, ignore=(), context=None):


class ValueChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> dict:
context = context or {}
condition = self._get("condition", Undefined)
ignore = ignore or []
condition = self._get("condition", Undefined) # type: ignore[attr-defined]
copy = self # don't copy unless we need to
if condition is not Undefined:
if isinstance(condition, core.SchemaBase):
pass
elif "field" in condition and "type" not in condition:
kwds = parse_shorthand(condition["field"], context.get("data", None))
copy = self.copy(deep=["condition"])
copy["condition"].update(kwds)
copy = self.copy(deep=["condition"]) # type: ignore[attr-defined]
copy["condition"].update(kwds) # type: ignore[index]
return super(ValueChannelMixin, copy).to_dict(
validate=validate, ignore=ignore, context=context
)


class DatumChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> dict:
context = context or {}
datum = self._get("datum", Undefined)
ignore = ignore or []
datum = self._get("datum", Undefined) # type: ignore[attr-defined]
copy = self # don't copy unless we need to
if datum is not Undefined:
if isinstance(datum, core.SchemaBase):
Expand Down
13 changes: 7 additions & 6 deletions altair/vegalite/v5/schema/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

from typing import Any, Literal, Union, Protocol, Sequence, List
from typing import Dict as TypingDict

from typing import Generator as TypingGenerator
from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses

import pkgutil
import json


def load_schema():
def load_schema() -> dict:
"""Load the json schema associated with this module's functions"""
return json.loads(
pkgutil.get_data(__name__, "vega-lite-schema.json").decode("utf-8")
)
schema_bytes = pkgutil.get_data(__name__, "vega-lite-schema.json")
if schema_bytes is None:
raise ValueError("Unable to load vega-lite-schema.json")
return json.loads(schema_bytes.decode("utf-8"))


class _ParameterProtocol(Protocol):
Expand All @@ -40,7 +41,7 @@ class VegaLiteSchema(SchemaBase):
_rootschema = load_schema()

@classmethod
def _default_wrapper_classes(cls):
def _default_wrapper_classes(cls) -> TypingGenerator[type, None, None]:
return _subclasses(VegaLiteSchema)


Expand Down
118 changes: 75 additions & 43 deletions tools/generate_schema_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,103 +74,134 @@ def _to_expr(self) -> str:
class {basename}(SchemaBase):
_rootschema = load_schema()
@classmethod
def _default_wrapper_classes(cls):
def _default_wrapper_classes(cls) -> TypingGenerator[type, None, None]:
return _subclasses({basename})
"""

LOAD_SCHEMA: Final = '''
import pkgutil
import json
def load_schema():
def load_schema() -> dict:
"""Load the json schema associated with this module's functions"""
return json.loads(pkgutil.get_data(__name__, '{schemafile}').decode('utf-8'))
schema_bytes = pkgutil.get_data(__name__, "{schemafile}")
if schema_bytes is None:
raise ValueError("Unable to load {schemafile}")
return json.loads(
schema_bytes.decode("utf-8")
)
'''


CHANNEL_MIXINS: Final = """
class FieldChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> Union[dict, List[dict]]:
context = context or {}
shorthand = self._get('shorthand')
field = self._get('field')
ignore = ignore or []
shorthand = self._get("shorthand") # type: ignore[attr-defined]
field = self._get("field") # type: ignore[attr-defined]
if shorthand is not Undefined and field is not Undefined:
raise ValueError("{} specifies both shorthand={} and field={}. "
"".format(self.__class__.__name__, shorthand, field))
raise ValueError(
"{} specifies both shorthand={} and field={}. "
"".format(self.__class__.__name__, shorthand, field)
)
if isinstance(shorthand, (tuple, list)):
# If given a list of shorthands, then transform it to a list of classes
kwds = self._kwds.copy()
kwds.pop('shorthand')
return [self.__class__(sh, **kwds).to_dict(validate=validate, ignore=ignore, context=context)
for sh in shorthand]
kwds = self._kwds.copy() # type: ignore[attr-defined]
kwds.pop("shorthand")
return [
self.__class__(sh, **kwds).to_dict( # type: ignore[call-arg]
validate=validate, ignore=ignore, context=context
)
for sh in shorthand
]
if shorthand is Undefined:
parsed = {}
elif isinstance(shorthand, str):
parsed = parse_shorthand(shorthand, data=context.get('data', None))
type_required = 'type' in self._kwds
type_in_shorthand = 'type' in parsed
type_defined_explicitly = self._get('type') is not Undefined
parsed = parse_shorthand(shorthand, data=context.get("data", None))
type_required = "type" in self._kwds # type: ignore[attr-defined]
type_in_shorthand = "type" in parsed
type_defined_explicitly = self._get("type") is not Undefined # type: ignore[attr-defined]
if not type_required:
# Secondary field names don't require a type argument in VegaLite 3+.
# We still parse it out of the shorthand, but drop it here.
parsed.pop('type', None)
parsed.pop("type", None)
elif not (type_in_shorthand or type_defined_explicitly):
if isinstance(context.get('data', None), pd.DataFrame):
if isinstance(context.get("data", None), pd.DataFrame):
raise ValueError(
'Unable to determine data type for the field "{}";'
" verify that the field name is not misspelled."
" If you are referencing a field from a transform,"
" also confirm that the data type is specified correctly.".format(shorthand)
" also confirm that the data type is specified correctly.".format(
shorthand
)
)
else:
raise ValueError("{} encoding field is specified without a type; "
"the type cannot be automatically inferred because "
"the data is not specified as a pandas.DataFrame."
"".format(shorthand))
raise ValueError(
"{} encoding field is specified without a type; "
"the type cannot be automatically inferred because "
"the data is not specified as a pandas.DataFrame."
"".format(shorthand)
)
else:
# Shorthand is not a string; we pass the definition to field,
# and do not do any parsing.
parsed = {'field': shorthand}
parsed = {"field": shorthand}
context["parsed_shorthand"] = parsed
return super(FieldChannelMixin, self).to_dict(
validate=validate,
ignore=ignore,
context=context
validate=validate, ignore=ignore, context=context
)
class ValueChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> dict:
context = context or {}
condition = self._get('condition', Undefined)
ignore = ignore or []
condition = self._get("condition", Undefined) # type: ignore[attr-defined]
copy = self # don't copy unless we need to
if condition is not Undefined:
if isinstance(condition, core.SchemaBase):
pass
elif 'field' in condition and 'type' not in condition:
kwds = parse_shorthand(condition['field'], context.get('data', None))
copy = self.copy(deep=['condition'])
copy['condition'].update(kwds)
return super(ValueChannelMixin, copy).to_dict(validate=validate,
ignore=ignore,
context=context)
elif "field" in condition and "type" not in condition:
kwds = parse_shorthand(condition["field"], context.get("data", None))
copy = self.copy(deep=["condition"]) # type: ignore[attr-defined]
copy["condition"].update(kwds) # type: ignore[index]
return super(ValueChannelMixin, copy).to_dict(
validate=validate, ignore=ignore, context=context
)
class DatumChannelMixin:
def to_dict(self, validate=True, ignore=(), context=None):
def to_dict(
self,
validate: bool = True,
ignore: Optional[List[str]] = None,
context: Optional[TypingDict[str, Any]] = None,
) -> dict:
context = context or {}
datum = self._get('datum', Undefined)
ignore = ignore or []
datum = self._get("datum", Undefined) # type: ignore[attr-defined]
copy = self # don't copy unless we need to
if datum is not Undefined:
if isinstance(datum, core.SchemaBase):
pass
return super(DatumChannelMixin, copy).to_dict(validate=validate,
ignore=ignore,
context=context)
return super(DatumChannelMixin, copy).to_dict(
validate=validate, ignore=ignore, context=context
)
"""

MARK_METHOD: Final = '''
Expand Down Expand Up @@ -424,7 +455,7 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str:
HEADER,
"from typing import Any, Literal, Union, Protocol, Sequence, List",
"from typing import Dict as TypingDict",
"",
"from typing import Generator as TypingGenerator" "",
"from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses",
LOAD_SCHEMA.format(schemafile="vega-lite-schema.json"),
]
Expand Down Expand Up @@ -459,7 +490,8 @@ def generate_vegalite_channel_wrappers(
"import pandas as pd",
"from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters",
"from altair.utils import parse_shorthand",
"from typing import Any, overload, Sequence, List, Literal, Union",
"from typing import Any, overload, Sequence, List, Literal, Union, Optional",
"from typing import Dict as TypingDict",
]
contents = [HEADER]
contents.append(CHANNEL_MYPY_IGNORE_STATEMENTS)
Expand Down
1 change: 1 addition & 0 deletions tools/update_init_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _is_relevant_attribute(attr_name: str) -> bool:
or attr is Protocol
or attr is Sequence
or attr_name == "TypingDict"
or attr_name == "TypingGenerator"
):
return False
else:
Expand Down

0 comments on commit ebaaee2

Please sign in to comment.