From 69b0029ff909b44953aeeaabc3f071ae846b2a42 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Wed, 9 Aug 2023 19:21:14 +0200 Subject: [PATCH 01/18] First batch of type hints --- altair/vegalite/v5/api.py | 113 +++++++++++++++++------------- altair/vegalite/v5/schema/core.py | 2 +- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 70680984a..75b2aa54d 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,13 +8,13 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any, Iterable +from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal # Have to rename it here as else it overlaps with schema.core.Type from typing import Type as TypingType from typing import Dict as TypingDict -from .schema import core, channels, mixins, Undefined, SCHEMA_URL +from .schema import core, channels, mixins, Undefined, SCHEMA_URL, UndefinedType from .data import data_transformers from ... import utils, expr @@ -35,13 +35,13 @@ # ------------------------------------------------------------------------ # Data Utilities -def _dataset_name(values): +def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str: """Generate a unique hash of the data Parameters ---------- - values : list or dict - A list/dict representation of data values. + values : list, dict, core.InlineDataset + A representation of data values. Returns ------- @@ -136,7 +136,7 @@ class LookupData(core.LookupData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: """Convert the chart to a dictionary suitable for JSON export.""" copy = self.copy(deep=False) copy.data = _prepare_data(copy.data, kwargs.get("context")) @@ -150,7 +150,7 @@ class FacetMapping(core.FacetMapping): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: copy = self.copy(deep=False) context = kwargs.get("context", {}) data = context.get("data", None) @@ -172,8 +172,8 @@ def to_dict(self, *args, **kwargs): TOPLEVEL_ONLY_KEYS = {"background", "config", "autosize", "padding", "$schema"} -def _get_channels_mapping(): - mapping = {} +def _get_channels_mapping() -> Dict[Type[core.SchemaBase], str]: + mapping: Dict[Type[core.SchemaBase], str] = {} for attr in dir(channels): cls = getattr(channels, attr) if isinstance(cls, type) and issubclass(cls, core.SchemaBase): @@ -189,11 +189,11 @@ class Parameter(expr.core.OperatorMixin, object): _counter = 0 @classmethod - def _get_name(cls): + def _get_name(cls) -> str: cls._counter += 1 return f"param_{cls._counter}" - def __init__(self, name): + def __init__(self, name: Optional[str]) -> None: if name is None: name = self._get_name() self.name = name @@ -201,11 +201,11 @@ def __init__(self, name): @utils.deprecation.deprecated( message="'ref' is deprecated. No need to call '.ref()' anymore." ) - def ref(self): + def ref(self) -> dict: "'ref' is deprecated. No need to call '.ref()' anymore." return self.to_dict() - def to_dict(self): + def to_dict(self) -> dict: if self.param_type == "variable": return {"expr": self.name} elif self.param_type == "selection": @@ -214,6 +214,8 @@ def to_dict(self): if hasattr(self.name, "to_dict") else self.name } + else: + raise ValueError(f"Unrecognized parameter type: {self.param_type}") def __invert__(self): if self.param_type == "selection": @@ -237,16 +239,18 @@ def __or__(self, other): else: return expr.core.OperatorMixin.__or__(self, other) - def __repr__(self): + def __repr__(self) -> str: return "Parameter({0!r}, {1})".format(self.name, self.param) - def _to_expr(self): + def _to_expr(self) -> str: return self.name - def _from_expr(self, expr): + def _from_expr(self, expr) -> "ParameterExpression": return ParameterExpression(expr=expr) - def __getattr__(self, field_name): + def __getattr__( + self, field_name: str + ) -> Union["SelectionExpression", expr.core.GetAttrExpression]: if field_name.startswith("__") and field_name.endswith("__"): raise AttributeError(field_name) _attrexpr = expr.core.GetAttrExpression(self.name, field_name) @@ -258,7 +262,7 @@ def __getattr__(self, field_name): # TODO: Are there any special cases to consider for __getitem__? # This was copied from v4. - def __getitem__(self, field_name): + def __getitem__(self, field_name: str) -> expr.core.GetItemExpression: return expr.core.GetItemExpression(self.name, field_name) @@ -275,34 +279,34 @@ def __or__(self, other): class ParameterExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> Dict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "ParameterExpression": return ParameterExpression(expr=expr) class SelectionExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> Dict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "SelectionExpression": return SelectionExpression(expr=expr) -def check_fields_and_encodings(parameter, field_name): +def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: for prop in ["fields", "encodings"]: try: if field_name in getattr(parameter.param.select, prop): @@ -317,20 +321,24 @@ def check_fields_and_encodings(parameter, field_name): # Top-Level Functions -def value(value, **kwargs): +def value(value, **kwargs) -> dict: """Specify a value for use in an encoding""" return dict(value=value, **kwargs) def param( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, + name: Optional[str] = None, + value: Union[UndefinedType, Any] = Undefined, + bind: Union[UndefinedType, core.Binding] = Undefined, + empty: Union[UndefinedType, bool] = Undefined, + expr: Union[UndefinedType, core.Expr] = Undefined, **kwds, -): - """Create a named parameter. See https://altair-viz.github.io/user_guide/interactions.html for examples. Although both variable parameters and selection parameters can be created using this 'param' function, to create a selection parameter, it is recommended to use either 'selection_point' or 'selection_interval' instead. +) -> Parameter: + """Create a named parameter. + See https://altair-viz.github.io/user_guide/interactions.html for examples. + Although both variable parameters and selection parameters can be created using + this 'param' function, to create a selection parameter, it is recommended to use + either 'selection_point' or 'selection_interval' instead. Parameters ---------- @@ -415,7 +423,9 @@ def param( return parameter -def _selection(type=Undefined, **kwds): +def _selection( + type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds +) -> Parameter: # We separate out the parameter keywords from the selection keywords param_kwds = {} @@ -423,6 +433,7 @@ def _selection(type=Undefined, **kwds): if kwd in kwds: param_kwds[kwd] = kwds.pop(kwd) + select: Union[core.IntervalSelectionConfig, core.PointSelectionConfig] if type == "interval": select = core.IntervalSelectionConfig(type=type, **kwds) elif type == "point": @@ -445,7 +456,9 @@ def _selection(type=Undefined, **kwds): message="""'selection' is deprecated. Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings.""" ) -def selection(type=Undefined, **kwds): +def selection( + type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds +) -> Parameter: """ Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create. @@ -469,20 +482,20 @@ def selection(type=Undefined, **kwds): def selection_interval( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, - encodings=Undefined, - on=Undefined, - clear=Undefined, - resolve=Undefined, - mark=Undefined, - translate=Undefined, - zoom=Undefined, + name: Optional[str] = None, + value: Union[UndefinedType, Any] = Undefined, + bind: Union[UndefinedType, core.Binding] = Undefined, + empty: Union[UndefinedType, bool] = Undefined, + expr: Union[UndefinedType, core.Expr] = Undefined, + encodings: Union[UndefinedType, List[str]] = Undefined, + on: Union[UndefinedType, str] = Undefined, + clear: Union[UndefinedType, str, bool] = Undefined, + resolve: Union[UndefinedType, Literal["global", "union", "intersect"]] = Undefined, + mark: Union[UndefinedType, core.Mark] = Undefined, + translate: Union[UndefinedType, str, bool] = Undefined, + zoom: Union[UndefinedType, str, bool] = Undefined, **kwds, -): +) -> Parameter: """Create an interval selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Interval selection parameters are used to select a continuous range of data values on drag, whereas point selection parameters (`selection_point`) are used to select multiple discrete data values.) Parameters diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index ad1e5afd5..d2f4887e6 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -1,7 +1,7 @@ # The contents of this file are automatically written by # tools/generate_schema_wrapper.py. Do not modify directly. -from altair.utils.schemapi import SchemaBase, Undefined, _subclasses +from altair.utils.schemapi import SchemaBase, Undefined, _subclasses, UndefinedType import pkgutil import json From 911bf5043b60264b3bf782fb0d2f2905580b9729 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 25 Aug 2023 18:52:29 +0200 Subject: [PATCH 02/18] Some more --- altair/vegalite/v5/api.py | 48 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 7e10d642a..057876801 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -483,17 +483,17 @@ def selection( def selection_interval( name: Optional[str] = None, - value: Union[UndefinedType, Any] = Undefined, - bind: Union[UndefinedType, core.Binding] = Undefined, - empty: Union[UndefinedType, bool] = Undefined, - expr: Union[UndefinedType, core.Expr] = Undefined, - encodings: Union[UndefinedType, List[str]] = Undefined, - on: Union[UndefinedType, str] = Undefined, - clear: Union[UndefinedType, str, bool] = Undefined, - resolve: Union[UndefinedType, Literal["global", "union", "intersect"]] = Undefined, - mark: Union[UndefinedType, core.Mark] = Undefined, - translate: Union[UndefinedType, str, bool] = Undefined, - zoom: Union[UndefinedType, str, bool] = Undefined, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[core.Expr, UndefinedType] = Undefined, + encodings: Union[List[str], UndefinedType] = Undefined, + on: Union[str, UndefinedType] = Undefined, + clear: Union[str, bool, UndefinedType] = Undefined, + resolve: Union[Literal["global", "union", "intersect"], UndefinedType] = Undefined, + mark: Union[core.Mark, UndefinedType] = Undefined, + translate: Union[str, bool, UndefinedType] = Undefined, + zoom: Union[str, bool, UndefinedType] = Undefined, **kwds, ) -> Parameter: """Create an interval selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Interval selection parameters are used to select a continuous range of data values on drag, whereas point selection parameters (`selection_point`) are used to select multiple discrete data values.) @@ -594,20 +594,20 @@ def selection_interval( def selection_point( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, - encodings=Undefined, - fields=Undefined, - on=Undefined, - clear=Undefined, - resolve=Undefined, - toggle=Undefined, - nearest=Undefined, + name: Optional[str] = None, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[core.Expr, UndefinedType] = Undefined, + encodings: Union[List[str], UndefinedType] = Undefined, + fields: Union[List[str], UndefinedType] = Undefined, + on: Union[str, UndefinedType] = Undefined, + clear: Union[str, bool, UndefinedType] = Undefined, + resolve: Union[Literal["global", "union", "intersect"], UndefinedType] = Undefined, + toggle: Union[str, bool, UndefinedType] = Undefined, + nearest: Union[bool, UndefinedType] = Undefined, **kwds, -): +) -> Parameter: """Create a point selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Point selection parameters are used to select multiple discrete data values; the first value is selected on click and additional values toggled on shift-click. To select a continuous range of data values on drag interval selection parameters (`selection_interval`) can be used instead. Parameters From 937ce55609a97197e9f486cae4e5db99e05dcc30 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 26 Aug 2023 10:16:17 +0200 Subject: [PATCH 03/18] Add type hints to some more transform_methods --- altair/vegalite/v5/api.py | 277 ++++++++++++++++++++++---------------- 1 file changed, 163 insertions(+), 114 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 057876801..d8345f3fc 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,7 +8,7 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal +from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal, IO # Have to rename it here as else it overlaps with schema.core.Type from typing import Type as TypingType @@ -758,12 +758,19 @@ def binding_range(**kwargs): # TODO: update the docstring -def condition(predicate, if_true, if_false, **kwargs): +def condition( + predicate: Union[Parameter, str, expr.Expression, core.PredicateComposition, dict], + # Types of these depends on where the condition is used so we probably + # can't be more specific here. + if_true: Any, + if_false: Any, + **kwargs, +) -> Union[dict, core.SchemaBase]: """A conditional attribute or encoding Parameters ---------- - predicate: Selection, PredicateComposition, expr.Expression, dict, or string + predicate: Parameter, PredicateComposition, expr.Expression, dict, or string the selection predicate or test predicate for the condition. if a string is passed, it will be treated as a test operand. if_true: @@ -833,7 +840,7 @@ def condition(predicate, if_true, if_false, **kwargs): class TopLevelMixin(mixins.ConfigMethodMixin): """Mixin for top-level chart objects such as Chart, LayeredChart, etc.""" - _class_is_valid_at_instantiation = False + _class_is_valid_at_instantiation: bool = False def to_dict( self, @@ -1001,12 +1008,12 @@ def to_json( def to_html( self, - base_url="https://cdn.jsdelivr.net/npm", - output_div="vis", - embed_options=None, - json_kwds=None, - fullhtml=True, - requirejs=False, + base_url: str = "https://cdn.jsdelivr.net/npm", + output_div: str = "vis", + embed_options: Optional[dict] = None, + json_kwds: Optional[dict] = None, + fullhtml: bool = True, + requirejs: bool = False, ) -> str: return utils.spec_to_html( self.to_dict(), @@ -1024,15 +1031,15 @@ def to_html( def save( self, - fp, - format=None, - override_data_transformer=True, - scale_factor=1.0, - vegalite_version=VEGALITE_VERSION, - vega_version=VEGA_VERSION, - vegaembed_version=VEGAEMBED_VERSION, + fp: Union[str, IO], + format: Literal["json", "html", "png", "svg", "pdf"] = None, + override_data_transformer: bool = True, + scale_factor: float = 1.0, + vegalite_version: str = VEGALITE_VERSION, + vega_version: str = VEGA_VERSION, + vegaembed_version: str = VEGAEMBED_VERSION, **kwargs, - ): + ) -> None: """Save a chart to file in a variety of formats Supported formats are json, html, png, svg, pdf; the last three require @@ -1083,32 +1090,32 @@ def save( # Fallback for when rendering fails; the full repr is too long to be # useful in nearly all cases. - def __repr__(self): + def __repr__(self) -> str: return "alt.{}(...)".format(self.__class__.__name__) # Layering and stacking - def __add__(self, other): + def __add__(self, other) -> "LayerChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be layered.") return layer(self, other) - def __and__(self, other): + def __and__(self, other) -> "VConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") return vconcat(self, other) - def __or__(self, other): + def __or__(self, other) -> "HConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") return hconcat(self, other) def repeat( self, - repeat=Undefined, - row=Undefined, - column=Undefined, - layer=Undefined, - columns=Undefined, + repeat: Union[List[str], UndefinedType] = Undefined, + row: Union[List[str], UndefinedType] = Undefined, + column: Union[List[str], UndefinedType] = Undefined, + layer: Union[List[str], UndefinedType] = Undefined, + columns: Union[int, UndefinedType] = Undefined, **kwargs, ) -> "RepeatChart": """Return a RepeatChart built from the chart @@ -1182,25 +1189,25 @@ def properties(self, **kwargs) -> Self: def project( self, - type=Undefined, - center=Undefined, - clipAngle=Undefined, - clipExtent=Undefined, - coefficient=Undefined, - distance=Undefined, - fraction=Undefined, - lobes=Undefined, - parallel=Undefined, - precision=Undefined, - radius=Undefined, - ratio=Undefined, - reflectX=Undefined, - reflectY=Undefined, - rotate=Undefined, - scale=Undefined, - spacing=Undefined, - tilt=Undefined, - translate=Undefined, + type: Union[str, UndefinedType] = Undefined, + center: Union[List[float], UndefinedType] = Undefined, + clipAngle: Union[float, UndefinedType] = Undefined, + clipExtent: Union[List[List[float]], UndefinedType] = Undefined, + coefficient: Union[float, UndefinedType] = Undefined, + distance: Union[float, UndefinedType] = Undefined, + fraction: Union[float, UndefinedType] = Undefined, + lobes: Union[float, UndefinedType] = Undefined, + parallel: Union[float, UndefinedType] = Undefined, + precision: Union[float, UndefinedType] = Undefined, + radius: Union[float, UndefinedType] = Undefined, + ratio: Union[float, UndefinedType] = Undefined, + reflectX: Union[bool, UndefinedType] = Undefined, + reflectY: Union[bool, UndefinedType] = Undefined, + rotate: Union[List[float], UndefinedType] = Undefined, + scale: Union[float, UndefinedType] = Undefined, + spacing: Union[float, UndefinedType] = Undefined, + tilt: Union[float, UndefinedType] = Undefined, + translate: Union[List[float], UndefinedType] = Undefined, **kwds, ) -> Self: """Add a geographic projection to the chart. @@ -1215,7 +1222,7 @@ def project( Parameters ---------- - type : ProjectionType + type : str The cartographic projection to use. This value is case-insensitive, for example `"albers"` and `"Albers"` indicate the same projection type. You can find all valid projection types [in the @@ -1237,16 +1244,28 @@ def project( left-side of the viewport, `y0` is the top, `x1` is the right and `y1` is the bottom. If `null`, no viewport clipping is performed. coefficient : float + The coefficient parameter for the ``hammer`` projection. + **Default value:** ``2`` distance : float + For the ``satellite`` projection, the distance from the center of the sphere to the + point of view, as a proportion of the sphere’s radius. The recommended maximum clip + angle for a given ``distance`` is acos(1 / distance) converted to degrees. If tilt + is also applied, then more conservative clipping may be necessary. + **Default value:** ``2.0`` fraction : float + The fraction parameter for the ``bottomley`` projection. + **Default value:** ``0.5``, corresponding to a sin(ψ) where ψ = π/6. lobes : float - + The number of lobes in projections that support multi-lobe views: ``berghaus``, + ``gingery``, or ``healpix``. The default value varies based on the projection type. parallel : float - - precision : Mapping(required=[length]) + For conic projections, the `two standard parallels + `__ that define the map layout. + The default depends on the specific conic projection used. + precision : float Sets the threshold for the projection’s [adaptive resampling](http://bl.ocks.org/mbostock/3795544) to the specified value in pixels. This value corresponds to the [Douglas–Peucker @@ -1254,13 +1273,15 @@ def project( If precision is not specified, returns the projection’s current resampling precision which defaults to `√0.5 ≅ 0.70710…`. radius : float - + The radius parameter for the ``airy`` or ``gingery`` projection. The default value + varies based on the projection type. ratio : float - + The ratio parameter for the ``hill``, ``hufnagel``, or ``wagner`` projections. The + default value varies based on the projection type. reflectX : boolean - + Sets whether or not the x-dimension is reflected (negated) in the output. reflectY : boolean - + Sets whether or not the y-dimension is reflected (negated) in the output. rotate : List(float) Sets the projection’s three-axis rotation to the specified angles, which must be a two- or three-element array of numbers [`lambda`, `phi`, `gamma`] specifying the @@ -1269,14 +1290,21 @@ def project( **Default value:** `[0, 0, 0]` scale : float - Sets the projection's scale (zoom) value, overriding automatic fitting. - + The projection’s scale (zoom) factor, overriding automatic fitting. The default + scale is projection-specific. The scale factor corresponds linearly to the distance + between projected points; however, scale factor values are not equivalent across + projections. spacing : float + The spacing parameter for the ``lagrange`` projection. + **Default value:** ``0.5`` tilt : float + The tilt angle (in degrees) for the ``satellite`` projection. + **Default value:** ``0``. translate : List(float) - Sets the projection's translation (pan) value, overriding automatic fitting. + The projection’s translation offset as a two-element array ``[tx, ty]``, + overriding automatic fitting. """ projection = core.Projection( @@ -1303,7 +1331,7 @@ def project( ) return self.properties(projection=projection) - def _add_transform(self, *transforms): + def _add_transform(self, *transforms: core.Transform) -> Self: """Copy the chart and add specified transforms to chart.transform""" copy = self.copy(deep=["transform"]) if copy.transform is Undefined: @@ -1312,7 +1340,10 @@ def _add_transform(self, *transforms): return copy def transform_aggregate( - self, aggregate=Undefined, groupby=Undefined, **kwds + self, + aggregate: Union[List[core.AggregatedFieldDef], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + **kwds: Union[Dict[str, Any], str], ) -> Self: """ Add an :class:`AggregateTransform` to the schema. @@ -1324,7 +1355,7 @@ def transform_aggregate( groupby : List(string) The data fields to group by. If not specified, a single group containing all data objects will be used. - **kwds : + **kwds : Union[Dict[str, Any], str] additional keywords are converted to aggregates using standard shorthand parsing. @@ -1387,7 +1418,13 @@ def transform_aggregate( core.AggregateTransform(aggregate=aggregate, groupby=groupby) ) - def transform_bin(self, as_=Undefined, field=Undefined, bin=True, **kwargs) -> Self: + def transform_bin( + self, + as_: Union[str, List[str], UndefinedType] = Undefined, + field: Union[str, UndefinedType] = Undefined, + bin: Union[Literal[True], core.BinParams] = True, + **kwargs, + ) -> Self: """ Add a :class:`BinTransform` to the schema. @@ -1443,7 +1480,12 @@ def transform_bin(self, as_=Undefined, field=Undefined, bin=True, **kwargs) -> S kwargs["field"] = field return self._add_transform(core.BinTransform(**kwargs)) - def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> Self: + def transform_calculate( + self, + as_: Union[str, UndefinedType] = Undefined, + calculate: Union[str, expr.core.Expression, UndefinedType] = Undefined, + **kwargs: Union[str, expr.core.Expression], + ) -> Self: """ Add a :class:`CalculateTransform` to the schema. @@ -1451,8 +1493,8 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S ---------- as_ : string The field for storing the computed formula value. - calculate : string or alt.expr expression - A `expression `__ + calculate : string or alt.expr.Expression + An `expression `__ string. Use the variable ``datum`` to refer to the current data object. **kwargs transforms can also be passed by keyword argument; see Examples @@ -1476,7 +1518,7 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S It's also possible to pass the ``CalculateTransform`` arguments directly: - >>> kwds = {'as': 'y', 'calculate': '2 * sin(datum.x)'} + >>> kwds = {'as_': 'y', 'calculate': '2 * sin(datum.x)'} >>> chart = alt.Chart().transform_calculate(**kwds) >>> chart.transform[0] CalculateTransform({ @@ -1507,16 +1549,16 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S def transform_density( self, - density, - as_=Undefined, - bandwidth=Undefined, - counts=Undefined, - cumulative=Undefined, - extent=Undefined, - groupby=Undefined, - maxsteps=Undefined, - minsteps=Undefined, - steps=Undefined, + density: str, + as_: Union[List[str], UndefinedType] = Undefined, + bandwidth: Union[float, UndefinedType] = Undefined, + counts: Union[bool, UndefinedType] = Undefined, + cumulative: Union[bool, UndefinedType] = Undefined, + extent: Union[List[float], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + maxsteps: Union[int, UndefinedType] = Undefined, + minsteps: Union[int, UndefinedType] = Undefined, + steps: Union[int, UndefinedType] = Undefined, ) -> Self: """Add a :class:`DensityTransform` to the spec. @@ -1546,13 +1588,13 @@ def transform_density( groupby : List(str) The data fields to group by. If not specified, a single group containing all data objects will be used. - maxsteps : float + maxsteps : int The maximum number of samples to take along the extent domain for plotting the density. **Default value:** ``200`` - minsteps : float + minsteps : int The minimum number of samples to take along the extent domain for plotting the density. **Default value:** ``25`` - steps : float + steps : int The exact number of samples to take along the extent domain for plotting the density. If specified, overrides both minsteps and maxsteps to set an exact number of uniform samples. Potentially useful in conjunction with a fixed extent to ensure @@ -1575,12 +1617,14 @@ def transform_density( def transform_impute( self, - impute, - key, - frame=Undefined, - groupby=Undefined, - keyvals=Undefined, - method=Undefined, + impute: str, + key: str, + frame: Union[List[Optional[int]], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + keyvals: Union[List[Any], core.ImputeSequence, UndefinedType] = Undefined, + method: Union[ + Literal["value", "mean", "median", "max", "min"], UndefinedType + ] = Undefined, value=Undefined, ) -> Self: """ @@ -1594,7 +1638,7 @@ def transform_impute( A key field that uniquely identifies data objects within a group. Missing key values (those occurring in the data but not in the current group) will be imputed. - frame : List(anyOf(None, float)) + frame : List(anyOf(None, int)) A frame specification as a two-element array used to control the window over which the specified method is applied. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded @@ -1644,7 +1688,12 @@ def transform_impute( ) def transform_joinaggregate( - self, joinaggregate=Undefined, groupby=Undefined, **kwargs + self, + joinaggregate: Union[ + List[core.JoinAggregateFieldDef], UndefinedType + ] = Undefined, + groupby: Union[str, UndefinedType] = Undefined, + **kwargs: str, ) -> Self: """ Add a :class:`JoinAggregateTransform` to the schema. @@ -1695,7 +1744,7 @@ def transform_joinaggregate( core.JoinAggregateTransform(joinaggregate=joinaggregate, groupby=groupby) ) - def transform_extent(self, extent: str, param: str): + def transform_extent(self, extent: str, param: str) -> Self: """Add a :class:`ExtentTransform` to the spec. Parameters @@ -1714,6 +1763,7 @@ def transform_extent(self, extent: str, param: str): return self._add_transform(core.ExtentTransform(extent=extent, param=param)) # TODO: Update docstring + # TODO: Add type hints. Was too complex to type hint `filter` for now def transform_filter(self, filter, **kwargs) -> Self: """ Add a :class:`FilterTransform` to the schema. @@ -1732,11 +1782,6 @@ def transform_filter(self, filter, **kwargs) -> Self: ------- self : Chart object returns chart to allow for chaining - - See Also - -------- - alt.FilterTransform : underlying transform object - """ if isinstance(filter, Parameter): new_filter = {"param": filter.name} @@ -1747,7 +1792,9 @@ def transform_filter(self, filter, **kwargs) -> Self: filter = new_filter return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) - def transform_flatten(self, flatten, as_=Undefined) -> Self: + def transform_flatten( + self, flatten: List[str], as_: Union[List[str], UndefinedType] = Undefined + ) -> Self: """Add a :class:`FlattenTransform` to the schema. Parameters @@ -1775,7 +1822,9 @@ def transform_flatten(self, flatten, as_=Undefined) -> Self: core.FlattenTransform(flatten=flatten, **{"as": as_}) ) - def transform_fold(self, fold, as_=Undefined) -> Self: + def transform_fold( + self, fold: List[str], as_: Union[List[str], UndefinedType] = Undefined + ) -> Self: """Add a :class:`FoldTransform` to the spec. Parameters @@ -1800,11 +1849,11 @@ def transform_fold(self, fold, as_=Undefined) -> Self: def transform_loess( self, - on, - loess, - as_=Undefined, - bandwidth=Undefined, - groupby=Undefined, + on: str, + loess: str, + as_: Union[List[str], UndefinedType] = Undefined, + bandwidth: Union[float, UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, ) -> Self: """Add a :class:`LoessTransform` to the spec. @@ -1842,10 +1891,10 @@ def transform_loess( def transform_lookup( self, - lookup=Undefined, - from_=Undefined, - as_=Undefined, - default=Undefined, + lookup: Union[str, UndefinedType] = Undefined, + from_: Union[core.LookupData, core.LookupSelection, UndefinedType] = Undefined, + as_: Union[str, List[str], UndefinedType] = Undefined, + default: Union[str, UndefinedType] = Undefined, **kwargs, ) -> Self: """Add a :class:`DataLookupTransform` or :class:`SelectionLookupTransform` to the chart @@ -1897,11 +1946,11 @@ def transform_lookup( def transform_pivot( self, - pivot, - value, - groupby=Undefined, - limit=Undefined, - op=Undefined, + pivot: str, + value: str, + groupby: Union[List[str], UndefinedType]=Undefined, + limit: Union[int, UndefinedType]=Undefined, + op: Union[str, UndefinedType]=Undefined, ) -> Self: """Add a :class:`PivotTransform` to the chart. @@ -1916,7 +1965,7 @@ def transform_pivot( groupby : List(str) The optional data fields to group by. If not specified, a single group containing all data objects will be used. - limit : float + limit : int An optional parameter indicating the maximum number of pivoted fields to generate. The default ( ``0`` ) applies no limit. The pivoted ``pivot`` names are sorted in ascending order prior to enforcing the limit. @@ -1943,11 +1992,11 @@ def transform_pivot( def transform_quantile( self, - quantile, - as_=Undefined, - groupby=Undefined, - probs=Undefined, - step=Undefined, + quantile: str, + as_: Union[List[str], UndefinedType]=Undefined, + groupby: Union[List[str], UndefinedType]=Undefined, + probs: Union[List[float], UndefinedType]=Undefined, + step: Union[float, UndefinedType]=Undefined, ) -> Self: """Add a :class:`QuantileTransform` to the chart From 1cfef3396223aa363ce34b4c87e3434b7bf7c526 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 26 Aug 2023 13:48:37 +0200 Subject: [PATCH 04/18] Add type hints to some more transform_methods --- altair/vegalite/v5/api.py | 71 ++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index d8345f3fc..203520fce 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1948,9 +1948,9 @@ def transform_pivot( self, pivot: str, value: str, - groupby: Union[List[str], UndefinedType]=Undefined, - limit: Union[int, UndefinedType]=Undefined, - op: Union[str, UndefinedType]=Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + limit: Union[int, UndefinedType] = Undefined, + op: Union[str, UndefinedType] = Undefined, ) -> Self: """Add a :class:`PivotTransform` to the chart. @@ -1993,10 +1993,10 @@ def transform_pivot( def transform_quantile( self, quantile: str, - as_: Union[List[str], UndefinedType]=Undefined, - groupby: Union[List[str], UndefinedType]=Undefined, - probs: Union[List[float], UndefinedType]=Undefined, - step: Union[float, UndefinedType]=Undefined, + as_: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + probs: Union[List[float], UndefinedType] = Undefined, + step: Union[float, UndefinedType] = Undefined, ) -> Self: """Add a :class:`QuantileTransform` to the chart @@ -2038,14 +2038,16 @@ def transform_quantile( def transform_regression( self, - on, - regression, - as_=Undefined, - extent=Undefined, - groupby=Undefined, - method=Undefined, - order=Undefined, - params=Undefined, + on: str, + regression: str, + as_: Union[List[str], UndefinedType] = Undefined, + extent: Union[List[float], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + method: Union[ + Literal["linear", "log", "exp", "pow", "quad", "poly"], UndefinedType + ] = Undefined, + order: Union[int, UndefinedType] = Undefined, + params: Union[bool, UndefinedType] = Undefined, ) -> Self: """Add a :class:`RegressionTransform` to the chart. @@ -2067,7 +2069,7 @@ def transform_regression( method : enum('linear', 'log', 'exp', 'pow', 'quad', 'poly') The functional form of the regression model. One of ``"linear"``, ``"log"``, ``"exp"``, ``"pow"``, ``"quad"``, or ``"poly"``. **Default value:** ``"linear"`` - order : float + order : int The polynomial order (number of coefficients) for the 'poly' method. **Default value:** ``3`` params : boolean @@ -2101,13 +2103,13 @@ def transform_regression( ) ) - def transform_sample(self, sample=1000) -> Self: + def transform_sample(self, sample: int = 1000) -> Self: """ Add a :class:`SampleTransform` to the schema. Parameters ---------- - sample : float + sample : int The maximum number of data objects to include in the sample. Default: 1000. Returns @@ -2122,7 +2124,14 @@ def transform_sample(self, sample=1000) -> Self: return self._add_transform(core.SampleTransform(sample)) def transform_stack( - self, as_, stack, groupby, offset=Undefined, sort=Undefined + self, + as_: Union[str, List[str]], + stack: str, + groupby: List[str], + offset: Union[ + Literal["zero", "center", "normalize"], UndefinedType + ] = Undefined, + sort: Union[List[core.SortField], UndefinedType] = Undefined, ) -> Self: """ Add a :class:`StackTransform` to the schema. @@ -2160,10 +2169,10 @@ def transform_stack( def transform_timeunit( self, - as_=Undefined, - field=Undefined, - timeUnit=Undefined, - **kwargs, + as_: Union[str, UndefinedType] = Undefined, + field: Union[str, UndefinedType] = Undefined, + timeUnit: Union[str, core.TimeUnit, UndefinedType] = Undefined, + **kwargs: str, ) -> Self: """ Add a :class:`TimeUnitTransform` to the schema. @@ -2174,7 +2183,7 @@ def transform_timeunit( The output field to write the timeUnit value. field : string The data field to apply time unit. - timeUnit : :class:`TimeUnit` + timeUnit : str or :class:`TimeUnit` The timeUnit. **kwargs transforms can also be passed by keyword argument; see Examples @@ -2244,12 +2253,12 @@ def transform_timeunit( def transform_window( self, - window=Undefined, - frame=Undefined, - groupby=Undefined, - ignorePeers=Undefined, - sort=Undefined, - **kwargs, + window: Union[List[core.WindowFieldDef], UndefinedType] = Undefined, + frame: Union[List[Optional[int]], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + ignorePeers: Union[bool, UndefinedType] = Undefined, + sort: Union[List[core.SortField], UndefinedType] = Undefined, + **kwargs: str, ) -> Self: """Add a :class:`WindowTransform` to the schema @@ -2257,7 +2266,7 @@ def transform_window( ---------- window : List(:class:`WindowFieldDef`) The definition of the fields in the window, and what calculations to use. - frame : List(anyOf(None, float)) + frame : List(anyOf(None, int)) A frame specification as a two-element array indicating how the sliding window should proceed. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded rows preceding or From e0ca5822977c3d3ce9c6aa98f8d0bf5d29b79212 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 26 Aug 2023 14:09:16 +0200 Subject: [PATCH 05/18] Finish first pass of type hints for public objects --- altair/vegalite/v5/api.py | 102 +++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 39 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 203520fce..09d1b6599 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2367,7 +2367,13 @@ def _repr_mimebundle_(self, include=None, exclude=None): else: return renderers.get()(dct) - def display(self, renderer=Undefined, theme=Undefined, actions=Undefined, **kwargs): + def display( + self, + renderer: Union[Literal["canvas", "svg"], UndefinedType] = Undefined, + theme: Union[str, UndefinedType] = Undefined, + actions: Union[bool, dict, UndefinedType] = Undefined, + **kwargs, + ) -> None: """Display chart in Jupyter notebook or JupyterLab Parameters are passed as options to vega-embed within supported frontends. @@ -2463,7 +2469,9 @@ def serve( http_server=http_server, ) - def show(self, embed_opt=None, open_browser=None): + def show( + self, embed_opt: Optional[dict] = None, open_browser: Optional[bool] = None + ) -> None: """Show the chart in an external browser window. This requires a recent version of the altair_viewer package. @@ -2471,7 +2479,7 @@ def show(self, embed_opt=None, open_browser=None): Parameters ---------- embed_opt : dict (optional) - The Vega embed options that control the dispay of the chart. + The Vega embed options that control the display of the chart. open_browser : bool (optional) Specify whether a browser window should be opened. If not specified, a browser window will be opened only if the server is not already @@ -2554,7 +2562,7 @@ def facet( column : string or alt.Column (optional) The data column to use as an encoding for a column facet. May be combined with row argument, but not with facet argument. - row : string or alt.Column (optional) + row : string or alt.Row (optional) The data column to use as an encoding for a row facet. May be combined with column argument, but not with facet argument. data : string or dataframe (optional) @@ -2676,12 +2684,12 @@ def __init__( _counter = 0 @classmethod - def _get_name(cls): + def _get_name(cls) -> str: cls._counter += 1 return f"view_{cls._counter}" @classmethod - def from_dict(cls, dct, validate=True) -> core.SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future + def from_dict(cls, dct: dict, validate: bool = True) -> core.SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future """Construct class from a dictionary representation Parameters @@ -2793,7 +2801,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params: return self @@ -2812,7 +2820,9 @@ def add_selection(self, *params) -> Self: """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*params) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -2839,7 +2849,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: return self.add_params(selection_interval(bind="scales", encodings=encodings)) -def _check_if_valid_subspec(spec, classname): +def _check_if_valid_subspec(spec: Union[dict, core.SchemaBase], classname: str) -> None: """Check if the spec is a valid sub-spec. If it is not, then raise a ValueError @@ -2860,7 +2870,7 @@ def _check_if_valid_subspec(spec, classname): raise ValueError(err.format(attr, classname)) -def _check_if_can_be_layered(spec): +def _check_if_can_be_layered(spec: Union[dict, core.SchemaBase]) -> None: """Check if the spec can be layered.""" def _get(spec, attr): @@ -2994,7 +3004,9 @@ def transformed_data( "transformed_data is not yet implemented for RepeatChart" ) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3017,7 +3029,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: copy.spec = copy.spec.interactive(name=name, bind_x=bind_x, bind_y=bind_y) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or self.spec is Undefined: return self @@ -3033,7 +3045,9 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def repeat(repeater="repeat"): +def repeat( + repeater: Literal["row", "column", "repeat", "layer"] = "repeat" +) -> core.RepeatRef: """Tie a channel to the row or column within a repeated chart The output of this should be passed to the ``field`` attribute of @@ -3067,14 +3081,14 @@ def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs): self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) - def __ior__(self, other): + def __ior__(self, other: core.NonNormalizedSpec) -> Self: _check_if_valid_subspec(other, "ConcatChart") self.concat.append(other) self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) return self - def __or__(self, other): + def __or__(self, other: core.NonNormalizedSpec) -> Self: copy = self.copy(deep=["concat"]) copy |= other return copy @@ -3105,7 +3119,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3131,7 +3147,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.concat: return self @@ -3147,7 +3163,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def concat(*charts, **kwargs): +def concat(*charts: core.NonNormalizedSpec, **kwargs) -> ConcatChart: """Concatenate charts horizontally""" return ConcatChart(concat=charts, **kwargs) @@ -3164,14 +3180,14 @@ def __init__(self, data=Undefined, hconcat=(), **kwargs): self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) - def __ior__(self, other): + def __ior__(self, other: core.NonNormalizedSpec) -> Self: _check_if_valid_subspec(other, "HConcatChart") self.hconcat.append(other) self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) return self - def __or__(self, other): + def __or__(self, other: core.NonNormalizedSpec) -> Self: copy = self.copy(deep=["hconcat"]) copy |= other return copy @@ -3202,7 +3218,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3228,7 +3246,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.hconcat: return self @@ -3244,7 +3262,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def hconcat(*charts, **kwargs): +def hconcat(*charts: core.NonNormalizedSpec, **kwargs) -> HConcatChart: """Concatenate charts horizontally""" return HConcatChart(hconcat=charts, **kwargs) @@ -3261,14 +3279,14 @@ def __init__(self, data=Undefined, vconcat=(), **kwargs): self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) - def __iand__(self, other): + def __iand__(self, other: core.NonNormalizedSpec) -> Self: _check_if_valid_subspec(other, "VConcatChart") self.vconcat.append(other) self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) return self - def __and__(self, other): + def __and__(self, other: core.NonNormalizedSpec) -> Self: copy = self.copy(deep=["vconcat"]) copy &= other return copy @@ -3299,7 +3317,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3325,7 +3345,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.vconcat: return self @@ -3341,7 +3361,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def vconcat(*charts, **kwargs): +def vconcat(*charts: core.NonNormalizedSpec, **kwargs) -> VConcatChart: """Concatenate charts vertically""" return VConcatChart(vconcat=charts, **kwargs) @@ -3395,7 +3415,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def __iadd__(self, other): + def __iadd__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) self.layer.append(other) @@ -3403,18 +3423,20 @@ def __iadd__(self, other): self.params, self.layer = _combine_subchart_params(self.params, self.layer) return self - def __add__(self, other): + def __add__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: copy = self.copy(deep=["layer"]) copy += other return copy - def add_layers(self, *layers) -> Self: + def add_layers(self, *layers: Union[core.LayerSpec, core.UnitSpec]) -> Self: copy = self.copy(deep=["layer"]) for layer in layers: copy += layer return copy - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3443,7 +3465,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: ) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.layer: return self @@ -3459,7 +3481,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def layer(*charts, **kwargs): +def layer(*charts: Union[core.LayerSpec, core.UnitSpec], **kwargs) -> LayerChart: """layer multiple charts""" return LayerChart(layer=charts, **kwargs) @@ -3510,7 +3532,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3533,7 +3557,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: copy.spec = copy.spec.interactive(name=name, bind_x=bind_x, bind_y=bind_y) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or self.spec is Undefined: return self @@ -3549,7 +3573,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def topo_feature(url, feature, **kwargs): +def topo_feature(url: str, feature: str, **kwargs) -> core.UrlData: """A convenience function for extracting features from a topojson url Parameters @@ -3597,7 +3621,7 @@ def remove_data(subchart): return data, subcharts -def _viewless_dict(param): +def _viewless_dict(param: Parameter) -> dict: d = param.to_dict() d.pop("views", None) return d @@ -3859,6 +3883,6 @@ def graticule(**kwds): return core.GraticuleGenerator(graticule=graticule) -def sphere(): +def sphere() -> core.SphereGenerator: """Sphere generator.""" return core.SphereGenerator(sphere=True) From a60d56f6c66d0815ec5b2e6e5b4e633a22eec432 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 26 Aug 2023 14:19:53 +0200 Subject: [PATCH 06/18] Move UndefinedType always to last item of Union --- altair/vegalite/v5/api.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 09d1b6599..9a5590823 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -328,10 +328,10 @@ def value(value, **kwargs) -> dict: def param( name: Optional[str] = None, - value: Union[UndefinedType, Any] = Undefined, - bind: Union[UndefinedType, core.Binding] = Undefined, - empty: Union[UndefinedType, bool] = Undefined, - expr: Union[UndefinedType, core.Expr] = Undefined, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[core.Expr, UndefinedType] = Undefined, **kwds, ) -> Parameter: """Create a named parameter. @@ -424,7 +424,7 @@ def param( def _selection( - type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds + type: Union[Literal["interval", "point"], UndefinedType] = Undefined, **kwds ) -> Parameter: # We separate out the parameter keywords from the selection keywords param_kwds = {} @@ -457,7 +457,7 @@ def _selection( Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings.""" ) def selection( - type: Union[UndefinedType, Literal["interval", "point"]] = Undefined, **kwds + type: Union[Literal["interval", "point"], UndefinedType] = Undefined, **kwds ) -> Parameter: """ Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create. From ad7c174e0d5bb6e68021a200cf90f5eff28061df Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 27 Aug 2023 15:50:47 +0200 Subject: [PATCH 07/18] Type hint transform_filter --- altair/vegalite/v5/api.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 9a5590823..617aef528 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1763,8 +1763,19 @@ def transform_extent(self, extent: str, param: str) -> Self: return self._add_transform(core.ExtentTransform(extent=extent, param=param)) # TODO: Update docstring - # TODO: Add type hints. Was too complex to type hint `filter` for now - def transform_filter(self, filter, **kwargs) -> Self: + def transform_filter( + self, + filter: Union[ + str, + expr.core.Expression, + core.Predicate, + Parameter, + core.PredicateComposition, + # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} + Dict[str, core.Predicate], + ], + **kwargs, + ) -> Self: """ Add a :class:`FilterTransform` to the schema. From c8f0dfb1da4b469d5d64379b051a1c2674a5a047 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 27 Aug 2023 15:58:22 +0200 Subject: [PATCH 08/18] Type hint Chart.__init__ --- altair/vegalite/v5/api.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 617aef528..3d12e2243 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -26,6 +26,7 @@ compile_with_vegafusion as _compile_with_vegafusion, ) from ...utils.core import _DataFrameLike +from ...utils.data import _DataType if sys.version_info >= (3, 11): from typing import Self @@ -2676,13 +2677,15 @@ class Chart( def __init__( self, - data=Undefined, - encoding=Undefined, - mark=Undefined, - width=Undefined, - height=Undefined, + data: Union[ + _DataType, core.Data, str, UndefinedType, core.Generator + ] = Undefined, + encoding: Union[core.FacetedEncoding, UndefinedType] = Undefined, + mark: Union[str, core.AnyMark, UndefinedType] = Undefined, + width: Union[int, str, dict, core.Step, UndefinedType] = Undefined, + height: Union[int, str, dict, core.Step, UndefinedType] = Undefined, **kwargs, - ): + ) -> None: super(Chart, self).__init__( data=data, encoding=encoding, @@ -2692,7 +2695,7 @@ def __init__( **kwargs, ) - _counter = 0 + _counter: int = 0 @classmethod def _get_name(cls) -> str: From d647c973ee43fdf2195609d54e4fd168ee4b6f0c Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 27 Aug 2023 16:09:39 +0200 Subject: [PATCH 09/18] Type hint .facet --- altair/vegalite/v5/api.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 3d12e2243..722f55386 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -33,6 +33,8 @@ else: from typing_extensions import Self +_ChartDataType = Union[_DataType, core.Data, str, UndefinedType, core.Generator] + # ------------------------------------------------------------------------ # Data Utilities @@ -2553,11 +2555,13 @@ def encode(self, *args, **kwargs) -> Self: def facet( self, - facet=Undefined, - row=Undefined, - column=Undefined, - data=Undefined, - columns=Undefined, + facet: Union[str, channels.Facet, UndefinedType] = Undefined, + row: Union[str, core.FacetFieldDef, channels.Row, UndefinedType] = Undefined, + column: Union[ + str, core.FacetFieldDef, channels.Column, UndefinedType + ] = Undefined, + data: Union[_ChartDataType, UndefinedType] = Undefined, + columns: Union[int, UndefinedType] = Undefined, **kwargs, ) -> "FacetChart": """Create a facet chart from the current chart. @@ -2568,13 +2572,13 @@ def facet( Parameters ---------- - facet : string or alt.Facet (optional) + facet : string, Facet (optional) The data column to use as an encoding for a wrapped facet. If specified, then neither row nor column may be specified. - column : string or alt.Column (optional) + column : string, Column, FacetFieldDef (optional) The data column to use as an encoding for a column facet. May be combined with row argument, but not with facet argument. - row : string or alt.Row (optional) + row : string or Row, FacetFieldDef (optional) The data column to use as an encoding for a row facet. May be combined with column argument, but not with facet argument. data : string or dataframe (optional) @@ -2677,9 +2681,7 @@ class Chart( def __init__( self, - data: Union[ - _DataType, core.Data, str, UndefinedType, core.Generator - ] = Undefined, + data: Union[_ChartDataType, UndefinedType] = Undefined, encoding: Union[core.FacetedEncoding, UndefinedType] = Undefined, mark: Union[str, core.AnyMark, UndefinedType] = Undefined, width: Union[int, str, dict, core.Step, UndefinedType] = Undefined, From 0154b075973c11bdb74853afaa42b8e9a9455838 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Mon, 28 Aug 2023 20:19:07 +0200 Subject: [PATCH 10/18] Minor stuff --- altair/vegalite/v5/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 722f55386..edeb1cce3 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -334,7 +334,7 @@ def param( value: Union[Any, UndefinedType] = Undefined, bind: Union[core.Binding, UndefinedType] = Undefined, empty: Union[bool, UndefinedType] = Undefined, - expr: Union[core.Expr, UndefinedType] = Undefined, + expr: Union[core.Expr, expr.core.Expression, UndefinedType] = Undefined, **kwds, ) -> Parameter: """Create a named parameter. @@ -489,7 +489,7 @@ def selection_interval( value: Union[Any, UndefinedType] = Undefined, bind: Union[core.Binding, UndefinedType] = Undefined, empty: Union[bool, UndefinedType] = Undefined, - expr: Union[core.Expr, UndefinedType] = Undefined, + expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined, encodings: Union[List[str], UndefinedType] = Undefined, on: Union[str, UndefinedType] = Undefined, clear: Union[str, bool, UndefinedType] = Undefined, From 9ca1af9612774e03f3f4de9be07c4ac3b2848aba Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 24 Sep 2023 17:30:13 +0200 Subject: [PATCH 11/18] Various mypy error fixes --- altair/vegalite/v5/api.py | 81 ++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index ce6cc77d4..f37ef2c7e 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,9 +8,9 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal, IO +from typing import cast, List, Optional, Any, Iterable, Union, Literal, IO -# Have to rename it here as else it overlaps with schema.core.Type +# Have to rename it here as else it overlaps with schema.core.Type and schema.core.Dict from typing import Type as TypingType from typing import Dict as TypingDict @@ -175,8 +175,8 @@ def to_dict(self, *args, **kwargs) -> dict: TOPLEVEL_ONLY_KEYS = {"background", "config", "autosize", "padding", "$schema"} -def _get_channels_mapping() -> Dict[Type[core.SchemaBase], str]: - mapping: Dict[Type[core.SchemaBase], str] = {} +def _get_channels_mapping() -> TypingDict[TypingType[core.SchemaBase], str]: + mapping: TypingDict[TypingType[core.SchemaBase], str] = {} for attr in dir(channels): cls = getattr(channels, attr) if isinstance(cls, type) and issubclass(cls, core.SchemaBase): @@ -299,7 +299,7 @@ class ParameterExpression(expr.core.OperatorMixin, object): def __init__(self, expr) -> None: self.expr = expr - def to_dict(self) -> Dict[str, str]: + def to_dict(self) -> TypingDict[str, str]: return {"expr": repr(self.expr)} def _to_expr(self) -> str: @@ -313,7 +313,7 @@ class SelectionExpression(expr.core.OperatorMixin, object): def __init__(self, expr) -> None: self.expr = expr - def to_dict(self) -> Dict[str, str]: + def to_dict(self) -> TypingDict[str, str]: return {"expr": repr(self.expr)} def _to_expr(self) -> str: @@ -326,7 +326,7 @@ def _from_expr(self, expr) -> "SelectionExpression": def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: for prop in ["fields", "encodings"]: try: - if field_name in getattr(parameter.param.select, prop): + if field_name in getattr(parameter.param.select, prop): # type: ignore[union-attr] return True except (AttributeError, TypeError): pass @@ -501,7 +501,7 @@ def selection( def selection_interval( name: Optional[str] = None, value: Union[Any, UndefinedType] = Undefined, - bind: Union[core.Binding, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, empty: Union[bool, UndefinedType] = Undefined, expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined, encodings: Union[List[str], UndefinedType] = Undefined, @@ -523,7 +523,7 @@ def selection_interval( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) @@ -613,7 +613,7 @@ def selection_interval( def selection_point( name: Optional[str] = None, value: Union[Any, UndefinedType] = Undefined, - bind: Union[core.Binding, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, empty: Union[bool, UndefinedType] = Undefined, expr: Union[core.Expr, UndefinedType] = Undefined, encodings: Union[List[str], UndefinedType] = Undefined, @@ -635,7 +635,7 @@ def selection_point( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) @@ -804,15 +804,21 @@ def condition( """ test_predicates = (str, expr.Expression, core.PredicateComposition) + condition: TypingDict[ + str, Union[bool, str, expr.core.Expression, core.PredicateComposition] + ] if isinstance(predicate, Parameter): - if predicate.param_type == "selection" or predicate.param.expr is Undefined: + if ( + predicate.param_type == "selection" + or getattr(predicate.param, "expr", Undefined) is Undefined + ): condition = {"param": predicate.name} if "empty" in kwargs: condition["empty"] = kwargs.pop("empty") elif isinstance(predicate.empty, bool): condition["empty"] = predicate.empty else: - condition = {"test": predicate.param.expr} + condition = {"test": getattr(predicate.param, "expr", Undefined)} elif isinstance(predicate, test_predicates): condition = {"test": predicate} elif isinstance(predicate, dict): @@ -836,6 +842,7 @@ def condition( if_true.update(kwargs) condition.update(if_true) + selection: Union[dict, core.SchemaBase] if isinstance(if_false, core.SchemaBase): # For the selection, the channel definitions all allow selections # already. So use this SchemaBase wrapper if possible. @@ -1049,7 +1056,7 @@ def to_html( def save( self, fp: Union[str, IO], - format: Literal["json", "html", "png", "svg", "pdf"] = None, + format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None, override_data_transformer: bool = True, scale_factor: float = 1.0, vegalite_version: str = VEGALITE_VERSION, @@ -1114,17 +1121,20 @@ def __repr__(self) -> str: def __add__(self, other) -> "LayerChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be layered.") - return layer(self, other) + # Too difficult to type check this + return layer(self, other) # type: ignore[arg-type] def __and__(self, other) -> "VConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") - return vconcat(self, other) + # Too difficult to type check this + return vconcat(self, other) # type: ignore[arg-type] def __or__(self, other) -> "HConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") - return hconcat(self, other) + # Too difficult to type check this + return hconcat(self, other) # type: ignore[arg-type] def repeat( self, @@ -1174,14 +1184,16 @@ def repeat( elif repeat_specified and layer_specified: raise ValueError("repeat argument cannot be combined with layer argument.") + repeat_arg: Union[List[str], core.LayerRepeatMapping, core.RepeatMapping] if repeat_specified: - repeat = repeat + assert not isinstance(repeat, UndefinedType) # For mypy + repeat_arg = repeat elif layer_specified: - repeat = core.LayerRepeatMapping(layer=layer, row=row, column=column) + repeat_arg = core.LayerRepeatMapping(layer=layer, row=row, column=column) else: - repeat = core.RepeatMapping(row=row, column=column) + repeat_arg = core.RepeatMapping(row=row, column=column) - return RepeatChart(spec=self, repeat=repeat, columns=columns, **kwargs) + return RepeatChart(spec=self, repeat=repeat_arg, columns=columns, **kwargs) def properties(self, **kwargs) -> Self: """Set top-level properties of the Chart. @@ -1350,7 +1362,7 @@ def project( def _add_transform(self, *transforms: core.Transform) -> Self: """Copy the chart and add specified transforms to chart.transform""" - copy = self.copy(deep=["transform"]) + copy = self.copy(deep=["transform"]) # type: ignore[attr-defined] if copy.transform is Undefined: copy.transform = [] copy.transform.extend(transforms) @@ -1360,7 +1372,7 @@ def transform_aggregate( self, aggregate: Union[List[core.AggregatedFieldDef], UndefinedType] = Undefined, groupby: Union[List[str], UndefinedType] = Undefined, - **kwds: Union[Dict[str, Any], str], + **kwds: Union[TypingDict[str, Any], str], ) -> Self: """ Add an :class:`AggregateTransform` to the schema. @@ -1372,7 +1384,7 @@ def transform_aggregate( groupby : List(string) The data fields to group by. If not specified, a single group containing all data objects will be used. - **kwds : Union[Dict[str, Any], str] + **kwds : Union[TypingDict[str, Any], str] additional keywords are converted to aggregates using standard shorthand parsing. @@ -1430,6 +1442,7 @@ def transform_aggregate( "field": parsed.get("field", Undefined), "op": parsed.get("aggregate", Undefined), } + assert not isinstance(aggregate, UndefinedType) # For mypy aggregate.append(core.AggregatedFieldDef(**dct)) return self._add_transform( core.AggregateTransform(aggregate=aggregate, groupby=groupby) @@ -1551,7 +1564,11 @@ def transform_calculate( alt.CalculateTransform : underlying transform object """ if as_ is Undefined: - as_ = kwargs.pop("as", Undefined) + # Ignoring assignment error as passing 'as' as a keyword argument is + # an edge case and it's not worth changing the type annotation + # in this function to account for it as it could be confusing to + # users. + as_ = kwargs.pop("as", Undefined) # type: ignore[assignment] elif "as" in kwargs: raise ValueError( "transform_calculate: both 'as_' and 'as' passed as arguments." @@ -1756,6 +1773,7 @@ def transform_joinaggregate( "field": parsed.get("field", Undefined), "op": parsed.get("aggregate", Undefined), } + assert not isinstance(joinaggregate, UndefinedType) # For mypy joinaggregate.append(core.JoinAggregateFieldDef(**dct)) return self._add_transform( core.JoinAggregateTransform(joinaggregate=joinaggregate, groupby=groupby) @@ -1789,7 +1807,7 @@ def transform_filter( Parameter, core.PredicateComposition, # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} - Dict[str, core.Predicate], + TypingDict[str, Union[core.Predicate, str, bool]], ], **kwargs, ) -> Self: @@ -1817,7 +1835,7 @@ def transform_filter( new_filter["empty"] = kwargs.pop("empty") elif isinstance(filter.empty, bool): new_filter["empty"] = filter.empty - filter = new_filter + filter = new_filter # type: ignore[assignment] return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) def transform_flatten( @@ -2285,7 +2303,7 @@ def transform_window( frame: Union[List[Optional[int]], UndefinedType] = Undefined, groupby: Union[List[str], UndefinedType] = Undefined, ignorePeers: Union[bool, UndefinedType] = Undefined, - sort: Union[List[core.SortField], UndefinedType] = Undefined, + sort: Union[List[Union[core.SortField, dict[str, str]]], UndefinedType] = Undefined, **kwargs: str, ) -> Self: """Add a :class:`WindowTransform` to the schema @@ -2369,6 +2387,7 @@ def transform_window( parse_types=False, ) ) + assert not isinstance(window, UndefinedType) # For mypy window.append(core.WindowFieldDef(**kwds)) return self._add_transform( @@ -3111,14 +3130,16 @@ def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs): self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) - def __ior__(self, other: core.NonNormalizedSpec) -> Self: + # Too difficult to fix override error + def __ior__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override] _check_if_valid_subspec(other, "ConcatChart") self.concat.append(other) self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) return self - def __or__(self, other: core.NonNormalizedSpec) -> Self: + # Too difficult to fix override error + def __or__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override] copy = self.copy(deep=["concat"]) copy |= other return copy From 81f3948d7d9f6369fd870b6d47a097adf6da2c2e Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 24 Sep 2023 18:20:37 +0200 Subject: [PATCH 12/18] Fix remaining mypy errors --- altair/vegalite/v5/api.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index f37ef2c7e..781078170 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1121,8 +1121,7 @@ def __repr__(self) -> str: def __add__(self, other) -> "LayerChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be layered.") - # Too difficult to type check this - return layer(self, other) # type: ignore[arg-type] + return layer(self, other) def __and__(self, other) -> "VConcatChart": if not isinstance(other, TopLevelMixin): @@ -1133,8 +1132,7 @@ def __and__(self, other) -> "VConcatChart": def __or__(self, other) -> "HConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") - # Too difficult to type check this - return hconcat(self, other) # type: ignore[arg-type] + return hconcat(self, other) def repeat( self, @@ -1726,7 +1724,7 @@ def transform_joinaggregate( joinaggregate: Union[ List[core.JoinAggregateFieldDef], UndefinedType ] = Undefined, - groupby: Union[str, UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, **kwargs: str, ) -> Self: """ @@ -1807,7 +1805,7 @@ def transform_filter( Parameter, core.PredicateComposition, # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} - TypingDict[str, Union[core.Predicate, str, bool]], + TypingDict[str, Union[core.Predicate, str, list, bool]], ], **kwargs, ) -> Self: @@ -3214,7 +3212,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def concat(*charts: core.NonNormalizedSpec, **kwargs) -> ConcatChart: +def concat(*charts, **kwargs) -> ConcatChart: """Concatenate charts horizontally""" return ConcatChart(concat=charts, **kwargs) @@ -3313,7 +3311,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def hconcat(*charts: core.NonNormalizedSpec, **kwargs) -> HConcatChart: +def hconcat(*charts, **kwargs) -> HConcatChart: """Concatenate charts horizontally""" return HConcatChart(hconcat=charts, **kwargs) @@ -3532,7 +3530,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def layer(*charts: Union[core.LayerSpec, core.UnitSpec], **kwargs) -> LayerChart: +def layer(*charts, **kwargs) -> LayerChart: """layer multiple charts""" return LayerChart(layer=charts, **kwargs) From 1b843cfa5b6177cf43cef04d989350a274c43599 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Thu, 28 Sep 2023 18:02:40 +0200 Subject: [PATCH 13/18] Add more core Altair classes to type hints --- altair/vegalite/v5/api.py | 159 +++++++++++++++++++++++--------------- 1 file changed, 96 insertions(+), 63 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 781078170..dbcdf1351 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -346,9 +346,9 @@ def value(value, **kwargs) -> dict: def param( name: Optional[str] = None, value: Union[Any, UndefinedType] = Undefined, - bind: Union[core.Binding, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, empty: Union[bool, UndefinedType] = Undefined, - expr: Union[core.Expr, expr.core.Expression, UndefinedType] = Undefined, + expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined, **kwds, ) -> Parameter: """Create a named parameter. @@ -365,14 +365,14 @@ def param( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) For selection parameters, the predicate of empty selections returns True by default. Override this behavior, by setting this property 'empty=False'. - expr : :class:`Expr` (optional) + expr : str, Expression (optional) An expression for the value of the parameter. This expression may include other parameters, in which case the parameter will automatically update in response to upstream parameter changes. @@ -776,7 +776,9 @@ def binding_range(**kwargs): # TODO: update the docstring def condition( - predicate: Union[Parameter, str, expr.Expression, core.PredicateComposition, dict], + predicate: Union[ + Parameter, str, expr.Expression, core.Expr, core.PredicateComposition, dict + ], # Types of these depends on where the condition is used so we probably # can't be more specific here. if_true: Any, @@ -1216,25 +1218,39 @@ def properties(self, **kwargs) -> Self: def project( self, - type: Union[str, UndefinedType] = Undefined, - center: Union[List[float], UndefinedType] = Undefined, - clipAngle: Union[float, UndefinedType] = Undefined, - clipExtent: Union[List[List[float]], UndefinedType] = Undefined, - coefficient: Union[float, UndefinedType] = Undefined, - distance: Union[float, UndefinedType] = Undefined, - fraction: Union[float, UndefinedType] = Undefined, - lobes: Union[float, UndefinedType] = Undefined, - parallel: Union[float, UndefinedType] = Undefined, - precision: Union[float, UndefinedType] = Undefined, - radius: Union[float, UndefinedType] = Undefined, - ratio: Union[float, UndefinedType] = Undefined, - reflectX: Union[bool, UndefinedType] = Undefined, - reflectY: Union[bool, UndefinedType] = Undefined, - rotate: Union[List[float], UndefinedType] = Undefined, - scale: Union[float, UndefinedType] = Undefined, - spacing: Union[float, UndefinedType] = Undefined, - tilt: Union[float, UndefinedType] = Undefined, - translate: Union[List[float], UndefinedType] = Undefined, + type: Union[str, core.ProjectionType, core.ExprRef, UndefinedType] = Undefined, + center: Union[ + List[float], core.Vector2number, core.ExprRef, UndefinedType + ] = Undefined, + clipAngle: Union[float, core.ExprRef, UndefinedType] = Undefined, + clipExtent: Union[ + List[List[float]], core.Vector2Vector2number, core.ExprRef, UndefinedType + ] = Undefined, + coefficient: Union[float, core.ExprRef, UndefinedType] = Undefined, + distance: Union[float, core.ExprRef, UndefinedType] = Undefined, + fraction: Union[float, core.ExprRef, UndefinedType] = Undefined, + lobes: Union[float, core.ExprRef, UndefinedType] = Undefined, + parallel: Union[float, core.ExprRef, UndefinedType] = Undefined, + precision: Union[float, core.ExprRef, UndefinedType] = Undefined, + radius: Union[float, core.ExprRef, UndefinedType] = Undefined, + ratio: Union[float, core.ExprRef, UndefinedType] = Undefined, + reflectX: Union[bool, core.ExprRef, UndefinedType] = Undefined, + reflectY: Union[bool, core.ExprRef, UndefinedType] = Undefined, + rotate: Union[ + List[float], + core.Vector2number, + core.Vector3number, + core.ExprRef, + UndefinedType, + ] = Undefined, + scale: Union[float, core.ExprRef, UndefinedType] = Undefined, + spacing: Union[ + float, core.Vector2number, core.ExprRef, UndefinedType + ] = Undefined, + tilt: Union[float, core.ExprRef, UndefinedType] = Undefined, + translate: Union[ + List[float], core.Vector2number, core.ExprRef, UndefinedType + ] = Undefined, **kwds, ) -> Self: """Add a geographic projection to the chart. @@ -1369,7 +1385,7 @@ def _add_transform(self, *transforms: core.Transform) -> Self: def transform_aggregate( self, aggregate: Union[List[core.AggregatedFieldDef], UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, **kwds: Union[TypingDict[str, Any], str], ) -> Self: """ @@ -1448,8 +1464,10 @@ def transform_aggregate( def transform_bin( self, - as_: Union[str, List[str], UndefinedType] = Undefined, - field: Union[str, UndefinedType] = Undefined, + as_: Union[ + str, core.FieldName, List[Union[str, core.FieldName]], UndefinedType + ] = Undefined, + field: Union[str, core.FieldName, UndefinedType] = Undefined, bin: Union[Literal[True], core.BinParams] = True, **kwargs, ) -> Self: @@ -1510,9 +1528,11 @@ def transform_bin( def transform_calculate( self, - as_: Union[str, UndefinedType] = Undefined, - calculate: Union[str, expr.core.Expression, UndefinedType] = Undefined, - **kwargs: Union[str, expr.core.Expression], + as_: Union[str, core.FieldName, UndefinedType] = Undefined, + calculate: Union[ + str, core.Expr, expr.core.Expression, UndefinedType + ] = Undefined, + **kwargs: Union[str, core.Expr, expr.core.Expression], ) -> Self: """ Add a :class:`CalculateTransform` to the schema. @@ -1581,13 +1601,13 @@ def transform_calculate( def transform_density( self, - density: str, - as_: Union[List[str], UndefinedType] = Undefined, + density: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, bandwidth: Union[float, UndefinedType] = Undefined, counts: Union[bool, UndefinedType] = Undefined, cumulative: Union[bool, UndefinedType] = Undefined, extent: Union[List[float], UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, maxsteps: Union[int, UndefinedType] = Undefined, minsteps: Union[int, UndefinedType] = Undefined, steps: Union[int, UndefinedType] = Undefined, @@ -1649,13 +1669,15 @@ def transform_density( def transform_impute( self, - impute: str, - key: str, + impute: Union[str, core.FieldName], + key: Union[str, core.FieldName], frame: Union[List[Optional[int]], UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, keyvals: Union[List[Any], core.ImputeSequence, UndefinedType] = Undefined, method: Union[ - Literal["value", "mean", "median", "max", "min"], UndefinedType + Literal["value", "mean", "median", "max", "min"], + core.ImputeMethod, + UndefinedType, ] = Undefined, value=Undefined, ) -> Self: @@ -1724,7 +1746,7 @@ def transform_joinaggregate( joinaggregate: Union[ List[core.JoinAggregateFieldDef], UndefinedType ] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, **kwargs: str, ) -> Self: """ @@ -1777,7 +1799,9 @@ def transform_joinaggregate( core.JoinAggregateTransform(joinaggregate=joinaggregate, groupby=groupby) ) - def transform_extent(self, extent: str, param: str) -> Self: + def transform_extent( + self, extent: Union[str, core.FieldName], param: Union[str, core.ParameterName] + ) -> Self: """Add a :class:`ExtentTransform` to the spec. Parameters @@ -1800,6 +1824,7 @@ def transform_filter( self, filter: Union[ str, + core.Expr, expr.core.Expression, core.Predicate, Parameter, @@ -1837,7 +1862,9 @@ def transform_filter( return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) def transform_flatten( - self, flatten: List[str], as_: Union[List[str], UndefinedType] = Undefined + self, + flatten: List[Union[str, core.FieldName]], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, ) -> Self: """Add a :class:`FlattenTransform` to the schema. @@ -1867,7 +1894,9 @@ def transform_flatten( ) def transform_fold( - self, fold: List[str], as_: Union[List[str], UndefinedType] = Undefined + self, + fold: List[Union[str, core.FieldName]], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, ) -> Self: """Add a :class:`FoldTransform` to the spec. @@ -1893,11 +1922,11 @@ def transform_fold( def transform_loess( self, - on: str, - loess: str, - as_: Union[List[str], UndefinedType] = Undefined, + on: Union[str, core.FieldName], + loess: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, bandwidth: Union[float, UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, ) -> Self: """Add a :class:`LoessTransform` to the spec. @@ -1937,7 +1966,9 @@ def transform_lookup( self, lookup: Union[str, UndefinedType] = Undefined, from_: Union[core.LookupData, core.LookupSelection, UndefinedType] = Undefined, - as_: Union[str, List[str], UndefinedType] = Undefined, + as_: Union[ + Union[str, core.FieldName], List[Union[str, core.FieldName]], UndefinedType + ] = Undefined, default: Union[str, UndefinedType] = Undefined, **kwargs, ) -> Self: @@ -1990,11 +2021,11 @@ def transform_lookup( def transform_pivot( self, - pivot: str, - value: str, - groupby: Union[List[str], UndefinedType] = Undefined, + pivot: Union[str, core.FieldName], + value: Union[str, core.FieldName], + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, limit: Union[int, UndefinedType] = Undefined, - op: Union[str, UndefinedType] = Undefined, + op: Union[str, core.AggregateOp, UndefinedType] = Undefined, ) -> Self: """Add a :class:`PivotTransform` to the chart. @@ -2036,9 +2067,9 @@ def transform_pivot( def transform_quantile( self, - quantile: str, - as_: Union[List[str], UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + quantile: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, probs: Union[List[float], UndefinedType] = Undefined, step: Union[float, UndefinedType] = Undefined, ) -> Self: @@ -2082,11 +2113,11 @@ def transform_quantile( def transform_regression( self, - on: str, - regression: str, - as_: Union[List[str], UndefinedType] = Undefined, + on: Union[str, core.FieldName], + regression: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, extent: Union[List[float], UndefinedType] = Undefined, - groupby: Union[List[str], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, method: Union[ Literal["linear", "log", "exp", "pow", "quad", "poly"], UndefinedType ] = Undefined, @@ -2169,9 +2200,9 @@ def transform_sample(self, sample: int = 1000) -> Self: def transform_stack( self, - as_: Union[str, List[str]], - stack: str, - groupby: List[str], + as_: Union[str, core.FieldName, List[str]], + stack: Union[str, core.FieldName], + groupby: List[Union[str, core.FieldName]], offset: Union[ Literal["zero", "center", "normalize"], UndefinedType ] = Undefined, @@ -2213,8 +2244,8 @@ def transform_stack( def transform_timeunit( self, - as_: Union[str, UndefinedType] = Undefined, - field: Union[str, UndefinedType] = Undefined, + as_: Union[str, core.FieldName, UndefinedType] = Undefined, + field: Union[str, core.FieldName, UndefinedType] = Undefined, timeUnit: Union[str, core.TimeUnit, UndefinedType] = Undefined, **kwargs: str, ) -> Self: @@ -2301,7 +2332,9 @@ def transform_window( frame: Union[List[Optional[int]], UndefinedType] = Undefined, groupby: Union[List[str], UndefinedType] = Undefined, ignorePeers: Union[bool, UndefinedType] = Undefined, - sort: Union[List[Union[core.SortField, dict[str, str]]], UndefinedType] = Undefined, + sort: Union[ + List[Union[core.SortField, dict[str, str]]], UndefinedType + ] = Undefined, **kwargs: str, ) -> Self: """Add a :class:`WindowTransform` to the schema From 3e853883b948e7fd26f0fada18f16717cc963c10 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 17:39:00 +0200 Subject: [PATCH 14/18] Add Parameter hints --- altair/vegalite/v5/api.py | 43 +++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index dbcdf1351..d32a9aea4 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1218,38 +1218,45 @@ def properties(self, **kwargs) -> Self: def project( self, - type: Union[str, core.ProjectionType, core.ExprRef, UndefinedType] = Undefined, + type: Union[ + str, core.ProjectionType, core.ExprRef, Parameter, UndefinedType + ] = Undefined, center: Union[ - List[float], core.Vector2number, core.ExprRef, UndefinedType + List[float], core.Vector2number, core.ExprRef, Parameter, UndefinedType ] = Undefined, - clipAngle: Union[float, core.ExprRef, UndefinedType] = Undefined, + clipAngle: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, clipExtent: Union[ - List[List[float]], core.Vector2Vector2number, core.ExprRef, UndefinedType + List[List[float]], + core.Vector2Vector2number, + core.ExprRef, + Parameter, + UndefinedType, ] = Undefined, - coefficient: Union[float, core.ExprRef, UndefinedType] = Undefined, - distance: Union[float, core.ExprRef, UndefinedType] = Undefined, - fraction: Union[float, core.ExprRef, UndefinedType] = Undefined, - lobes: Union[float, core.ExprRef, UndefinedType] = Undefined, - parallel: Union[float, core.ExprRef, UndefinedType] = Undefined, - precision: Union[float, core.ExprRef, UndefinedType] = Undefined, - radius: Union[float, core.ExprRef, UndefinedType] = Undefined, - ratio: Union[float, core.ExprRef, UndefinedType] = Undefined, - reflectX: Union[bool, core.ExprRef, UndefinedType] = Undefined, - reflectY: Union[bool, core.ExprRef, UndefinedType] = Undefined, + coefficient: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + distance: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + fraction: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + lobes: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + parallel: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + precision: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + radius: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + ratio: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + reflectX: Union[bool, core.ExprRef, Parameter, UndefinedType] = Undefined, + reflectY: Union[bool, core.ExprRef, Parameter, UndefinedType] = Undefined, rotate: Union[ List[float], core.Vector2number, core.Vector3number, core.ExprRef, + Parameter, UndefinedType, ] = Undefined, - scale: Union[float, core.ExprRef, UndefinedType] = Undefined, + scale: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, spacing: Union[ - float, core.Vector2number, core.ExprRef, UndefinedType + float, core.Vector2number, core.ExprRef, Parameter, UndefinedType ] = Undefined, - tilt: Union[float, core.ExprRef, UndefinedType] = Undefined, + tilt: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, translate: Union[ - List[float], core.Vector2number, core.ExprRef, UndefinedType + List[float], core.Vector2number, core.ExprRef, Parameter, UndefinedType ] = Undefined, **kwds, ) -> Self: From 1d76d0cf61e85f90e8a895783746e373dfc4e514 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 17:46:27 +0200 Subject: [PATCH 15/18] Minor fix --- altair/vegalite/v5/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index d32a9aea4..57dbe118b 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2340,7 +2340,7 @@ def transform_window( groupby: Union[List[str], UndefinedType] = Undefined, ignorePeers: Union[bool, UndefinedType] = Undefined, sort: Union[ - List[Union[core.SortField, dict[str, str]]], UndefinedType + List[Union[core.SortField, TypingDict[str, str]]], UndefinedType ] = Undefined, **kwargs: str, ) -> Self: From e68e456cea729171eae59573f9439a307d54acdc Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Fri, 29 Sep 2023 18:01:52 +0200 Subject: [PATCH 16/18] Exclude IO from __init__.py --- tools/update_init_file.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/update_init_file.py b/tools/update_init_file.py index d657965bf..c02e63a8c 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -6,7 +6,7 @@ import sys from pathlib import Path from os.path import abspath, dirname, join -from typing import TypeVar, Type, cast, List, Any, Optional, Iterable, Union +from typing import TypeVar, Type, cast, List, Any, Optional, Iterable, Union, IO import black @@ -80,6 +80,7 @@ def _is_relevant_attribute(attr_name): or attr is Optional or attr is Iterable or attr is Union + or attr is IO or attr_name == "TypingDict" ): return False From 343730befb09afbb12ae9bdce31bd58cc4e7511f Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 28 Oct 2023 13:17:23 +0200 Subject: [PATCH 17/18] Apply code suggestion Co-authored-by: Mattijn van Hoek --- altair/vegalite/v5/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 57dbe118b..6b3fafcb1 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -33,7 +33,7 @@ else: from typing_extensions import Self -_ChartDataType = Union[_DataType, core.Data, str, UndefinedType, core.Generator] +_ChartDataType = Union[_DataType, core.Data, str, core.Generator, UndefinedType] # ------------------------------------------------------------------------ From a3244dcf8c9bf1662e45f6b463804eada9af5051 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 28 Oct 2023 13:24:14 +0200 Subject: [PATCH 18/18] Remove ignore statement which is redundant for new versions of mypy --- altair/vegalite/v5/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index a9e33e9e1..259cce0e1 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1593,7 +1593,7 @@ def transform_calculate( # an edge case and it's not worth changing the type annotation # in this function to account for it as it could be confusing to # users. - as_ = kwargs.pop("as", Undefined) # type: ignore[assignment] + as_ = kwargs.pop("as", Undefined) elif "as" in kwargs: raise ValueError( "transform_calculate: both 'as_' and 'as' passed as arguments."