Skip to content

Commit

Permalink
Various mypy error fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
binste committed Sep 24, 2023
1 parent dbcf461 commit 9ca1af9
Showing 1 changed file with 51 additions and 30 deletions.
81 changes: 51 additions & 30 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from toolz.curried import pipe as _pipe
import itertools
import sys
from typing import cast, List, Optional, Any, Iterable, Union, Type, Dict, Literal, IO
from typing import cast, List, Optional, Any, Iterable, Union, Literal, IO

# Have to rename it here as else it overlaps with schema.core.Type
# Have to rename it here as else it overlaps with schema.core.Type and schema.core.Dict
from typing import Type as TypingType
from typing import Dict as TypingDict

Expand Down Expand Up @@ -175,8 +175,8 @@ def to_dict(self, *args, **kwargs) -> dict:
TOPLEVEL_ONLY_KEYS = {"background", "config", "autosize", "padding", "$schema"}


def _get_channels_mapping() -> Dict[Type[core.SchemaBase], str]:
mapping: Dict[Type[core.SchemaBase], str] = {}
def _get_channels_mapping() -> TypingDict[TypingType[core.SchemaBase], str]:
mapping: TypingDict[TypingType[core.SchemaBase], str] = {}
for attr in dir(channels):
cls = getattr(channels, attr)
if isinstance(cls, type) and issubclass(cls, core.SchemaBase):
Expand Down Expand Up @@ -299,7 +299,7 @@ class ParameterExpression(expr.core.OperatorMixin, object):
def __init__(self, expr) -> None:
self.expr = expr

def to_dict(self) -> Dict[str, str]:
def to_dict(self) -> TypingDict[str, str]:
return {"expr": repr(self.expr)}

def _to_expr(self) -> str:
Expand All @@ -313,7 +313,7 @@ class SelectionExpression(expr.core.OperatorMixin, object):
def __init__(self, expr) -> None:
self.expr = expr

def to_dict(self) -> Dict[str, str]:
def to_dict(self) -> TypingDict[str, str]:
return {"expr": repr(self.expr)}

def _to_expr(self) -> str:
Expand All @@ -326,7 +326,7 @@ def _from_expr(self, expr) -> "SelectionExpression":
def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
for prop in ["fields", "encodings"]:
try:
if field_name in getattr(parameter.param.select, prop):
if field_name in getattr(parameter.param.select, prop): # type: ignore[union-attr]
return True
except (AttributeError, TypeError):
pass
Expand Down Expand Up @@ -501,7 +501,7 @@ def selection(
def selection_interval(
name: Optional[str] = None,
value: Union[Any, UndefinedType] = Undefined,
bind: Union[core.Binding, UndefinedType] = Undefined,
bind: Union[core.Binding, str, UndefinedType] = Undefined,
empty: Union[bool, UndefinedType] = Undefined,
expr: Union[str, core.Expr, expr.core.Expression, UndefinedType] = Undefined,
encodings: Union[List[str], UndefinedType] = Undefined,
Expand All @@ -523,7 +523,7 @@ def selection_interval(
value : any (optional)
The default value of the parameter. If not specified, the parameter
will be created without a default value.
bind : :class:`Binding` (optional)
bind : :class:`Binding`, str (optional)
Binds the parameter to an external input element such as a slider,
selection list or radio button group.
empty : boolean (optional)
Expand Down Expand Up @@ -613,7 +613,7 @@ def selection_interval(
def selection_point(
name: Optional[str] = None,
value: Union[Any, UndefinedType] = Undefined,
bind: Union[core.Binding, UndefinedType] = Undefined,
bind: Union[core.Binding, str, UndefinedType] = Undefined,
empty: Union[bool, UndefinedType] = Undefined,
expr: Union[core.Expr, UndefinedType] = Undefined,
encodings: Union[List[str], UndefinedType] = Undefined,
Expand All @@ -635,7 +635,7 @@ def selection_point(
value : any (optional)
The default value of the parameter. If not specified, the parameter
will be created without a default value.
bind : :class:`Binding` (optional)
bind : :class:`Binding`, str (optional)
Binds the parameter to an external input element such as a slider,
selection list or radio button group.
empty : boolean (optional)
Expand Down Expand Up @@ -804,15 +804,21 @@ def condition(
"""
test_predicates = (str, expr.Expression, core.PredicateComposition)

condition: TypingDict[
str, Union[bool, str, expr.core.Expression, core.PredicateComposition]
]
if isinstance(predicate, Parameter):
if predicate.param_type == "selection" or predicate.param.expr is Undefined:
if (
predicate.param_type == "selection"
or getattr(predicate.param, "expr", Undefined) is Undefined
):
condition = {"param": predicate.name}
if "empty" in kwargs:
condition["empty"] = kwargs.pop("empty")
elif isinstance(predicate.empty, bool):
condition["empty"] = predicate.empty
else:
condition = {"test": predicate.param.expr}
condition = {"test": getattr(predicate.param, "expr", Undefined)}
elif isinstance(predicate, test_predicates):
condition = {"test": predicate}
elif isinstance(predicate, dict):
Expand All @@ -836,6 +842,7 @@ def condition(
if_true.update(kwargs)
condition.update(if_true)

selection: Union[dict, core.SchemaBase]
if isinstance(if_false, core.SchemaBase):
# For the selection, the channel definitions all allow selections
# already. So use this SchemaBase wrapper if possible.
Expand Down Expand Up @@ -1049,7 +1056,7 @@ def to_html(
def save(
self,
fp: Union[str, IO],
format: Literal["json", "html", "png", "svg", "pdf"] = None,
format: Optional[Literal["json", "html", "png", "svg", "pdf"]] = None,
override_data_transformer: bool = True,
scale_factor: float = 1.0,
vegalite_version: str = VEGALITE_VERSION,
Expand Down Expand Up @@ -1114,17 +1121,20 @@ def __repr__(self) -> str:
def __add__(self, other) -> "LayerChart":
if not isinstance(other, TopLevelMixin):
raise ValueError("Only Chart objects can be layered.")
return layer(self, other)
# Too difficult to type check this
return layer(self, other) # type: ignore[arg-type]

def __and__(self, other) -> "VConcatChart":
if not isinstance(other, TopLevelMixin):
raise ValueError("Only Chart objects can be concatenated.")
return vconcat(self, other)
# Too difficult to type check this
return vconcat(self, other) # type: ignore[arg-type]

def __or__(self, other) -> "HConcatChart":
if not isinstance(other, TopLevelMixin):
raise ValueError("Only Chart objects can be concatenated.")
return hconcat(self, other)
# Too difficult to type check this
return hconcat(self, other) # type: ignore[arg-type]

def repeat(
self,
Expand Down Expand Up @@ -1174,14 +1184,16 @@ def repeat(
elif repeat_specified and layer_specified:
raise ValueError("repeat argument cannot be combined with layer argument.")

repeat_arg: Union[List[str], core.LayerRepeatMapping, core.RepeatMapping]
if repeat_specified:
repeat = repeat
assert not isinstance(repeat, UndefinedType) # For mypy
repeat_arg = repeat
elif layer_specified:
repeat = core.LayerRepeatMapping(layer=layer, row=row, column=column)
repeat_arg = core.LayerRepeatMapping(layer=layer, row=row, column=column)
else:
repeat = core.RepeatMapping(row=row, column=column)
repeat_arg = core.RepeatMapping(row=row, column=column)

return RepeatChart(spec=self, repeat=repeat, columns=columns, **kwargs)
return RepeatChart(spec=self, repeat=repeat_arg, columns=columns, **kwargs)

def properties(self, **kwargs) -> Self:
"""Set top-level properties of the Chart.
Expand Down Expand Up @@ -1350,7 +1362,7 @@ def project(

def _add_transform(self, *transforms: core.Transform) -> Self:
"""Copy the chart and add specified transforms to chart.transform"""
copy = self.copy(deep=["transform"])
copy = self.copy(deep=["transform"]) # type: ignore[attr-defined]
if copy.transform is Undefined:
copy.transform = []
copy.transform.extend(transforms)
Expand All @@ -1360,7 +1372,7 @@ def transform_aggregate(
self,
aggregate: Union[List[core.AggregatedFieldDef], UndefinedType] = Undefined,
groupby: Union[List[str], UndefinedType] = Undefined,
**kwds: Union[Dict[str, Any], str],
**kwds: Union[TypingDict[str, Any], str],
) -> Self:
"""
Add an :class:`AggregateTransform` to the schema.
Expand All @@ -1372,7 +1384,7 @@ def transform_aggregate(
groupby : List(string)
The data fields to group by. If not specified, a single group containing all data
objects will be used.
**kwds : Union[Dict[str, Any], str]
**kwds : Union[TypingDict[str, Any], str]
additional keywords are converted to aggregates using standard
shorthand parsing.
Expand Down Expand Up @@ -1430,6 +1442,7 @@ def transform_aggregate(
"field": parsed.get("field", Undefined),
"op": parsed.get("aggregate", Undefined),
}
assert not isinstance(aggregate, UndefinedType) # For mypy
aggregate.append(core.AggregatedFieldDef(**dct))
return self._add_transform(
core.AggregateTransform(aggregate=aggregate, groupby=groupby)
Expand Down Expand Up @@ -1551,7 +1564,11 @@ def transform_calculate(
alt.CalculateTransform : underlying transform object
"""
if as_ is Undefined:
as_ = kwargs.pop("as", Undefined)
# Ignoring assignment error as passing 'as' as a keyword argument is
# an edge case and it's not worth changing the type annotation
# in this function to account for it as it could be confusing to
# users.
as_ = kwargs.pop("as", Undefined) # type: ignore[assignment]
elif "as" in kwargs:
raise ValueError(
"transform_calculate: both 'as_' and 'as' passed as arguments."
Expand Down Expand Up @@ -1756,6 +1773,7 @@ def transform_joinaggregate(
"field": parsed.get("field", Undefined),
"op": parsed.get("aggregate", Undefined),
}
assert not isinstance(joinaggregate, UndefinedType) # For mypy
joinaggregate.append(core.JoinAggregateFieldDef(**dct))
return self._add_transform(
core.JoinAggregateTransform(joinaggregate=joinaggregate, groupby=groupby)
Expand Down Expand Up @@ -1789,7 +1807,7 @@ def transform_filter(
Parameter,
core.PredicateComposition,
# E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
Dict[str, core.Predicate],
TypingDict[str, Union[core.Predicate, str, bool]],
],
**kwargs,
) -> Self:
Expand Down Expand Up @@ -1817,7 +1835,7 @@ def transform_filter(
new_filter["empty"] = kwargs.pop("empty")
elif isinstance(filter.empty, bool):
new_filter["empty"] = filter.empty
filter = new_filter
filter = new_filter # type: ignore[assignment]
return self._add_transform(core.FilterTransform(filter=filter, **kwargs))

def transform_flatten(
Expand Down Expand Up @@ -2285,7 +2303,7 @@ def transform_window(
frame: Union[List[Optional[int]], UndefinedType] = Undefined,
groupby: Union[List[str], UndefinedType] = Undefined,
ignorePeers: Union[bool, UndefinedType] = Undefined,
sort: Union[List[core.SortField], UndefinedType] = Undefined,
sort: Union[List[Union[core.SortField, dict[str, str]]], UndefinedType] = Undefined,
**kwargs: str,
) -> Self:
"""Add a :class:`WindowTransform` to the schema
Expand Down Expand Up @@ -2369,6 +2387,7 @@ def transform_window(
parse_types=False,
)
)
assert not isinstance(window, UndefinedType) # For mypy
window.append(core.WindowFieldDef(**kwds))

return self._add_transform(
Expand Down Expand Up @@ -3111,14 +3130,16 @@ def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs):
self.data, self.concat = _combine_subchart_data(self.data, self.concat)
self.params, self.concat = _combine_subchart_params(self.params, self.concat)

def __ior__(self, other: core.NonNormalizedSpec) -> Self:
# Too difficult to fix override error
def __ior__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override]
_check_if_valid_subspec(other, "ConcatChart")
self.concat.append(other)
self.data, self.concat = _combine_subchart_data(self.data, self.concat)
self.params, self.concat = _combine_subchart_params(self.params, self.concat)
return self

def __or__(self, other: core.NonNormalizedSpec) -> Self:
# Too difficult to fix override error
def __or__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override]
copy = self.copy(deep=["concat"])
copy |= other
return copy
Expand Down

0 comments on commit 9ca1af9

Please sign in to comment.