From ebaaee20ebadb5a166bfa4a24a1c6ca86a6a4858 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 1 Oct 2023 11:22:53 +0200 Subject: [PATCH] Add type hints to channel mixins --- altair/vegalite/v5/schema/channels.py | 47 +++++++--- altair/vegalite/v5/schema/core.py | 13 +-- tools/generate_schema_wrapper.py | 118 ++++++++++++++++---------- tools/update_init_file.py | 1 + 4 files changed, 116 insertions(+), 63 deletions(-) diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index d9d3f180c..acc15e7a7 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -14,14 +14,21 @@ import pandas as pd from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters from altair.utils import parse_shorthand -from typing import Any, overload, Sequence, List, Literal, Union +from typing import Any, overload, Sequence, List, Literal, Union, Optional +from typing import Dict as TypingDict class FieldChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> Union[dict, List[dict]]: context = context or {} - shorthand = self._get("shorthand") - field = self._get("field") + ignore = ignore or [] + shorthand = self._get("shorthand") # type: ignore[attr-defined] + field = self._get("field") # type: ignore[attr-defined] if shorthand is not Undefined and field is not Undefined: raise ValueError( @@ -31,10 +38,10 @@ def to_dict(self, validate=True, ignore=(), context=None): if isinstance(shorthand, (tuple, list)): # If given a list of shorthands, then transform it to a list of classes - kwds = self._kwds.copy() + kwds = self._kwds.copy() # type: ignore[attr-defined] kwds.pop("shorthand") return [ - self.__class__(sh, **kwds).to_dict( + self.__class__(sh, **kwds).to_dict( # type: ignore[call-arg] validate=validate, ignore=ignore, context=context ) for sh in shorthand @@ -44,9 +51,9 @@ def to_dict(self, validate=True, ignore=(), context=None): parsed = {} elif isinstance(shorthand, str): parsed = parse_shorthand(shorthand, data=context.get("data", None)) - type_required = "type" in self._kwds + type_required = "type" in self._kwds # type: ignore[attr-defined] type_in_shorthand = "type" in parsed - type_defined_explicitly = self._get("type") is not Undefined + type_defined_explicitly = self._get("type") is not Undefined # type: ignore[attr-defined] if not type_required: # Secondary field names don't require a type argument in VegaLite 3+. # We still parse it out of the shorthand, but drop it here. @@ -80,26 +87,38 @@ def to_dict(self, validate=True, ignore=(), context=None): class ValueChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> dict: context = context or {} - condition = self._get("condition", Undefined) + ignore = ignore or [] + condition = self._get("condition", Undefined) # type: ignore[attr-defined] copy = self # don't copy unless we need to if condition is not Undefined: if isinstance(condition, core.SchemaBase): pass elif "field" in condition and "type" not in condition: kwds = parse_shorthand(condition["field"], context.get("data", None)) - copy = self.copy(deep=["condition"]) - copy["condition"].update(kwds) + copy = self.copy(deep=["condition"]) # type: ignore[attr-defined] + copy["condition"].update(kwds) # type: ignore[index] return super(ValueChannelMixin, copy).to_dict( validate=validate, ignore=ignore, context=context ) class DatumChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> dict: context = context or {} - datum = self._get("datum", Undefined) + ignore = ignore or [] + datum = self._get("datum", Undefined) # type: ignore[attr-defined] copy = self # don't copy unless we need to if datum is not Undefined: if isinstance(datum, core.SchemaBase): diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 60d8e17ff..b46fc23e7 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -3,18 +3,19 @@ from typing import Any, Literal, Union, Protocol, Sequence, List from typing import Dict as TypingDict - +from typing import Generator as TypingGenerator from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses import pkgutil import json -def load_schema(): +def load_schema() -> dict: """Load the json schema associated with this module's functions""" - return json.loads( - pkgutil.get_data(__name__, "vega-lite-schema.json").decode("utf-8") - ) + schema_bytes = pkgutil.get_data(__name__, "vega-lite-schema.json") + if schema_bytes is None: + raise ValueError("Unable to load vega-lite-schema.json") + return json.loads(schema_bytes.decode("utf-8")) class _ParameterProtocol(Protocol): @@ -40,7 +41,7 @@ class VegaLiteSchema(SchemaBase): _rootschema = load_schema() @classmethod - def _default_wrapper_classes(cls): + def _default_wrapper_classes(cls) -> TypingGenerator[type, None, None]: return _subclasses(VegaLiteSchema) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index a7284c357..8654d0edd 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -74,7 +74,7 @@ def _to_expr(self) -> str: class {basename}(SchemaBase): _rootschema = load_schema() @classmethod - def _default_wrapper_classes(cls): + def _default_wrapper_classes(cls) -> TypingGenerator[type, None, None]: return _subclasses({basename}) """ @@ -82,95 +82,126 @@ def _default_wrapper_classes(cls): import pkgutil import json -def load_schema(): +def load_schema() -> dict: """Load the json schema associated with this module's functions""" - return json.loads(pkgutil.get_data(__name__, '{schemafile}').decode('utf-8')) + schema_bytes = pkgutil.get_data(__name__, "{schemafile}") + if schema_bytes is None: + raise ValueError("Unable to load {schemafile}") + return json.loads( + schema_bytes.decode("utf-8") + ) ''' CHANNEL_MIXINS: Final = """ class FieldChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> Union[dict, List[dict]]: context = context or {} - shorthand = self._get('shorthand') - field = self._get('field') + ignore = ignore or [] + shorthand = self._get("shorthand") # type: ignore[attr-defined] + field = self._get("field") # type: ignore[attr-defined] if shorthand is not Undefined and field is not Undefined: - raise ValueError("{} specifies both shorthand={} and field={}. " - "".format(self.__class__.__name__, shorthand, field)) + raise ValueError( + "{} specifies both shorthand={} and field={}. " + "".format(self.__class__.__name__, shorthand, field) + ) if isinstance(shorthand, (tuple, list)): # If given a list of shorthands, then transform it to a list of classes - kwds = self._kwds.copy() - kwds.pop('shorthand') - return [self.__class__(sh, **kwds).to_dict(validate=validate, ignore=ignore, context=context) - for sh in shorthand] + kwds = self._kwds.copy() # type: ignore[attr-defined] + kwds.pop("shorthand") + return [ + self.__class__(sh, **kwds).to_dict( # type: ignore[call-arg] + validate=validate, ignore=ignore, context=context + ) + for sh in shorthand + ] if shorthand is Undefined: parsed = {} elif isinstance(shorthand, str): - parsed = parse_shorthand(shorthand, data=context.get('data', None)) - type_required = 'type' in self._kwds - type_in_shorthand = 'type' in parsed - type_defined_explicitly = self._get('type') is not Undefined + parsed = parse_shorthand(shorthand, data=context.get("data", None)) + type_required = "type" in self._kwds # type: ignore[attr-defined] + type_in_shorthand = "type" in parsed + type_defined_explicitly = self._get("type") is not Undefined # type: ignore[attr-defined] if not type_required: # Secondary field names don't require a type argument in VegaLite 3+. # We still parse it out of the shorthand, but drop it here. - parsed.pop('type', None) + parsed.pop("type", None) elif not (type_in_shorthand or type_defined_explicitly): - if isinstance(context.get('data', None), pd.DataFrame): + if isinstance(context.get("data", None), pd.DataFrame): raise ValueError( 'Unable to determine data type for the field "{}";' " verify that the field name is not misspelled." " If you are referencing a field from a transform," - " also confirm that the data type is specified correctly.".format(shorthand) + " also confirm that the data type is specified correctly.".format( + shorthand + ) ) else: - raise ValueError("{} encoding field is specified without a type; " - "the type cannot be automatically inferred because " - "the data is not specified as a pandas.DataFrame." - "".format(shorthand)) + raise ValueError( + "{} encoding field is specified without a type; " + "the type cannot be automatically inferred because " + "the data is not specified as a pandas.DataFrame." + "".format(shorthand) + ) else: # Shorthand is not a string; we pass the definition to field, # and do not do any parsing. - parsed = {'field': shorthand} + parsed = {"field": shorthand} context["parsed_shorthand"] = parsed return super(FieldChannelMixin, self).to_dict( - validate=validate, - ignore=ignore, - context=context + validate=validate, ignore=ignore, context=context ) class ValueChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> dict: context = context or {} - condition = self._get('condition', Undefined) + ignore = ignore or [] + condition = self._get("condition", Undefined) # type: ignore[attr-defined] copy = self # don't copy unless we need to if condition is not Undefined: if isinstance(condition, core.SchemaBase): pass - elif 'field' in condition and 'type' not in condition: - kwds = parse_shorthand(condition['field'], context.get('data', None)) - copy = self.copy(deep=['condition']) - copy['condition'].update(kwds) - return super(ValueChannelMixin, copy).to_dict(validate=validate, - ignore=ignore, - context=context) + elif "field" in condition and "type" not in condition: + kwds = parse_shorthand(condition["field"], context.get("data", None)) + copy = self.copy(deep=["condition"]) # type: ignore[attr-defined] + copy["condition"].update(kwds) # type: ignore[index] + return super(ValueChannelMixin, copy).to_dict( + validate=validate, ignore=ignore, context=context + ) class DatumChannelMixin: - def to_dict(self, validate=True, ignore=(), context=None): + def to_dict( + self, + validate: bool = True, + ignore: Optional[List[str]] = None, + context: Optional[TypingDict[str, Any]] = None, + ) -> dict: context = context or {} - datum = self._get('datum', Undefined) + ignore = ignore or [] + datum = self._get("datum", Undefined) # type: ignore[attr-defined] copy = self # don't copy unless we need to if datum is not Undefined: if isinstance(datum, core.SchemaBase): pass - return super(DatumChannelMixin, copy).to_dict(validate=validate, - ignore=ignore, - context=context) + return super(DatumChannelMixin, copy).to_dict( + validate=validate, ignore=ignore, context=context + ) """ MARK_METHOD: Final = ''' @@ -424,7 +455,7 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str: HEADER, "from typing import Any, Literal, Union, Protocol, Sequence, List", "from typing import Dict as TypingDict", - "", + "from typing import Generator as TypingGenerator" "", "from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses", LOAD_SCHEMA.format(schemafile="vega-lite-schema.json"), ] @@ -459,7 +490,8 @@ def generate_vegalite_channel_wrappers( "import pandas as pd", "from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters", "from altair.utils import parse_shorthand", - "from typing import Any, overload, Sequence, List, Literal, Union", + "from typing import Any, overload, Sequence, List, Literal, Union, Optional", + "from typing import Dict as TypingDict", ] contents = [HEADER] contents.append(CHANNEL_MYPY_IGNORE_STATEMENTS) diff --git a/tools/update_init_file.py b/tools/update_init_file.py index fe458dce0..497ed9754 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -94,6 +94,7 @@ def _is_relevant_attribute(attr_name: str) -> bool: or attr is Protocol or attr is Sequence or attr_name == "TypingDict" + or attr_name == "TypingGenerator" ): return False else: