From dd5b61a6ecdbd15bb9f29e1f0137d651790f7de2 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Thu, 23 Nov 2023 20:41:25 +0100 Subject: [PATCH] Type hints: Finish type hints and mark package as typed (#3272) * Type Undefined as UndefinedType. Some minor mypy fixes * Add entry to changelog * Make types public so users can use them in their own code if needed * Add py.typed * Remove type annotation on inputs to vconcat as too complex to typed. Already removed for layer and hconcat in a previous PR * Move some changes into code generation files * Remove unused ignore statement * Make SupportsGeoInterface and DataFrameLike public --- altair/__init__.py | 3 +++ altair/py.typed | 0 altair/utils/_transformed_data.py | 6 ++--- altair/utils/_vegafusion_data.py | 12 ++++----- altair/utils/core.py | 18 ++++++------- altair/utils/data.py | 30 ++++++++++----------- altair/utils/schemapi.py | 10 +++---- altair/vegalite/data.py | 6 ++--- altair/vegalite/v5/api.py | 44 +++++++++++++++---------------- doc/releases/changes.rst | 1 + pyproject.toml | 1 + tools/schemapi/schemapi.py | 10 +++---- 12 files changed, 69 insertions(+), 72 deletions(-) create mode 100644 altair/py.typed diff --git a/altair/__init__.py b/altair/__init__.py index 8184d1722..65571f02a 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -54,6 +54,7 @@ "CalculateTransform", "Categorical", "Chart", + "ChartDataType", "Color", "ColorDatum", "ColorDef", @@ -125,7 +126,9 @@ "Cyclical", "Data", "DataFormat", + "DataFrameLike", "DataSource", + "DataType", "Datasets", "DateTime", "DatumChannelMixin", diff --git a/altair/py.typed b/altair/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index 99bfbcde4..c3498d4c3 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -25,7 +25,7 @@ data_transformers, ) from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion -from altair.utils.core import _DataFrameLike +from altair.utils.core import DataFrameLike from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] @@ -56,7 +56,7 @@ def transformed_data( chart: Union[Chart, FacetChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> Optional[_DataFrameLike]: +) -> Optional[DataFrameLike]: ... @@ -65,7 +65,7 @@ def transformed_data( chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> List[_DataFrameLike]: +) -> List[DataFrameLike]: ... diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 920082b58..65585e5bf 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -7,15 +7,15 @@ from typing import TypedDict, Final from altair.utils._importers import import_vegafusion -from altair.utils.core import _DataFrameLike -from altair.utils.data import _DataType, _ToValuesReturnType, MaxRowsError +from altair.utils.core import DataFrameLike +from altair.utils.data import DataType, ToValuesReturnType, MaxRowsError from altair.vegalite.data import default_data_transformer # Temporary storage for dataframes that have been extracted # from charts by the vegafusion data transformer. Use a WeakValueDictionary # rather than a dict so that the Python interpreter is free to garbage # collect the stored DataFrames. -extracted_inline_tables: MutableMapping[str, _DataFrameLike] = WeakValueDictionary() +extracted_inline_tables: MutableMapping[str, DataFrameLike] = WeakValueDictionary() # Special URL prefix that VegaFusion uses to denote that a # dataset in a Vega spec corresponds to an entry in the `inline_datasets` @@ -29,8 +29,8 @@ class _ToVegaFusionReturnUrlDict(TypedDict): @curried.curry def vegafusion_data_transformer( - data: _DataType, max_rows: int = 100000 -) -> Union[_ToVegaFusionReturnUrlDict, _ToValuesReturnType]: + data: DataType, max_rows: int = 100000 +) -> Union[_ToVegaFusionReturnUrlDict, ToValuesReturnType]: """VegaFusion Data Transformer""" if hasattr(data, "__geo_interface__"): # Use default transformer for geo interface objects @@ -95,7 +95,7 @@ def get_inline_table_names(vega_spec: dict) -> Set[str]: return table_names -def get_inline_tables(vega_spec: dict) -> Dict[str, _DataFrameLike]: +def get_inline_tables(vega_spec: dict) -> Dict[str, DataFrameLike]: """Get the inline tables referenced by a Vega specification Note: This function should only be called on a Vega spec that corresponds diff --git a/altair/utils/core.py b/altair/utils/core.py index 4d2f9c7a5..ea8abf1f1 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -41,11 +41,11 @@ if TYPE_CHECKING: from pandas.core.interchange.dataframe_protocol import Column as PandasColumn -_V = TypeVar("_V") -_P = ParamSpec("_P") +V = TypeVar("V") +P = ParamSpec("P") -class _DataFrameLike(Protocol): +class DataFrameLike(Protocol): def __dataframe__(self, *args, **kwargs) -> DfiDataFrame: ... @@ -188,12 +188,12 @@ def __dataframe__(self, *args, **kwargs) -> DfiDataFrame: ] -_InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] +InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] def infer_vegalite_type( data: object, -) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: +) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: """ From an array-like input, infer the correct vega typecode ('ordinal', 'nominal', 'quantitative', or 'temporal') @@ -442,7 +442,7 @@ def sanitize_arrow_table(pa_table): def parse_shorthand( shorthand: Union[Dict[str, Any], str], - data: Optional[Union[pd.DataFrame, _DataFrameLike]] = None, + data: Optional[Union[pd.DataFrame, DataFrameLike]] = None, parse_aggregates: bool = True, parse_window_ops: bool = False, parse_timeunits: bool = True, @@ -637,7 +637,7 @@ def parse_shorthand( def infer_vegalite_type_for_dfi_column( column: Union[Column, "PandasColumn"], -) -> Union[_InferredVegaLiteType, Tuple[_InferredVegaLiteType, list]]: +) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: from pyarrow.interchange.from_dataframe import column_to_array try: @@ -672,10 +672,10 @@ def infer_vegalite_type_for_dfi_column( raise ValueError(f"Unexpected DtypeKind: {kind}") -def use_signature(Obj: Callable[_P, Any]): +def use_signature(Obj: Callable[P, Any]): """Apply call signature and documentation of Obj to the decorated method""" - def decorate(f: Callable[..., _V]) -> Callable[_P, _V]: + def decorate(f: Callable[..., V]) -> Callable[P, V]: # call-signature of f is exposed via __wrapped__. # we want it to mimic Obj.__init__ f.__wrapped__ = Obj.__init__ # type: ignore diff --git a/altair/utils/data.py b/altair/utils/data.py index 9437897ec..2a3a3474b 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -10,7 +10,7 @@ from typing import TypeVar from ._importers import import_pyarrow_interchange -from .core import sanitize_dataframe, sanitize_arrow_table, _DataFrameLike +from .core import sanitize_dataframe, sanitize_arrow_table, DataFrameLike from .core import sanitize_geo_interface from .deprecation import AltairDeprecationWarning from .plugin_registry import PluginRegistry @@ -23,15 +23,15 @@ import pyarrow.lib -class _SupportsGeoInterface(Protocol): +class SupportsGeoInterface(Protocol): __geo_interface__: MutableMapping -_DataType = Union[dict, pd.DataFrame, _SupportsGeoInterface, _DataFrameLike] -_TDataType = TypeVar("_TDataType", bound=_DataType) +DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, DataFrameLike] +TDataType = TypeVar("TDataType", bound=DataType) -_VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] -_ToValuesReturnType = Dict[str, Union[dict, List[dict]]] +VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] +ToValuesReturnType = Dict[str, Union[dict, List[dict]]] # ============================================================================== @@ -46,7 +46,7 @@ class _SupportsGeoInterface(Protocol): # form. # ============================================================================== class DataTransformerType(Protocol): - def __call__(self, data: _DataType, **kwargs) -> _VegaLiteDataDict: + def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: pass @@ -70,7 +70,7 @@ class MaxRowsError(Exception): @curried.curry -def limit_rows(data: _TDataType, max_rows: Optional[int] = 5000) -> _TDataType: +def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType: """Raise MaxRowsError if the data model has more than max_rows. If max_rows is None, then do not perform any check. @@ -122,7 +122,7 @@ def raise_max_rows_error(): @curried.curry def sample( - data: _DataType, n: Optional[int] = None, frac: Optional[float] = None + data: DataType, n: Optional[int] = None, frac: Optional[float] = None ) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]: """Reduce the size of the data model by sampling without replacement.""" check_data_type(data) @@ -180,7 +180,7 @@ class _ToCsvReturnUrlDict(TypedDict): @curried.curry def to_json( - data: _DataType, + data: DataType, prefix: str = "altair-data", extension: str = "json", filename: str = "{prefix}-{hash}.{extension}", @@ -199,7 +199,7 @@ def to_json( @curried.curry def to_csv( - data: Union[dict, pd.DataFrame, _DataFrameLike], + data: Union[dict, pd.DataFrame, DataFrameLike], prefix: str = "altair-data", extension: str = "csv", filename: str = "{prefix}-{hash}.{extension}", @@ -215,7 +215,7 @@ def to_csv( @curried.curry -def to_values(data: _DataType) -> _ToValuesReturnType: +def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) if hasattr(data, "__geo_interface__"): @@ -242,7 +242,7 @@ def to_values(data: _DataType) -> _ToValuesReturnType: raise ValueError("Unrecognized data type: {}".format(type(data))) -def check_data_type(data: _DataType) -> None: +def check_data_type(data: DataType) -> None: if not isinstance(data, (dict, pd.DataFrame)) and not any( hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] ): @@ -260,7 +260,7 @@ def _compute_data_hash(data_str: str) -> str: return hashlib.md5(data_str.encode()).hexdigest() -def _data_to_json_string(data: _DataType) -> str: +def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): @@ -288,7 +288,7 @@ def _data_to_json_string(data: _DataType) -> str: ) -def _data_to_csv_string(data: Union[dict, pd.DataFrame, _DataFrameLike]) -> str: +def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: """return a CSV string representation of the input data""" check_data_type(data) if hasattr(data, "__geo_interface__"): diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index cf2fba2f6..8b1c2ebb3 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -44,7 +44,7 @@ else: from typing_extensions import Self -_TSchemaBase = TypeVar("_TSchemaBase", bound=Type["SchemaBase"]) +TSchemaBase = TypeVar("TSchemaBase", bound=Type["SchemaBase"]) ValidationErrorList = List[jsonschema.exceptions.ValidationError] GroupedValidationErrors = Dict[str, ValidationErrorList] @@ -733,11 +733,7 @@ def __repr__(self): return "Undefined" -# In the future Altair may implement a more complete set of type hints. -# But for now, we'll add an annotation to indicate that the type checker -# should permit any value passed to a function argument whose default -# value is Undefined. -Undefined: Any = UndefinedType() +Undefined = UndefinedType() class SchemaBase: @@ -1329,7 +1325,7 @@ def __call__(self, *args, **kwargs): return obj -def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase: +def with_property_setters(cls: TSchemaBase) -> TSchemaBase: """ Decorator to add property setters to a Schema class. """ diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index 8a6c0b074..fbeda0fee 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -12,14 +12,14 @@ check_data_type, ) from ..utils.data import DataTransformerRegistry as _DataTransformerRegistry -from ..utils.data import _DataType, _ToValuesReturnType +from ..utils.data import DataType, ToValuesReturnType from ..utils.plugin_registry import PluginEnabler @curried.curry def default_data_transformer( - data: _DataType, max_rows: int = 5000 -) -> _ToValuesReturnType: + data: DataType, max_rows: int = 5000 +) -> ToValuesReturnType: return curried.pipe(data, limit_rows(max_rows=max_rows), to_values) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 2c2a322f1..730c2e107 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -26,15 +26,15 @@ using_vegafusion as _using_vegafusion, compile_with_vegafusion as _compile_with_vegafusion, ) -from ...utils.core import _DataFrameLike -from ...utils.data import _DataType +from ...utils.core import DataFrameLike +from ...utils.data import DataType if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self -_ChartDataType = Union[_DataType, core.Data, str, core.Generator, UndefinedType] +ChartDataType = Union[DataType, core.Data, str, core.Generator, UndefinedType] # ------------------------------------------------------------------------ @@ -816,7 +816,10 @@ def condition( test_predicates = (str, expr.Expression, core.PredicateComposition) condition: TypingDict[ - str, Union[bool, str, _expr_core.Expression, core.PredicateComposition] + str, + Union[ + bool, str, _expr_core.Expression, core.PredicateComposition, UndefinedType + ], ] if isinstance(predicate, Parameter): if ( @@ -1228,7 +1231,7 @@ def __and__(self, other) -> "VConcatChart": if not isinstance(other, TopLevelMixin): raise ValueError("Only Chart objects can be concatenated.") # Too difficult to type check this - return vconcat(self, other) # type: ignore[arg-type] + return vconcat(self, other) def __or__(self, other) -> "HConcatChart": if not isinstance(other, TopLevelMixin): @@ -1694,7 +1697,7 @@ def transform_calculate( # an edge case and it's not worth changing the type annotation # in this function to account for it as it could be confusing to # users. - as_ = kwargs.pop("as", Undefined) + as_ = kwargs.pop("as", Undefined) # type: ignore[assignment] elif "as" in kwargs: raise ValueError( "transform_calculate: both 'as_' and 'as' passed as arguments." @@ -2739,7 +2742,7 @@ def facet( column: Union[ str, core.FacetFieldDef, channels.Column, UndefinedType ] = Undefined, - data: Union[_ChartDataType, UndefinedType] = Undefined, + data: Union[ChartDataType, UndefinedType] = Undefined, columns: Union[int, UndefinedType] = Undefined, **kwargs, ) -> "FacetChart": @@ -2779,10 +2782,8 @@ def facet( "facet argument cannot be combined with row/column argument." ) - # Remove "ignore" statement once Undefined is no longer typed as Any if data is Undefined: - # Remove "ignore" statement once Undefined is no longer typed as Any - if self.data is Undefined: # type: ignore + if self.data is Undefined: # type: ignore[has-type] raise ValueError( "Facet charts require data to be specified at the top level. " "If you are trying to facet layered or concatenated charts, " @@ -2791,8 +2792,7 @@ def facet( ) # ignore type as copy comes from another class self = self.copy(deep=False) # type: ignore[attr-defined] - # Remove "ignore" statement once Undefined is no longer typed as Any - data, self.data = self.data, Undefined # type: ignore + data, self.data = self.data, Undefined # type: ignore[has-type] if facet_specified: if isinstance(facet, str): @@ -2863,7 +2863,7 @@ class Chart( def __init__( self, - data: Union[_ChartDataType, UndefinedType] = Undefined, + data: Union[ChartDataType, UndefinedType] = Undefined, encoding: Union[core.FacetedEncoding, UndefinedType] = Undefined, mark: Union[str, core.AnyMark, UndefinedType] = Undefined, width: Union[int, str, dict, core.Step, UndefinedType] = Undefined, @@ -2968,7 +2968,7 @@ def to_dict( # No data specified here or in parent: inject empty data # for easier specification of datum encodings. copy = self.copy(deep=False) - copy.data = core.InlineData(values=[{}]) + copy.data = core.InlineData(values=[{}]) # type: ignore[assignment] return super(Chart, copy).to_dict( validate=validate, format=format, ignore=ignore, context=context ) @@ -2980,7 +2980,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[_DataFrameLike]: + ) -> Optional[DataFrameLike]: """Evaluate a Chart's transforms Evaluate the data transforms associated with a Chart and return the @@ -3183,7 +3183,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[_DataFrameLike]: + ) -> Optional[DataFrameLike]: """Evaluate a RepeatChart's transforms Evaluate the data transforms associated with a RepeatChart and return the @@ -3300,7 +3300,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[_DataFrameLike]: + ) -> List[DataFrameLike]: """Evaluate a ConcatChart's transforms Evaluate the data transforms associated with a ConcatChart and return the @@ -3399,7 +3399,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[_DataFrameLike]: + ) -> List[DataFrameLike]: """Evaluate a HConcatChart's transforms Evaluate the data transforms associated with a HConcatChart and return the @@ -3498,7 +3498,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[_DataFrameLike]: + ) -> List[DataFrameLike]: """Evaluate a VConcatChart's transforms Evaluate the data transforms associated with a VConcatChart and return the @@ -3564,7 +3564,7 @@ def add_selection(self, *selections) -> Self: return self.add_params(*selections) -def vconcat(*charts: core.NonNormalizedSpec, **kwargs) -> VConcatChart: +def vconcat(*charts, **kwargs) -> VConcatChart: """Concatenate charts vertically""" return VConcatChart(vconcat=charts, **kwargs) @@ -3596,7 +3596,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[_DataFrameLike]: + ) -> List[DataFrameLike]: """Evaluate a LayerChart's transforms Evaluate the data transforms associated with a LayerChart and return the @@ -3713,7 +3713,7 @@ def transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[_DataFrameLike]: + ) -> Optional[DataFrameLike]: """Evaluate a FacetChart's transforms Evaluate the data transforms associated with a FacetChart and return the diff --git a/doc/releases/changes.rst b/doc/releases/changes.rst index 579a72c1c..7226c80ca 100644 --- a/doc/releases/changes.rst +++ b/doc/releases/changes.rst @@ -11,6 +11,7 @@ Enhancements - Support offline HTML export using vl-convert (#3251) - Support saving charts as PDF files using the vl-convert export engine (#3244) - Support converting charts to sharable Vega editor URLs with ``chart.to_url()`` (#3252) +- Vega-Altair is now a typed package, with type annotations for all public functions and classes and some of the internal code (#2951) Bug Fixes ~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index dc562c780..f322088cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Typing :: Typed", ] [project.urls] diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 356b6b021..67365c762 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -42,7 +42,7 @@ else: from typing_extensions import Self -_TSchemaBase = TypeVar("_TSchemaBase", bound=Type["SchemaBase"]) +TSchemaBase = TypeVar("TSchemaBase", bound=Type["SchemaBase"]) ValidationErrorList = List[jsonschema.exceptions.ValidationError] GroupedValidationErrors = Dict[str, ValidationErrorList] @@ -731,11 +731,7 @@ def __repr__(self): return "Undefined" -# In the future Altair may implement a more complete set of type hints. -# But for now, we'll add an annotation to indicate that the type checker -# should permit any value passed to a function argument whose default -# value is Undefined. -Undefined: Any = UndefinedType() +Undefined = UndefinedType() class SchemaBase: @@ -1327,7 +1323,7 @@ def __call__(self, *args, **kwargs): return obj -def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase: +def with_property_setters(cls: TSchemaBase) -> TSchemaBase: """ Decorator to add property setters to a Schema class. """