diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 638f2f50f..a4af084b9 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -592,6 +592,42 @@ def _validator_values(errors: Iterable[ValidationError], /) -> Iterator[str]: yield cast("str", err.validator_value) +def _iter_channels(tp: type[Any], spec: Mapping[str, Any], /) -> Iterator[type[Any]]: + from altair import vegalite + + for channel_type in ("datum", "value"): + if channel_type in spec: + name = f"{tp.__name__}{channel_type.capitalize()}" + if narrower := getattr(vegalite, name, None): + yield narrower + + +def _is_channel(obj: Any) -> TypeIs[dict[str, Any]]: + props = {"datum", "value"} + return ( + _is_dict(obj) + and all(isinstance(k, str) for k in obj) + and not (props.isdisjoint(obj)) + ) + + +def _maybe_channel(tp: type[Any], spec: Any, /) -> type[Any]: + """ + Replace a channel type with a `more specific`_ one or passthrough unchanged. + + Parameters + ---------- + tp + An imported ``SchemaBase`` class. + spec + The instance that failed validation. + + .. _more specific: + https://github.com/vega/altair/issues/2913#issuecomment-2571762700 + """ + return next(_iter_channels(tp, spec), tp) if _is_channel(spec) else tp + + class SchemaValidationError(jsonschema.ValidationError): _JS_TO_PY: ClassVar[Mapping[str, str]] = { "boolean": "bool", @@ -703,22 +739,19 @@ def _get_altair_class_for_error( Try to get the lowest class possible in the chart hierarchy so it can be displayed in the error message. This should lead to more informative error messages pointing the user closer to the source of the issue. + + If we did not find a suitable class based on traversing the path so we fall + back on the class of the top-level object which created the SchemaValidationError """ from altair import vegalite for prop_name in reversed(error.absolute_path): # Check if str as e.g. first item can be a 0 if isinstance(prop_name, str): - potential_class_name = prop_name[0].upper() + prop_name[1:] - cls = getattr(vegalite, potential_class_name, None) - if cls is not None: - break - else: - # Did not find a suitable class based on traversing the path so we fall - # back on the class of the top-level object which created - # the SchemaValidationError - cls = self.obj.__class__ - return cls + candidate = prop_name[0].upper() + prop_name[1:] + if tp := getattr(vegalite, candidate, None): + return _maybe_channel(tp, self.instance) + return type(self.obj) @staticmethod def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index 1059b35c1..86f7e925b 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -522,6 +522,11 @@ def chart_error_example__additional_datum_argument(): return alt.Chart().mark_point().encode(x=alt.datum(1, wrong_argument=1)) +def chart_error_example__additional_value_argument(): + # Error: `ColorValue` has no parameter named 'predicate' + return alt.Chart().mark_point().encode(color=alt.value("red", predicate=True)) + + def chart_error_example__invalid_value_type(): # Error: Value cannot be an integer in this case return ( @@ -812,15 +817,23 @@ def id_func_chart_error_example(val) -> str: ), ( chart_error_example__additional_datum_argument, - r"""`X` has no parameter named 'wrong_argument' + r"""`XDatum` has no parameter named 'wrong_argument' + + Existing parameter names are: + datum impute title + axis scale type + bandPosition stack + + See the help for `XDatum` to read the full description of these parameters$""", + ), + ( + chart_error_example__additional_value_argument, + r"""`ColorValue` has no parameter named 'predicate' Existing parameter names are: - shorthand bin scale timeUnit - aggregate field sort title - axis impute stack type - bandPosition + value condition - See the help for `X` to read the full description of these parameters$""", + See the help for `ColorValue` to read the full description of these parameters$""", ), ( chart_error_example__invalid_value_type, diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 24de6faae..60e8182a7 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -590,6 +590,42 @@ def _validator_values(errors: Iterable[ValidationError], /) -> Iterator[str]: yield cast("str", err.validator_value) +def _iter_channels(tp: type[Any], spec: Mapping[str, Any], /) -> Iterator[type[Any]]: + from altair import vegalite + + for channel_type in ("datum", "value"): + if channel_type in spec: + name = f"{tp.__name__}{channel_type.capitalize()}" + if narrower := getattr(vegalite, name, None): + yield narrower + + +def _is_channel(obj: Any) -> TypeIs[dict[str, Any]]: + props = {"datum", "value"} + return ( + _is_dict(obj) + and all(isinstance(k, str) for k in obj) + and not (props.isdisjoint(obj)) + ) + + +def _maybe_channel(tp: type[Any], spec: Any, /) -> type[Any]: + """ + Replace a channel type with a `more specific`_ one or passthrough unchanged. + + Parameters + ---------- + tp + An imported ``SchemaBase`` class. + spec + The instance that failed validation. + + .. _more specific: + https://github.com/vega/altair/issues/2913#issuecomment-2571762700 + """ + return next(_iter_channels(tp, spec), tp) if _is_channel(spec) else tp + + class SchemaValidationError(jsonschema.ValidationError): _JS_TO_PY: ClassVar[Mapping[str, str]] = { "boolean": "bool", @@ -701,22 +737,19 @@ def _get_altair_class_for_error( Try to get the lowest class possible in the chart hierarchy so it can be displayed in the error message. This should lead to more informative error messages pointing the user closer to the source of the issue. + + If we did not find a suitable class based on traversing the path so we fall + back on the class of the top-level object which created the SchemaValidationError """ from altair import vegalite for prop_name in reversed(error.absolute_path): # Check if str as e.g. first item can be a 0 if isinstance(prop_name, str): - potential_class_name = prop_name[0].upper() + prop_name[1:] - cls = getattr(vegalite, potential_class_name, None) - if cls is not None: - break - else: - # Did not find a suitable class based on traversing the path so we fall - # back on the class of the top-level object which created - # the SchemaValidationError - cls = self.obj.__class__ - return cls + candidate = prop_name[0].upper() + prop_name[1:] + if tp := getattr(vegalite, candidate, None): + return _maybe_channel(tp, self.instance) + return type(self.obj) @staticmethod def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: