From 19eac9377cfae089d13f3d4e906154d2bd01552f Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sat, 28 Oct 2023 13:36:34 +0200 Subject: [PATCH] Type hint api.py (#3143) * First batch of type hints * Some more * Add type hints to some more transform_methods * Add type hints to some more transform_methods * Finish first pass of type hints for public objects * Move UndefinedType always to last item of Union * Type hint transform_filter * Type hint Chart.__init__ * Type hint .facet * Minor stuff * Various mypy error fixes * Fix remaining mypy errors * Add more core Altair classes to type hints * Add Parameter hints * Minor fix * Exclude IO from __init__.py * Apply code suggestion Co-authored-by: Mattijn van Hoek * Remove ignore statement which is redundant for new versions of mypy --------- Co-authored-by: Mattijn van Hoek --- altair/vegalite/v5/api.py | 684 +++++++++++++++++++++++--------------- tools/update_init_file.py | 3 +- 2 files changed, 427 insertions(+), 260 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index a0b4b91bb..259cce0e1 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,9 +8,9 @@ from toolz.curried import pipe as _pipe import itertools import sys -from typing import cast, List, Optional, Any, Iterable, Union, Literal +from typing import cast, List, Optional, Any, Iterable, Union, Literal, IO -# Have to rename it here as else it overlaps with schema.core.Type +# Have to rename it here as else it overlaps with schema.core.Type and schema.core.Dict from typing import Type as TypingType from typing import Dict as TypingDict @@ -26,22 +26,25 @@ compile_with_vegafusion as _compile_with_vegafusion, ) from ...utils.core import _DataFrameLike +from ...utils.data import _DataType if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self +_ChartDataType = Union[_DataType, core.Data, str, core.Generator, UndefinedType] + # ------------------------------------------------------------------------ # Data Utilities -def _dataset_name(values): +def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str: """Generate a unique hash of the data Parameters ---------- - values : list or dict - A list/dict representation of data values. + values : list, dict, core.InlineDataset + A representation of data values. Returns ------- @@ -136,7 +139,7 @@ class LookupData(core.LookupData): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: """Convert the chart to a dictionary suitable for JSON export.""" copy = self.copy(deep=False) copy.data = _prepare_data(copy.data, kwargs.get("context")) @@ -150,7 +153,7 @@ class FacetMapping(core.FacetMapping): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def to_dict(self, *args, **kwargs): + def to_dict(self, *args, **kwargs) -> dict: copy = self.copy(deep=False) context = kwargs.get("context", {}) data = context.get("data", None) @@ -172,8 +175,8 @@ def to_dict(self, *args, **kwargs): TOPLEVEL_ONLY_KEYS = {"background", "config", "autosize", "padding", "$schema"} -def _get_channels_mapping(): - mapping = {} +def _get_channels_mapping() -> TypingDict[TypingType[core.SchemaBase], str]: + mapping: TypingDict[TypingType[core.SchemaBase], str] = {} for attr in dir(channels): cls = getattr(channels, attr) if isinstance(cls, type) and issubclass(cls, core.SchemaBase): @@ -293,37 +296,37 @@ def __or__(self, other): class ParameterExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> TypingDict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "ParameterExpression": return ParameterExpression(expr=expr) class SelectionExpression(expr.core.OperatorMixin, object): - def __init__(self, expr): + def __init__(self, expr) -> None: self.expr = expr - def to_dict(self): + def to_dict(self) -> TypingDict[str, str]: return {"expr": repr(self.expr)} - def _to_expr(self): + def _to_expr(self) -> str: return repr(self.expr) - def _from_expr(self, expr): + def _from_expr(self, expr) -> "SelectionExpression": return SelectionExpression(expr=expr) -def check_fields_and_encodings(parameter, field_name): +def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: for prop in ["fields", "encodings"]: try: - if field_name in getattr(parameter.param.select, prop): + if field_name in getattr(parameter.param.select, prop): # type: ignore[union-attr] return True except (AttributeError, TypeError): pass @@ -335,20 +338,24 @@ def check_fields_and_encodings(parameter, field_name): # Top-Level Functions -def value(value, **kwargs): +def value(value, **kwargs) -> dict: """Specify a value for use in an encoding""" return dict(value=value, **kwargs) def param( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, + name: Optional[str] = None, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined, **kwds, -): - """Create a named parameter. See https://altair-viz.github.io/user_guide/interactions.html for examples. Although both variable parameters and selection parameters can be created using this 'param' function, to create a selection parameter, it is recommended to use either 'selection_point' or 'selection_interval' instead. +) -> Parameter: + """Create a named parameter. + See https://altair-viz.github.io/user_guide/interactions.html for examples. + Although both variable parameters and selection parameters can be created using + this 'param' function, to create a selection parameter, it is recommended to use + either 'selection_point' or 'selection_interval' instead. Parameters ---------- @@ -358,14 +365,14 @@ def param( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) For selection parameters, the predicate of empty selections returns True by default. Override this behavior, by setting this property 'empty=False'. - expr : :class:`Expr` (optional) + expr : str, Expression (optional) An expression for the value of the parameter. This expression may include other parameters, in which case the parameter will automatically update in response to upstream parameter changes. @@ -433,7 +440,9 @@ def param( return parameter -def _selection(type=Undefined, **kwds): +def _selection( + type: Union[Literal["interval", "point"], UndefinedType] = Undefined, **kwds +) -> Parameter: # We separate out the parameter keywords from the selection keywords param_kwds = {} @@ -441,6 +450,7 @@ def _selection(type=Undefined, **kwds): if kwd in kwds: param_kwds[kwd] = kwds.pop(kwd) + select: Union[core.IntervalSelectionConfig, core.PointSelectionConfig] if type == "interval": select = core.IntervalSelectionConfig(type=type, **kwds) elif type == "point": @@ -463,7 +473,9 @@ def _selection(type=Undefined, **kwds): message="""'selection' is deprecated. Use 'selection_point()' or 'selection_interval()' instead; these functions also include more helpful docstrings.""" ) -def selection(type=Undefined, **kwds): +def selection( + type: Union[Literal["interval", "point"], UndefinedType] = Undefined, **kwds +) -> Parameter: """ Users are recommended to use either 'selection_point' or 'selection_interval' instead, depending on the type of parameter they want to create. @@ -487,20 +499,20 @@ def selection(type=Undefined, **kwds): def selection_interval( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, - encodings=Undefined, - on=Undefined, - clear=Undefined, - resolve=Undefined, - mark=Undefined, - translate=Undefined, - zoom=Undefined, + name: Optional[str] = None, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined, + encodings: Union[List[str], UndefinedType] = Undefined, + on: Union[str, UndefinedType] = Undefined, + clear: Union[str, bool, UndefinedType] = Undefined, + resolve: Union[Literal["global", "union", "intersect"], UndefinedType] = Undefined, + mark: Union[core.Mark, UndefinedType] = Undefined, + translate: Union[str, bool, UndefinedType] = Undefined, + zoom: Union[str, bool, UndefinedType] = Undefined, **kwds, -): +) -> Parameter: """Create an interval selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Interval selection parameters are used to select a continuous range of data values on drag, whereas point selection parameters (`selection_point`) are used to select multiple discrete data values.) Parameters @@ -511,7 +523,7 @@ def selection_interval( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) @@ -599,20 +611,20 @@ def selection_interval( def selection_point( - name=None, - value=Undefined, - bind=Undefined, - empty=Undefined, - expr=Undefined, - encodings=Undefined, - fields=Undefined, - on=Undefined, - clear=Undefined, - resolve=Undefined, - toggle=Undefined, - nearest=Undefined, + name: Optional[str] = None, + value: Union[Any, UndefinedType] = Undefined, + bind: Union[core.Binding, str, UndefinedType] = Undefined, + empty: Union[bool, UndefinedType] = Undefined, + expr: Union[core.Expr, UndefinedType] = Undefined, + encodings: Union[List[str], UndefinedType] = Undefined, + fields: Union[List[str], UndefinedType] = Undefined, + on: Union[str, UndefinedType] = Undefined, + clear: Union[str, bool, UndefinedType] = Undefined, + resolve: Union[Literal["global", "union", "intersect"], UndefinedType] = Undefined, + toggle: Union[str, bool, UndefinedType] = Undefined, + nearest: Union[bool, UndefinedType] = Undefined, **kwds, -): +) -> Parameter: """Create a point selection parameter. Selection parameters define data queries that are driven by direct manipulation from user input (e.g., mouse clicks or drags). Point selection parameters are used to select multiple discrete data values; the first value is selected on click and additional values toggled on shift-click. To select a continuous range of data values on drag interval selection parameters (`selection_interval`) can be used instead. Parameters @@ -623,7 +635,7 @@ def selection_point( value : any (optional) The default value of the parameter. If not specified, the parameter will be created without a default value. - bind : :class:`Binding` (optional) + bind : :class:`Binding`, str (optional) Binds the parameter to an external input element such as a slider, selection list or radio button group. empty : boolean (optional) @@ -763,12 +775,21 @@ def binding_range(**kwargs): # TODO: update the docstring -def condition(predicate, if_true, if_false, **kwargs): +def condition( + predicate: Union[ + Parameter, str, expr.Expression, core.Expr, core.PredicateComposition, dict + ], + # Types of these depends on where the condition is used so we probably + # can't be more specific here. + if_true: Any, + if_false: Any, + **kwargs, +) -> Union[dict, core.SchemaBase]: """A conditional attribute or encoding Parameters ---------- - predicate: Selection, PredicateComposition, expr.Expression, dict, or string + predicate: Parameter, PredicateComposition, expr.Expression, dict, or string the selection predicate or test predicate for the condition. if a string is passed, it will be treated as a test operand. if_true: @@ -785,15 +806,21 @@ def condition(predicate, if_true, if_false, **kwargs): """ test_predicates = (str, expr.Expression, core.PredicateComposition) + condition: TypingDict[ + str, Union[bool, str, expr.core.Expression, core.PredicateComposition] + ] if isinstance(predicate, Parameter): - if predicate.param_type == "selection" or predicate.param.expr is Undefined: + if ( + predicate.param_type == "selection" + or getattr(predicate.param, "expr", Undefined) is Undefined + ): condition = {"param": predicate.name} if "empty" in kwargs: condition["empty"] = kwargs.pop("empty") elif isinstance(predicate.empty, bool): condition["empty"] = predicate.empty else: - condition = {"test": predicate.param.expr} + condition = {"test": getattr(predicate.param, "expr", Undefined)} elif isinstance(predicate, test_predicates): condition = {"test": predicate} elif isinstance(predicate, dict): @@ -817,6 +844,7 @@ def condition(predicate, if_true, if_false, **kwargs): if_true.update(kwargs) condition.update(if_true) + selection: Union[dict, core.SchemaBase] if isinstance(if_false, core.SchemaBase): # For the selection, the channel definitions all allow selections # already. So use this SchemaBase wrapper if possible. @@ -838,7 +866,7 @@ def condition(predicate, if_true, if_false, **kwargs): class TopLevelMixin(mixins.ConfigMethodMixin): """Mixin for top-level chart objects such as Chart, LayeredChart, etc.""" - _class_is_valid_at_instantiation = False + _class_is_valid_at_instantiation: bool = False def to_dict( self, @@ -1006,12 +1034,12 @@ def to_json( def to_html( self, - base_url="https://cdn.jsdelivr.net/npm", - output_div="vis", - embed_options=None, - json_kwds=None, - fullhtml=True, - requirejs=False, + base_url: str = "https://cdn.jsdelivr.net/npm", + output_div: str = "vis", + embed_options: Optional[dict] = None, + json_kwds: Optional[dict] = None, + fullhtml: bool = True, + requirejs: bool = False, ) -> str: return utils.spec_to_html( self.to_dict(), @@ -1029,15 +1057,15 @@ def to_html( def save( self, - fp, - format=None, - override_data_transformer=True, - scale_factor=1.0, - vegalite_version=VEGALITE_VERSION, - vega_version=VEGA_VERSION, - vegaembed_version=VEGAEMBED_VERSION, + fp: Union[str, IO], + format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None, + override_data_transformer: bool = True, + scale_factor: float = 1.0, + vegalite_version: str = VEGALITE_VERSION, + vega_version: str = VEGA_VERSION, + vegaembed_version: str = VEGAEMBED_VERSION, **kwargs, - ): + ) -> None: """Save a chart to file in a variety of formats Supported formats are json, html, png, svg, pdf; the last three require @@ -1088,32 +1116,33 @@ def save( # Fallback for when rendering fails; the full repr is too long to be # useful in nearly all cases. - def __repr__(self): + def __repr__(self) -> str: return "alt.{}(...)".format(self.__class__.__name__) # Layering and stacking - def __add__(self, other): + def __add__(self, other) -> "LayerChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be layered.") return layer(self, other) - def __and__(self, other): + def __and__(self, other) -> "VConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") - return vconcat(self, other) + # Too difficult to type check this + return vconcat(self, other) # type: ignore[arg-type] - def __or__(self, other): + def __or__(self, other) -> "HConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") return hconcat(self, other) def repeat( self, - repeat=Undefined, - row=Undefined, - column=Undefined, - layer=Undefined, - columns=Undefined, + repeat: Union[List[str], UndefinedType] = Undefined, + row: Union[List[str], UndefinedType] = Undefined, + column: Union[List[str], UndefinedType] = Undefined, + layer: Union[List[str], UndefinedType] = Undefined, + columns: Union[int, UndefinedType] = Undefined, **kwargs, ) -> "RepeatChart": """Return a RepeatChart built from the chart @@ -1155,14 +1184,16 @@ def repeat( elif repeat_specified and layer_specified: raise ValueError("repeat argument cannot be combined with layer argument.") + repeat_arg: Union[List[str], core.LayerRepeatMapping, core.RepeatMapping] if repeat_specified: - repeat = repeat + assert not isinstance(repeat, UndefinedType) # For mypy + repeat_arg = repeat elif layer_specified: - repeat = core.LayerRepeatMapping(layer=layer, row=row, column=column) + repeat_arg = core.LayerRepeatMapping(layer=layer, row=row, column=column) else: - repeat = core.RepeatMapping(row=row, column=column) + repeat_arg = core.RepeatMapping(row=row, column=column) - return RepeatChart(spec=self, repeat=repeat, columns=columns, **kwargs) + return RepeatChart(spec=self, repeat=repeat_arg, columns=columns, **kwargs) def properties(self, **kwargs) -> Self: """Set top-level properties of the Chart. @@ -1187,25 +1218,46 @@ def properties(self, **kwargs) -> Self: def project( self, - type=Undefined, - center=Undefined, - clipAngle=Undefined, - clipExtent=Undefined, - coefficient=Undefined, - distance=Undefined, - fraction=Undefined, - lobes=Undefined, - parallel=Undefined, - precision=Undefined, - radius=Undefined, - ratio=Undefined, - reflectX=Undefined, - reflectY=Undefined, - rotate=Undefined, - scale=Undefined, - spacing=Undefined, - tilt=Undefined, - translate=Undefined, + type: Union[ + str, core.ProjectionType, core.ExprRef, Parameter, UndefinedType + ] = Undefined, + center: Union[ + List[float], core.Vector2number, core.ExprRef, Parameter, UndefinedType + ] = Undefined, + clipAngle: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + clipExtent: Union[ + List[List[float]], + core.Vector2Vector2number, + core.ExprRef, + Parameter, + UndefinedType, + ] = Undefined, + coefficient: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + distance: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + fraction: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + lobes: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + parallel: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + precision: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + radius: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + ratio: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + reflectX: Union[bool, core.ExprRef, Parameter, UndefinedType] = Undefined, + reflectY: Union[bool, core.ExprRef, Parameter, UndefinedType] = Undefined, + rotate: Union[ + List[float], + core.Vector2number, + core.Vector3number, + core.ExprRef, + Parameter, + UndefinedType, + ] = Undefined, + scale: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + spacing: Union[ + float, core.Vector2number, core.ExprRef, Parameter, UndefinedType + ] = Undefined, + tilt: Union[float, core.ExprRef, Parameter, UndefinedType] = Undefined, + translate: Union[ + List[float], core.Vector2number, core.ExprRef, Parameter, UndefinedType + ] = Undefined, **kwds, ) -> Self: """Add a geographic projection to the chart. @@ -1220,7 +1272,7 @@ def project( Parameters ---------- - type : ProjectionType + type : str The cartographic projection to use. This value is case-insensitive, for example `"albers"` and `"Albers"` indicate the same projection type. You can find all valid projection types [in the @@ -1242,16 +1294,28 @@ def project( left-side of the viewport, `y0` is the top, `x1` is the right and `y1` is the bottom. If `null`, no viewport clipping is performed. coefficient : float + The coefficient parameter for the ``hammer`` projection. + **Default value:** ``2`` distance : float + For the ``satellite`` projection, the distance from the center of the sphere to the + point of view, as a proportion of the sphere’s radius. The recommended maximum clip + angle for a given ``distance`` is acos(1 / distance) converted to degrees. If tilt + is also applied, then more conservative clipping may be necessary. + **Default value:** ``2.0`` fraction : float + The fraction parameter for the ``bottomley`` projection. + **Default value:** ``0.5``, corresponding to a sin(ψ) where ψ = π/6. lobes : float - + The number of lobes in projections that support multi-lobe views: ``berghaus``, + ``gingery``, or ``healpix``. The default value varies based on the projection type. parallel : float - - precision : Mapping(required=[length]) + For conic projections, the `two standard parallels + `__ that define the map layout. + The default depends on the specific conic projection used. + precision : float Sets the threshold for the projection’s [adaptive resampling](http://bl.ocks.org/mbostock/3795544) to the specified value in pixels. This value corresponds to the [Douglas–Peucker @@ -1259,13 +1323,15 @@ def project( If precision is not specified, returns the projection’s current resampling precision which defaults to `√0.5 ≅ 0.70710…`. radius : float - + The radius parameter for the ``airy`` or ``gingery`` projection. The default value + varies based on the projection type. ratio : float - + The ratio parameter for the ``hill``, ``hufnagel``, or ``wagner`` projections. The + default value varies based on the projection type. reflectX : boolean - + Sets whether or not the x-dimension is reflected (negated) in the output. reflectY : boolean - + Sets whether or not the y-dimension is reflected (negated) in the output. rotate : List(float) Sets the projection’s three-axis rotation to the specified angles, which must be a two- or three-element array of numbers [`lambda`, `phi`, `gamma`] specifying the @@ -1274,14 +1340,21 @@ def project( **Default value:** `[0, 0, 0]` scale : float - Sets the projection's scale (zoom) value, overriding automatic fitting. - + The projection’s scale (zoom) factor, overriding automatic fitting. The default + scale is projection-specific. The scale factor corresponds linearly to the distance + between projected points; however, scale factor values are not equivalent across + projections. spacing : float + The spacing parameter for the ``lagrange`` projection. + **Default value:** ``0.5`` tilt : float + The tilt angle (in degrees) for the ``satellite`` projection. + **Default value:** ``0``. translate : List(float) - Sets the projection's translation (pan) value, overriding automatic fitting. + The projection’s translation offset as a two-element array ``[tx, ty]``, + overriding automatic fitting. """ projection = core.Projection( @@ -1308,16 +1381,19 @@ def project( ) return self.properties(projection=projection) - def _add_transform(self, *transforms): + def _add_transform(self, *transforms: core.Transform) -> Self: """Copy the chart and add specified transforms to chart.transform""" - copy = self.copy(deep=["transform"]) + copy = self.copy(deep=["transform"]) # type: ignore[attr-defined] if copy.transform is Undefined: copy.transform = [] copy.transform.extend(transforms) return copy def transform_aggregate( - self, aggregate=Undefined, groupby=Undefined, **kwds + self, + aggregate: Union[List[core.AggregatedFieldDef], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + **kwds: Union[TypingDict[str, Any], str], ) -> Self: """ Add an :class:`AggregateTransform` to the schema. @@ -1329,7 +1405,7 @@ def transform_aggregate( groupby : List(string) The data fields to group by. If not specified, a single group containing all data objects will be used. - **kwds : + **kwds : Union[TypingDict[str, Any], str] additional keywords are converted to aggregates using standard shorthand parsing. @@ -1387,12 +1463,21 @@ def transform_aggregate( "field": parsed.get("field", Undefined), "op": parsed.get("aggregate", Undefined), } + assert not isinstance(aggregate, UndefinedType) # For mypy aggregate.append(core.AggregatedFieldDef(**dct)) return self._add_transform( core.AggregateTransform(aggregate=aggregate, groupby=groupby) ) - def transform_bin(self, as_=Undefined, field=Undefined, bin=True, **kwargs) -> Self: + def transform_bin( + self, + as_: Union[ + str, core.FieldName, List[Union[str, core.FieldName]], UndefinedType + ] = Undefined, + field: Union[str, core.FieldName, UndefinedType] = Undefined, + bin: Union[Literal[True], core.BinParams] = True, + **kwargs, + ) -> Self: """ Add a :class:`BinTransform` to the schema. @@ -1448,7 +1533,14 @@ def transform_bin(self, as_=Undefined, field=Undefined, bin=True, **kwargs) -> S kwargs["field"] = field return self._add_transform(core.BinTransform(**kwargs)) - def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> Self: + def transform_calculate( + self, + as_: Union[str, core.FieldName, UndefinedType] = Undefined, + calculate: Union[ + str, core.Expr, expr.core.Expression, UndefinedType + ] = Undefined, + **kwargs: Union[str, core.Expr, expr.core.Expression], + ) -> Self: """ Add a :class:`CalculateTransform` to the schema. @@ -1456,8 +1548,8 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S ---------- as_ : string The field for storing the computed formula value. - calculate : string or alt.expr expression - A `expression `__ + calculate : string or alt.expr.Expression + An `expression `__ string. Use the variable ``datum`` to refer to the current data object. **kwargs transforms can also be passed by keyword argument; see Examples @@ -1481,7 +1573,7 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S It's also possible to pass the ``CalculateTransform`` arguments directly: - >>> kwds = {'as': 'y', 'calculate': '2 * sin(datum.x)'} + >>> kwds = {'as_': 'y', 'calculate': '2 * sin(datum.x)'} >>> chart = alt.Chart().transform_calculate(**kwds) >>> chart.transform[0] CalculateTransform({ @@ -1497,6 +1589,10 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S alt.CalculateTransform : underlying transform object """ if as_ is Undefined: + # Ignoring assignment error as passing 'as' as a keyword argument is + # an edge case and it's not worth changing the type annotation + # in this function to account for it as it could be confusing to + # users. as_ = kwargs.pop("as", Undefined) elif "as" in kwargs: raise ValueError( @@ -1512,16 +1608,16 @@ def transform_calculate(self, as_=Undefined, calculate=Undefined, **kwargs) -> S def transform_density( self, - density, - as_=Undefined, - bandwidth=Undefined, - counts=Undefined, - cumulative=Undefined, - extent=Undefined, - groupby=Undefined, - maxsteps=Undefined, - minsteps=Undefined, - steps=Undefined, + density: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + bandwidth: Union[float, UndefinedType] = Undefined, + counts: Union[bool, UndefinedType] = Undefined, + cumulative: Union[bool, UndefinedType] = Undefined, + extent: Union[List[float], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + maxsteps: Union[int, UndefinedType] = Undefined, + minsteps: Union[int, UndefinedType] = Undefined, + steps: Union[int, UndefinedType] = Undefined, ) -> Self: """Add a :class:`DensityTransform` to the spec. @@ -1551,13 +1647,13 @@ def transform_density( groupby : List(str) The data fields to group by. If not specified, a single group containing all data objects will be used. - maxsteps : float + maxsteps : int The maximum number of samples to take along the extent domain for plotting the density. **Default value:** ``200`` - minsteps : float + minsteps : int The minimum number of samples to take along the extent domain for plotting the density. **Default value:** ``25`` - steps : float + steps : int The exact number of samples to take along the extent domain for plotting the density. If specified, overrides both minsteps and maxsteps to set an exact number of uniform samples. Potentially useful in conjunction with a fixed extent to ensure @@ -1580,12 +1676,16 @@ def transform_density( def transform_impute( self, - impute, - key, - frame=Undefined, - groupby=Undefined, - keyvals=Undefined, - method=Undefined, + impute: Union[str, core.FieldName], + key: Union[str, core.FieldName], + frame: Union[List[Optional[int]], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + keyvals: Union[List[Any], core.ImputeSequence, UndefinedType] = Undefined, + method: Union[ + Literal["value", "mean", "median", "max", "min"], + core.ImputeMethod, + UndefinedType, + ] = Undefined, value=Undefined, ) -> Self: """ @@ -1599,7 +1699,7 @@ def transform_impute( A key field that uniquely identifies data objects within a group. Missing key values (those occurring in the data but not in the current group) will be imputed. - frame : List(anyOf(None, float)) + frame : List(anyOf(None, int)) A frame specification as a two-element array used to control the window over which the specified method is applied. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded @@ -1649,7 +1749,12 @@ def transform_impute( ) def transform_joinaggregate( - self, joinaggregate=Undefined, groupby=Undefined, **kwargs + self, + joinaggregate: Union[ + List[core.JoinAggregateFieldDef], UndefinedType + ] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + **kwargs: str, ) -> Self: """ Add a :class:`JoinAggregateTransform` to the schema. @@ -1695,12 +1800,15 @@ def transform_joinaggregate( "field": parsed.get("field", Undefined), "op": parsed.get("aggregate", Undefined), } + assert not isinstance(joinaggregate, UndefinedType) # For mypy joinaggregate.append(core.JoinAggregateFieldDef(**dct)) return self._add_transform( core.JoinAggregateTransform(joinaggregate=joinaggregate, groupby=groupby) ) - def transform_extent(self, extent: str, param: str) -> Self: + def transform_extent( + self, extent: Union[str, core.FieldName], param: Union[str, core.ParameterName] + ) -> Self: """Add a :class:`ExtentTransform` to the spec. Parameters @@ -1719,7 +1827,20 @@ def transform_extent(self, extent: str, param: str) -> Self: return self._add_transform(core.ExtentTransform(extent=extent, param=param)) # TODO: Update docstring - def transform_filter(self, filter, **kwargs) -> Self: + def transform_filter( + self, + filter: Union[ + str, + core.Expr, + expr.core.Expression, + core.Predicate, + Parameter, + core.PredicateComposition, + # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])} + TypingDict[str, Union[core.Predicate, str, list, bool]], + ], + **kwargs, + ) -> Self: """ Add a :class:`FilterTransform` to the schema. @@ -1737,11 +1858,6 @@ def transform_filter(self, filter, **kwargs) -> Self: ------- self : Chart object returns chart to allow for chaining - - See Also - -------- - alt.FilterTransform : underlying transform object - """ if isinstance(filter, Parameter): new_filter: TypingDict[str, Union[bool, str]] = {"param": filter.name} @@ -1749,10 +1865,14 @@ def transform_filter(self, filter, **kwargs) -> Self: new_filter["empty"] = kwargs.pop("empty") elif isinstance(filter.empty, bool): new_filter["empty"] = filter.empty - filter = new_filter + filter = new_filter # type: ignore[assignment] return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) - def transform_flatten(self, flatten, as_=Undefined) -> Self: + def transform_flatten( + self, + flatten: List[Union[str, core.FieldName]], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + ) -> Self: """Add a :class:`FlattenTransform` to the schema. Parameters @@ -1780,7 +1900,11 @@ def transform_flatten(self, flatten, as_=Undefined) -> Self: core.FlattenTransform(flatten=flatten, **{"as": as_}) ) - def transform_fold(self, fold, as_=Undefined) -> Self: + def transform_fold( + self, + fold: List[Union[str, core.FieldName]], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + ) -> Self: """Add a :class:`FoldTransform` to the spec. Parameters @@ -1805,11 +1929,11 @@ def transform_fold(self, fold, as_=Undefined) -> Self: def transform_loess( self, - on, - loess, - as_=Undefined, - bandwidth=Undefined, - groupby=Undefined, + on: Union[str, core.FieldName], + loess: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + bandwidth: Union[float, UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, ) -> Self: """Add a :class:`LoessTransform` to the spec. @@ -1847,10 +1971,12 @@ def transform_loess( def transform_lookup( self, - lookup=Undefined, - from_=Undefined, - as_=Undefined, - default=Undefined, + lookup: Union[str, UndefinedType] = Undefined, + from_: Union[core.LookupData, core.LookupSelection, UndefinedType] = Undefined, + as_: Union[ + Union[str, core.FieldName], List[Union[str, core.FieldName]], UndefinedType + ] = Undefined, + default: Union[str, UndefinedType] = Undefined, **kwargs, ) -> Self: """Add a :class:`DataLookupTransform` or :class:`SelectionLookupTransform` to the chart @@ -1902,11 +2028,11 @@ def transform_lookup( def transform_pivot( self, - pivot, - value, - groupby=Undefined, - limit=Undefined, - op=Undefined, + pivot: Union[str, core.FieldName], + value: Union[str, core.FieldName], + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + limit: Union[int, UndefinedType] = Undefined, + op: Union[str, core.AggregateOp, UndefinedType] = Undefined, ) -> Self: """Add a :class:`PivotTransform` to the chart. @@ -1921,7 +2047,7 @@ def transform_pivot( groupby : List(str) The optional data fields to group by. If not specified, a single group containing all data objects will be used. - limit : float + limit : int An optional parameter indicating the maximum number of pivoted fields to generate. The default ( ``0`` ) applies no limit. The pivoted ``pivot`` names are sorted in ascending order prior to enforcing the limit. @@ -1948,11 +2074,11 @@ def transform_pivot( def transform_quantile( self, - quantile, - as_=Undefined, - groupby=Undefined, - probs=Undefined, - step=Undefined, + quantile: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + probs: Union[List[float], UndefinedType] = Undefined, + step: Union[float, UndefinedType] = Undefined, ) -> Self: """Add a :class:`QuantileTransform` to the chart @@ -1994,14 +2120,16 @@ def transform_quantile( def transform_regression( self, - on, - regression, - as_=Undefined, - extent=Undefined, - groupby=Undefined, - method=Undefined, - order=Undefined, - params=Undefined, + on: Union[str, core.FieldName], + regression: Union[str, core.FieldName], + as_: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + extent: Union[List[float], UndefinedType] = Undefined, + groupby: Union[List[Union[str, core.FieldName]], UndefinedType] = Undefined, + method: Union[ + Literal["linear", "log", "exp", "pow", "quad", "poly"], UndefinedType + ] = Undefined, + order: Union[int, UndefinedType] = Undefined, + params: Union[bool, UndefinedType] = Undefined, ) -> Self: """Add a :class:`RegressionTransform` to the chart. @@ -2023,7 +2151,7 @@ def transform_regression( method : enum('linear', 'log', 'exp', 'pow', 'quad', 'poly') The functional form of the regression model. One of ``"linear"``, ``"log"``, ``"exp"``, ``"pow"``, ``"quad"``, or ``"poly"``. **Default value:** ``"linear"`` - order : float + order : int The polynomial order (number of coefficients) for the 'poly' method. **Default value:** ``3`` params : boolean @@ -2057,13 +2185,13 @@ def transform_regression( ) ) - def transform_sample(self, sample=1000) -> Self: + def transform_sample(self, sample: int = 1000) -> Self: """ Add a :class:`SampleTransform` to the schema. Parameters ---------- - sample : float + sample : int The maximum number of data objects to include in the sample. Default: 1000. Returns @@ -2078,7 +2206,14 @@ def transform_sample(self, sample=1000) -> Self: return self._add_transform(core.SampleTransform(sample)) def transform_stack( - self, as_, stack, groupby, offset=Undefined, sort=Undefined + self, + as_: Union[str, core.FieldName, List[str]], + stack: Union[str, core.FieldName], + groupby: List[Union[str, core.FieldName]], + offset: Union[ + Literal["zero", "center", "normalize"], UndefinedType + ] = Undefined, + sort: Union[List[core.SortField], UndefinedType] = Undefined, ) -> Self: """ Add a :class:`StackTransform` to the schema. @@ -2116,10 +2251,10 @@ def transform_stack( def transform_timeunit( self, - as_=Undefined, - field=Undefined, - timeUnit=Undefined, - **kwargs, + as_: Union[str, core.FieldName, UndefinedType] = Undefined, + field: Union[str, core.FieldName, UndefinedType] = Undefined, + timeUnit: Union[str, core.TimeUnit, UndefinedType] = Undefined, + **kwargs: str, ) -> Self: """ Add a :class:`TimeUnitTransform` to the schema. @@ -2130,7 +2265,7 @@ def transform_timeunit( The output field to write the timeUnit value. field : string The data field to apply time unit. - timeUnit : :class:`TimeUnit` + timeUnit : str or :class:`TimeUnit` The timeUnit. **kwargs transforms can also be passed by keyword argument; see Examples @@ -2200,12 +2335,14 @@ def transform_timeunit( def transform_window( self, - window=Undefined, - frame=Undefined, - groupby=Undefined, - ignorePeers=Undefined, - sort=Undefined, - **kwargs, + window: Union[List[core.WindowFieldDef], UndefinedType] = Undefined, + frame: Union[List[Optional[int]], UndefinedType] = Undefined, + groupby: Union[List[str], UndefinedType] = Undefined, + ignorePeers: Union[bool, UndefinedType] = Undefined, + sort: Union[ + List[Union[core.SortField, TypingDict[str, str]]], UndefinedType + ] = Undefined, + **kwargs: str, ) -> Self: """Add a :class:`WindowTransform` to the schema @@ -2213,7 +2350,7 @@ def transform_window( ---------- window : List(:class:`WindowFieldDef`) The definition of the fields in the window, and what calculations to use. - frame : List(anyOf(None, float)) + frame : List(anyOf(None, int)) A frame specification as a two-element array indicating how the sliding window should proceed. The array entries should either be a number indicating the offset from the current data object, or null to indicate unbounded rows preceding or @@ -2288,6 +2425,7 @@ def transform_window( parse_types=False, ) ) + assert not isinstance(window, UndefinedType) # For mypy window.append(core.WindowFieldDef(**kwds)) return self._add_transform( @@ -2314,7 +2452,13 @@ def _repr_mimebundle_(self, include=None, exclude=None): else: return renderers.get()(dct) - def display(self, renderer=Undefined, theme=Undefined, actions=Undefined, **kwargs): + def display( + self, + renderer: Union[Literal["canvas", "svg"], UndefinedType] = Undefined, + theme: Union[str, UndefinedType] = Undefined, + actions: Union[bool, dict, UndefinedType] = Undefined, + **kwargs, + ) -> None: """Display chart in Jupyter notebook or JupyterLab Parameters are passed as options to vega-embed within supported frontends. @@ -2410,7 +2554,9 @@ def serve( http_server=http_server, ) - def show(self, embed_opt=None, open_browser=None): + def show( + self, embed_opt: Optional[dict] = None, open_browser: Optional[bool] = None + ) -> None: """Show the chart in an external browser window. This requires a recent version of the altair_viewer package. @@ -2418,7 +2564,7 @@ def show(self, embed_opt=None, open_browser=None): Parameters ---------- embed_opt : dict (optional) - The Vega embed options that control the dispay of the chart. + The Vega embed options that control the display of the chart. open_browser : bool (optional) Specify whether a browser window should be opened. If not specified, a browser window will be opened only if the server is not already @@ -2480,11 +2626,13 @@ def encode(self, *args, **kwargs) -> Self: def facet( self, - facet=Undefined, - row=Undefined, - column=Undefined, - data=Undefined, - columns=Undefined, + facet: Union[str, channels.Facet, UndefinedType] = Undefined, + row: Union[str, core.FacetFieldDef, channels.Row, UndefinedType] = Undefined, + column: Union[ + str, core.FacetFieldDef, channels.Column, UndefinedType + ] = Undefined, + data: Union[_ChartDataType, UndefinedType] = Undefined, + columns: Union[int, UndefinedType] = Undefined, **kwargs, ) -> "FacetChart": """Create a facet chart from the current chart. @@ -2495,13 +2643,13 @@ def facet( Parameters ---------- - facet : string or alt.Facet (optional) + facet : string, Facet (optional) The data column to use as an encoding for a wrapped facet. If specified, then neither row nor column may be specified. - column : string or alt.Column (optional) + column : string, Column, FacetFieldDef (optional) The data column to use as an encoding for a column facet. May be combined with row argument, but not with facet argument. - row : string or alt.Column (optional) + row : string or Row, FacetFieldDef (optional) The data column to use as an encoding for a row facet. May be combined with column argument, but not with facet argument. data : string or dataframe (optional) @@ -2604,13 +2752,13 @@ class Chart( def __init__( self, - data=Undefined, - encoding=Undefined, - mark=Undefined, - width=Undefined, - height=Undefined, + data: Union[_ChartDataType, UndefinedType] = Undefined, + encoding: Union[core.FacetedEncoding, UndefinedType] = Undefined, + mark: Union[str, core.AnyMark, UndefinedType] = Undefined, + width: Union[int, str, dict, core.Step, UndefinedType] = Undefined, + height: Union[int, str, dict, core.Step, UndefinedType] = Undefined, **kwargs, - ): + ) -> None: super(Chart, self).__init__( data=data, encoding=encoding, @@ -2620,15 +2768,15 @@ def __init__( **kwargs, ) - _counter = 0 + _counter: int = 0 @classmethod - def _get_name(cls): + def _get_name(cls) -> str: cls._counter += 1 return f"view_{cls._counter}" @classmethod - def from_dict(cls, dct, validate=True) -> core.SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future + def from_dict(cls, dct: dict, validate: bool = True) -> core.SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future """Construct class from a dictionary representation Parameters @@ -2740,7 +2888,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params: return self @@ -2759,7 +2907,9 @@ def add_selection(self, *params) -> Self: """'add_selection' is deprecated. Use 'add_params' instead.""" return self.add_params(*params) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -2786,7 +2936,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: return self.add_params(selection_interval(bind="scales", encodings=encodings)) -def _check_if_valid_subspec(spec, classname): +def _check_if_valid_subspec(spec: Union[dict, core.SchemaBase], classname: str) -> None: """Check if the spec is a valid sub-spec. If it is not, then raise a ValueError @@ -2807,7 +2957,7 @@ def _check_if_valid_subspec(spec, classname): raise ValueError(err.format(attr, classname)) -def _check_if_can_be_layered(spec): +def _check_if_can_be_layered(spec: Union[dict, core.SchemaBase]) -> None: """Check if the spec can be layered.""" def _get(spec, attr): @@ -2941,7 +3091,9 @@ def transformed_data( "transformed_data is not yet implemented for RepeatChart" ) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -2964,7 +3116,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: copy.spec = copy.spec.interactive(name=name, bind_x=bind_x, bind_y=bind_y) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or self.spec is Undefined: return self @@ -2980,7 +3132,9 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def repeat(repeater="repeat"): +def repeat( + repeater: Literal["row", "column", "repeat", "layer"] = "repeat" +) -> core.RepeatRef: """Tie a channel to the row or column within a repeated chart The output of this should be passed to the ``field`` attribute of @@ -3014,14 +3168,16 @@ def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs): self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) - def __ior__(self, other): + # Too difficult to fix override error + def __ior__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override] _check_if_valid_subspec(other, "ConcatChart") self.concat.append(other) self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) return self - def __or__(self, other): + # Too difficult to fix override error + def __or__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override] copy = self.copy(deep=["concat"]) copy |= other return copy @@ -3052,7 +3208,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3078,7 +3236,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.concat: return self @@ -3094,7 +3252,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def concat(*charts, **kwargs): +def concat(*charts, **kwargs) -> ConcatChart: """Concatenate charts horizontally""" return ConcatChart(concat=charts, **kwargs) @@ -3111,14 +3269,14 @@ def __init__(self, data=Undefined, hconcat=(), **kwargs): self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) - def __ior__(self, other): + def __ior__(self, other: core.NonNormalizedSpec) -> Self: _check_if_valid_subspec(other, "HConcatChart") self.hconcat.append(other) self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) return self - def __or__(self, other): + def __or__(self, other: core.NonNormalizedSpec) -> Self: copy = self.copy(deep=["hconcat"]) copy |= other return copy @@ -3149,7 +3307,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3175,7 +3335,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.hconcat: return self @@ -3191,7 +3351,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def hconcat(*charts, **kwargs): +def hconcat(*charts, **kwargs) -> HConcatChart: """Concatenate charts horizontally""" return HConcatChart(hconcat=charts, **kwargs) @@ -3208,14 +3368,14 @@ def __init__(self, data=Undefined, vconcat=(), **kwargs): self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) - def __iand__(self, other): + def __iand__(self, other: core.NonNormalizedSpec) -> Self: _check_if_valid_subspec(other, "VConcatChart") self.vconcat.append(other) self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) return self - def __and__(self, other): + def __and__(self, other: core.NonNormalizedSpec) -> Self: copy = self.copy(deep=["vconcat"]) copy &= other return copy @@ -3246,7 +3406,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3272,7 +3434,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: encodings.append("y") return self.add_params(selection_interval(bind="scales", encodings=encodings)) - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.vconcat: return self @@ -3288,7 +3450,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def vconcat(*charts, **kwargs): +def vconcat(*charts: core.NonNormalizedSpec, **kwargs) -> VConcatChart: """Concatenate charts vertically""" return VConcatChart(vconcat=charts, **kwargs) @@ -3342,7 +3504,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def __iadd__(self, other): + def __iadd__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) self.layer.append(other) @@ -3350,18 +3512,20 @@ def __iadd__(self, other): self.params, self.layer = _combine_subchart_params(self.params, self.layer) return self - def __add__(self, other): + def __add__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: copy = self.copy(deep=["layer"]) copy += other return copy - def add_layers(self, *layers) -> Self: + def add_layers(self, *layers: Union[core.LayerSpec, core.UnitSpec]) -> Self: copy = self.copy(deep=["layer"]) for layer in layers: copy += layer return copy - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3390,7 +3554,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: ) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or not self.layer: return self @@ -3406,7 +3570,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def layer(*charts, **kwargs): +def layer(*charts, **kwargs) -> LayerChart: """layer multiple charts""" return LayerChart(layer=charts, **kwargs) @@ -3457,7 +3621,9 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: + def interactive( + self, name: Optional[str] = None, bind_x: bool = True, bind_y: bool = True + ) -> Self: """Make chart axes scales interactive Parameters @@ -3480,7 +3646,7 @@ def interactive(self, name=None, bind_x=True, bind_y=True) -> Self: copy.spec = copy.spec.interactive(name=name, bind_x=bind_x, bind_y=bind_y) return copy - def add_params(self, *params) -> Self: + def add_params(self, *params: Parameter) -> Self: """Add one or more parameters to the chart.""" if not params or self.spec is Undefined: return self @@ -3496,7 +3662,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def topo_feature(url, feature, **kwargs): +def topo_feature(url: str, feature: str, **kwargs) -> core.UrlData: """A convenience function for extracting features from a topojson url Parameters @@ -3544,7 +3710,7 @@ def remove_data(subchart): return data, subcharts -def _viewless_dict(param): +def _viewless_dict(param: Parameter) -> dict: d = param.to_dict() d.pop("views", None) return d @@ -3806,6 +3972,6 @@ def graticule(**kwds): return core.GraticuleGenerator(graticule=graticule) -def sphere(): +def sphere() -> core.SphereGenerator: """Sphere generator.""" return core.SphereGenerator(sphere=True) diff --git a/tools/update_init_file.py b/tools/update_init_file.py index a7a9a2cd8..e90def7b7 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -6,7 +6,7 @@ import sys from pathlib import Path from os.path import abspath, dirname, join -from typing import TypeVar, Type, cast, List, Any, Optional, Iterable, Union +from typing import TypeVar, Type, cast, List, Any, Optional, Iterable, Union, IO import black @@ -80,6 +80,7 @@ def _is_relevant_attribute(attr_name: str) -> bool: or attr is Optional or attr is Iterable or attr is Union + or attr is IO or attr_name == "TypingDict" ): return False