From df48ff1792bd912087ba644cae695181e6f26e67 Mon Sep 17 00:00:00 2001 From: Yusuke Oda Date: Sat, 10 Sep 2022 08:45:20 +0900 Subject: [PATCH] Refactor FeatureType (#462) * fix bugs. * add test * move utils.serialization --- .../test-kg-prediction-user-defined.json | 3 +- ...ser-defined-features-prediction-short.json | 4 +- .../fb15k-237/user-defined.json | 3 +- .../sst2_tabclass/sst2-tabclass-dataset.json | 1 + .../sst2_tabreg/sst2-tabreg-dataset.json | 1 + explainaboard/analysis/analyses.py | 12 +- explainaboard/analysis/feature.py | 408 +++++++++++++----- explainaboard/analysis/feature_test.py | 201 +++++++++ explainaboard/info.py | 10 +- explainaboard/loaders/file_loader.py | 16 +- .../processors/argument_pair_extraction.py | 6 +- explainaboard/processors/cloze_generative.py | 6 +- .../processors/cloze_multiple_choice.py | 10 +- .../processors/conditional_generation.py | 6 +- explainaboard/processors/extractive_qa.py | 10 +- .../grammatical_error_correction.py | 8 +- .../processors/kg_link_tail_prediction.py | 10 +- explainaboard/processors/language_modeling.py | 4 +- explainaboard/processors/processor.py | 19 +- .../processors/qa_multiple_choice.py | 10 +- explainaboard/processors/qa_open_domain.py | 4 +- explainaboard/processors/qa_tat.py | 12 +- explainaboard/processors/sequence_labeling.py | 6 +- .../resources/dataset_custom_features.json | 9 +- .../legacy.py} | 3 + .../with_custom_feature.json | 3 +- .../output_with_features.json | 1 + .../output_fig_qa_customized_features.json | 3 +- integration_tests/resources_test.py | 6 +- 29 files changed, 625 insertions(+), 170 deletions(-) create mode 100644 explainaboard/analysis/feature_test.py rename explainaboard/{utils/serialization.py => serialization/legacy.py} (87%) diff --git a/data/system_outputs/fb15k-237/test-kg-prediction-user-defined.json b/data/system_outputs/fb15k-237/test-kg-prediction-user-defined.json index 251a77c3..25fcd22e 100644 --- a/data/system_outputs/fb15k-237/test-kg-prediction-user-defined.json +++ b/data/system_outputs/fb15k-237/test-kg-prediction-user-defined.json @@ -2,6 +2,7 @@ "metadata": { "custom_features": { "rel_type": { + "cls_name": "Value", "dtype": "string", "description": "symmetric or asymmetric", "num_buckets": 2 @@ -160,4 +161,4 @@ "example_id": "10" } ] -} \ No newline at end of file +} diff --git a/data/system_outputs/fb15k-237/test-user-defined-features-prediction-short.json b/data/system_outputs/fb15k-237/test-user-defined-features-prediction-short.json index cfb6c335..21e2a093 100644 --- a/data/system_outputs/fb15k-237/test-user-defined-features-prediction-short.json +++ b/data/system_outputs/fb15k-237/test-user-defined-features-prediction-short.json @@ -2,11 +2,13 @@ "metadata": { "custom_features": { "toy_string_feature": { + "cls_name": "Value", "dtype": "string", "description": "most specific (highest) entity type level of true tail entity", "num_buckets": 8 }, "toy_float_feature": { + "cls_name": "Value", "dtype": "float", "description": "just a toy float feature for testing", "num_buckets": 4 @@ -1215,4 +1217,4 @@ "example_id": "/m/099c8n\t/award/award_category/nominees./award/award_nomination/nominated_for\t/m/064lsn" } ] -} \ No newline at end of file +} diff --git a/data/system_outputs/fb15k-237/user-defined.json b/data/system_outputs/fb15k-237/user-defined.json index 251a77c3..25fcd22e 100644 --- a/data/system_outputs/fb15k-237/user-defined.json +++ b/data/system_outputs/fb15k-237/user-defined.json @@ -2,6 +2,7 @@ "metadata": { "custom_features": { "rel_type": { + "cls_name": "Value", "dtype": "string", "description": "symmetric or asymmetric", "num_buckets": 2 @@ -160,4 +161,4 @@ "example_id": "10" } ] -} \ No newline at end of file +} diff --git a/data/system_outputs/sst2_tabclass/sst2-tabclass-dataset.json b/data/system_outputs/sst2_tabclass/sst2-tabclass-dataset.json index a83487ec..d17a742a 100644 --- a/data/system_outputs/sst2_tabclass/sst2-tabclass-dataset.json +++ b/data/system_outputs/sst2_tabclass/sst2-tabclass-dataset.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "total_words": { + "cls_name": "Value", "dtype": "float", "description": "total number of words" } diff --git a/data/system_outputs/sst2_tabreg/sst2-tabreg-dataset.json b/data/system_outputs/sst2_tabreg/sst2-tabreg-dataset.json index 74d5d4bd..6fcc6e45 100644 --- a/data/system_outputs/sst2_tabreg/sst2-tabreg-dataset.json +++ b/data/system_outputs/sst2_tabreg/sst2-tabreg-dataset.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "total_words": { + "cls_name": "Value", "dtype": "float", "description": "total number of words" } diff --git a/explainaboard/analysis/analyses.py b/explainaboard/analysis/analyses.py index 41212179..e1340c9d 100644 --- a/explainaboard/analysis/analyses.py +++ b/explainaboard/analysis/analyses.py @@ -8,12 +8,12 @@ import explainaboard.analysis.bucketing from explainaboard.analysis.case import AnalysisCase, AnalysisCaseCollection -from explainaboard.analysis.feature import FeatureType +from explainaboard.analysis.feature import FeatureType, get_feature_type_serializer from explainaboard.analysis.performance import BucketPerformance, Performance from explainaboard.metrics.metric import Metric, MetricConfig, MetricStats from explainaboard.metrics.registry import metric_config_from_dict from explainaboard.utils.logging import get_logger -from explainaboard.utils.typing_utils import unwrap, unwrap_generator +from explainaboard.utils.typing_utils import narrow, unwrap, unwrap_generator @dataclass @@ -325,7 +325,13 @@ class AnalysisLevel: @staticmethod def from_dict(dikt: dict): - features = {k: FeatureType.from_dict(v) for k, v in dikt['features'].items()} + ft_serializer = get_feature_type_serializer() + + features = { + # See https://github.com/python/mypy/issues/4717 + k: narrow(FeatureType, ft_serializer.deserialize(v)) # type: ignore + for k, v in dikt['features'].items() + } metric_configs = [metric_config_from_dict(v) for v in dikt['metric_configs']] return AnalysisLevel( name=dikt['name'], diff --git a/explainaboard/analysis/feature.py b/explainaboard/analysis/feature.py index f9f941be..feecceed 100644 --- a/explainaboard/analysis/feature.py +++ b/explainaboard/analysis/feature.py @@ -1,123 +1,339 @@ from __future__ import annotations +from abc import ABCMeta, abstractmethod from collections.abc import Callable -import copy -from dataclasses import dataclass, field -from typing import Optional +from typing import Any, final, TypeVar +from explainaboard.serialization.registry import TypeRegistry +from explainaboard.serialization.serializers import PrimitiveSerializer +from explainaboard.serialization.types import Serializable, SerializableData +from explainaboard.utils.logging import get_logger +from explainaboard.utils.typing_utils import narrow + +_feature_type_registry = TypeRegistry[Serializable]() + +T = TypeVar("T") -def is_dataclass_dict(obj): - """ - this function is used to judge if the input dictionary contains 'cls_name' and - the value of 'cls_name' is in the feature type registry - :param obj: a python object with different potential type - :return: boolean variable - """ - return isinstance(obj, dict) and obj.get('cls_name') in FEATURETYPE_REGISTRY +def get_feature_type_serializer() -> PrimitiveSerializer: + """Returns a serializer object for FeatureTypes. -def _fromdict_inner(obj): + Returns: + A serializer object. """ - This function aim to construct a dataclass based on a potentially nested - dictionary (obj) recursively - :param obj: python object - :return: an object with dataclass + return PrimitiveSerializer(_feature_type_registry) + + +def _get_value(cls: type[T], data: dict[str, SerializableData], key: str) -> T | None: + """Helper to obtain typed value in the SerializableData dict. + + Args: + cls: Type to obtain. + data: Dict containing the target value. + key: Key of the target value. + + Returs: + Typed target value, or None if it does not exist. """ - # reconstruct the dataclass using the type tag - if is_dataclass_dict(obj): - result = {} - for name, data in obj.items(): - result[name] = _fromdict_inner(data) - return FEATURETYPE_REGISTRY[obj["cls_name"]](**result) - - # exactly the same as before (without the tuple clause) - elif isinstance(obj, (list, tuple)): - return type(obj)(_fromdict_inner(v) for v in obj) - elif isinstance(obj, dict): - return type(obj)( - (_fromdict_inner(k), _fromdict_inner(v)) for k, v in obj.items() + value = data.get(key) + return narrow(cls, value) if value is not None else None + + +class FeatureType(Serializable, metaclass=ABCMeta): + def __init__( + self, + *, + dtype: str | None = None, + description: str | None = None, + func: Callable[..., Any] | None = None, + require_training_set: bool | None = None, + ) -> None: + """Initializes FeatureType object. + + Args: + dtype: Data type specifier. + description: Description of this feature. + func: Function to calculate this feature from other features. + require_training_set: Whether this feature relies on the training samples. + """ + self._dtype = dtype + self._description = description + self._func = func + self._require_training_set = ( + require_training_set if require_training_set is not None else False ) - else: - return copy.deepcopy(obj) - - -@dataclass -class FeatureType: - # dtype: declare the data type of a feature, e.g. dict, list, float - dtype: Optional[str] = None - # cls_name: declare the class type of the feature: Sequence, Position - cls_name: Optional[str] = None - # description: descriptive information of a feature - description: Optional[str] = None - # func: the function that is used to calculate the feature - func: Optional[Callable] = None - # require_training_set: whether calculating this feature - # relies on the training samples - require_training_set: bool = False - @classmethod - def from_dict(cls, obj: dict) -> FeatureType: - # If the type is not specified use Value by default - if not isinstance(obj, dict): - raise ValueError(f'called from_dict on non-dict object "{obj}"') - elif not is_dataclass_dict(obj): - obj = copy.deepcopy(obj) - obj['cls_name'] = 'Value' - return _fromdict_inner(obj) + @abstractmethod + def __eq__(self, other: object) -> bool: + """Checks if two FeatureTypes are the same. + + Args: + other: FeatureType to compare. - def __post_init__(self): - self.cls_name: str = self.__class__.__name__ + Returns: + True if `other` can be treated as the same value with `self`, False + otherwise. + """ + ... + @final + def _eq_base(self, other: FeatureType) -> bool: + """Helper to compare two FeatureTypes have the same base members. -@dataclass + Args: + other: FeatureType to compare. + + Returns: + True if `other` has the same base members with `self`, False otherwise. + """ + return ( + self._dtype == other._dtype + and self._description == other._description + and self._func is other._func + and self._require_training_set == other._require_training_set + ) + + @final + @property + def dtype(self) -> str | None: + return self._dtype + + @final + @property + def description(self) -> str | None: + return self._description + + @final + @property + def func(self) -> Callable[..., Any] | None: + return self._func + + @final + @property + def require_training_set(self) -> bool: + return self._require_training_set + + def _serialize_base(self) -> dict[str, SerializableData]: + """Helper to serialize base members. + + Returns: + Serialized object containing base members. + """ + if self.func is not None: + # TODO(odashi): FeatureTypes with `func` can't be restored correctly from + # the serialized data. If you met this warning, it seems there could be + # potential bugs. + # Remove `func` member from FeatureType to correctly serialize these + # objects. + get_logger(__name__).warning("`func` member is not serializable.") + + return { + "dtype": self._dtype, + "description": self._description, + "require_training_set": self._require_training_set, + } + + +@final +@_feature_type_registry.register("Sequence") class Sequence(FeatureType): - feature: FeatureType = field(default_factory=FeatureType) + def __init__( + self, + *, + description: str | None = None, + func: Callable[..., Any] | None = None, + require_training_set: bool | None = None, + feature: FeatureType, + ) -> None: + """Initializes Sequence object. - def __post_init__(self): - super().__post_init__() - self.dtype = "list" + Args: + description: See FeatureType.__init__. + func: See FeatureType.__init__. + require_training_set: See FeatureType.__init__. + feature: Feature type of elements. + """ + super().__init__( + dtype="list", + description=description, + func=func, + require_training_set=require_training_set, + ) + self._feature = feature + def __eq__(self, other: object) -> bool: + """See FeatureType.__eq__.""" + return ( + isinstance(other, Sequence) + and self._eq_base(other) + and self._feature == other._feature + ) + + @property + def feature(self) -> FeatureType: + return self._feature + + def serialize(self) -> dict[str, SerializableData]: + """See Serializable.serialize.""" + data = self._serialize_base() + data["feature"] = self._feature + return data + + @classmethod + def deserialize(cls, data: dict[str, SerializableData]) -> Serializable: + """See Serializable.deserialize.""" + return cls( + description=_get_value(str, data, "description"), + func=None, + require_training_set=_get_value(bool, data, "require_training_set"), + # See https://github.com/python/mypy/issues/4717 + feature=narrow(FeatureType, data["feature"]), # type: ignore + ) -@dataclass + +@final +@_feature_type_registry.register("Dict") class Dict(FeatureType): - feature: dict[str, FeatureType] = field(default_factory=dict) + def __init__( + self, + *, + description: str | None = None, + func: Callable[..., Any] | None = None, + require_training_set: bool | None = None, + feature: dict[str, FeatureType], + ) -> None: + """Initializes Dict object. + + Args: + description: See FeatureType.__init__. + func: See FeatureType.__init__. + require_training_set: See FeatureType.__init__. + feature: Definitions of member types. + """ + super().__init__( + dtype="dict", + description=description, + func=func, + require_training_set=require_training_set, + ) + self._feature = feature - def __post_init__(self): - super().__post_init__() - self.dtype = "dict" + def __eq__(self, other: object) -> bool: + """See FeatureType.__eq__.""" + return ( + isinstance(other, Dict) + and self._eq_base(other) + and self._feature == other._feature + ) + + @property + def feature(self) -> dict[str, FeatureType]: + return self._feature + def serialize(self) -> dict[str, SerializableData]: + """See Serializable.serialize.""" + data = self._serialize_base() + data["feature"] = self._feature + return data -@dataclass -class Position(FeatureType): - positions: Optional[list] = None + @classmethod + def deserialize(cls, data: dict[str, SerializableData]) -> Serializable: + """See Serializable.deserialize.""" + feature = { + # See https://github.com/python/mypy/issues/4717 + k: narrow(FeatureType, v) # type: ignore + for k, v in narrow(dict, data["feature"]).items() + } - def __post_init__(self): - super().__post_init__() - self.cls_name: str = "Position" + return cls( + description=_get_value(str, data, "description"), + func=None, + require_training_set=_get_value(bool, data, "require_training_set"), + feature=feature, + ) -@dataclass +@final +@_feature_type_registry.register("Value") class Value(FeatureType): + def __init__( + self, + *, + dtype: str | None = None, + description: str | None = None, + func: Callable[..., Any] | None = None, + require_training_set: bool | None = None, + max_value: int | float | None = None, + min_value: int | float | None = None, + ) -> None: + """Initializes Value object. + + Args: + dtype: See FeatureType.__init__. + description: See FeatureType.__init__. + func: See FeatureType.__init__. + require_training_set: See FeatureType.__init__. + max_value: The maximum value (inclusive) of values with int/float dtype. + min_value: The minimum value (inclusive) of values with int/float dtype. + """ + # Fix inferred types. + if dtype == "double": + dtype = "float64" + elif dtype == "float": + dtype = "float32" + + super().__init__( + dtype=dtype, + description=description, + func=func, + require_training_set=require_training_set, + ) + self._max_value = max_value + self._min_value = min_value + + def __eq__(self, other: object) -> bool: + """See FeatureType.__eq__.""" + return ( + isinstance(other, Value) + and self._eq_base(other) + and self._max_value == other._max_value + and self._min_value == other._min_value + ) - # the maximum value (inclusive) of a feature with the - # dtype of `float` or `int` - max_value: Optional[float | int] = None - # the minimum value (inclusive) of a feature with the - # dtype of `float` or `int` - min_value: Optional[float | int] = None - - def __post_init__(self): - super().__post_init__() - if self.dtype == "double": # fix inferred type - self.dtype = "float64" - if self.dtype == "float": # fix inferred type - self.dtype = "float32" - - -FEATURETYPE_REGISTRY = { - "FeatureType": FeatureType, - "Sequence": Sequence, - "Dict": Dict, - "Position": Position, - "Value": Value, -} + @property + def max_value(self) -> int | float | None: + return self._max_value + + @property + def min_value(self) -> int | float | None: + return self._min_value + + def serialize(self) -> dict[str, SerializableData]: + """See Serializable.serialize.""" + data = self._serialize_base() + data["max_value"] = self._max_value + data["min_value"] = self._min_value + return data + + @classmethod + def deserialize(cls, data: dict[str, SerializableData]) -> Serializable: + """See Serializable.deserialize.""" + max_value = data.get("max_value") + min_value = data.get("min_value") + if max_value is not None and not isinstance(max_value, (int, float)): + raise ValueError( + f"Unexpected type of `max_value`: {type(max_value).__name__}" + ) + if min_value is not None and not isinstance(min_value, (int, float)): + raise ValueError( + f"Unexpected type of `min_value`: {type(min_value).__name__}" + ) + + return cls( + dtype=_get_value(str, data, "dtype"), + description=_get_value(str, data, "description"), + func=None, + require_training_set=_get_value(bool, data, "require_training_set"), + max_value=max_value, + min_value=min_value, + ) diff --git a/explainaboard/analysis/feature_test.py b/explainaboard/analysis/feature_test.py new file mode 100644 index 00000000..49c769cc --- /dev/null +++ b/explainaboard/analysis/feature_test.py @@ -0,0 +1,201 @@ +"""Tests for explainaboard.analysis.feature.""" + +import unittest + +from explainaboard.analysis.feature import ( + Dict, + get_feature_type_serializer, + Sequence, + Value, +) + + +class SequenceTest(unittest.TestCase): + def test_members(self) -> None: + def dummy_fn(): + return 123 + + feature = Sequence( + feature=Value(dtype="string"), + description="test", + func=dummy_fn, + require_training_set=True, + ) + self.assertEqual(feature.dtype, "list") + self.assertEqual(feature.description, "test") + self.assertIs(feature.func, dummy_fn) + self.assertEqual(feature.require_training_set, True) + self.assertEqual(feature.feature, Value(dtype="string")) + + def test_serialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Sequence( + feature=Value(dtype="string"), + description="test", + require_training_set=True, + ) + serialized = { + "cls_name": "Sequence", + "dtype": "list", + "description": "test", + "require_training_set": True, + "feature": { + "cls_name": "Value", + "dtype": "string", + "max_value": None, + "min_value": None, + "description": None, + "require_training_set": False, + }, + } + self.assertEqual(serializer.serialize(feature), serialized) + + def test_deserialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Sequence( + feature=Value(dtype="string"), + description="test", + require_training_set=True, + ) + serialized = { + "cls_name": "Sequence", + "dtype": "list", + "description": "test", + "require_training_set": True, + "feature": { + "cls_name": "Value", + "dtype": "string", + "max_value": None, + "min_value": None, + "description": None, + "require_training_set": False, + }, + } + self.assertEqual(serializer.deserialize(serialized), feature) + + +class DictTest(unittest.TestCase): + def test_members(self) -> None: + def dummy_fn(): + return 123 + + feature = Dict( + feature={"foo": Value(dtype="string")}, + description="test", + func=dummy_fn, + require_training_set=True, + ) + self.assertEqual(feature.dtype, "dict") + self.assertEqual(feature.description, "test") + self.assertIs(feature.func, dummy_fn) + self.assertEqual(feature.require_training_set, True) + self.assertEqual(feature.feature, {"foo": Value(dtype="string")}) + + def test_serialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Dict( + feature={"foo": Value(dtype="string")}, + description="test", + require_training_set=True, + ) + serialized = { + "cls_name": "Dict", + "dtype": "dict", + "description": "test", + "require_training_set": True, + "feature": { + "foo": { + "cls_name": "Value", + "dtype": "string", + "max_value": None, + "min_value": None, + "description": None, + "require_training_set": False, + }, + }, + } + self.assertEqual(serializer.serialize(feature), serialized) + + def test_deserialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Dict( + feature={"foo": Value(dtype="string")}, + description="test", + require_training_set=True, + ) + serialized = { + "cls_name": "Dict", + "dtype": "dict", + "description": "test", + "require_training_set": True, + "feature": { + "foo": { + "cls_name": "Value", + "dtype": "string", + "max_value": None, + "min_value": None, + "description": None, + "require_training_set": False, + }, + }, + } + self.assertEqual(serializer.deserialize(serialized), feature) + + +class ValueTest(unittest.TestCase): + def test_members(self) -> None: + def dummy_fn(): + return 123 + + feature = Value( + dtype="string", + description="test", + func=dummy_fn, + require_training_set=True, + max_value=123, + min_value=45, + ) + self.assertEqual(feature.dtype, "string") + self.assertEqual(feature.description, "test") + self.assertIs(feature.func, dummy_fn) + self.assertEqual(feature.require_training_set, True) + self.assertEqual(feature.max_value, 123) + self.assertEqual(feature.min_value, 45) + + def test_serialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Value( + dtype="string", + description="test", + require_training_set=True, + max_value=123, + min_value=45, + ) + serialized = { + "cls_name": "Value", + "dtype": "string", + "max_value": 123, + "min_value": 45, + "description": "test", + "require_training_set": True, + } + self.assertEqual(serializer.serialize(feature), serialized) + + def test_deserialize(self) -> None: + serializer = get_feature_type_serializer() + feature = Value( + dtype="string", + description="test", + require_training_set=True, + max_value=123, + min_value=45, + ) + serialized = { + "cls_name": "Value", + "dtype": "string", + "max_value": 123, + "min_value": 45, + "description": "test", + "require_training_set": True, + } + self.assertEqual(serializer.deserialize(serialized), feature) diff --git a/explainaboard/info.py b/explainaboard/info.py index 45d5ed3e..7661af4a 100644 --- a/explainaboard/info.py +++ b/explainaboard/info.py @@ -13,8 +13,8 @@ from explainaboard.analysis.case import AnalysisCase from explainaboard.analysis.result import Result from explainaboard.metrics.metric import MetricStats +from explainaboard.serialization.legacy import general_to_dict from explainaboard.utils.logging import get_logger -from explainaboard.utils.serialization import general_to_dict from explainaboard.utils.tokenizer import get_tokenizer_serializer, Tokenizer logger = get_logger(__name__) @@ -161,7 +161,7 @@ def print_as_json(self, file=None): data_dict = self.to_dict() self.replace_nonstring_keys(data_dict) try: - json.dump(data_dict, fp=file, indent=2, default=lambda x: x.json_repr()) + json.dump(data_dict, fp=file, indent=2) except TypeError as e: raise e @@ -169,11 +169,7 @@ def _dump_info(self, file): """SystemOutputInfo => JSON""" data_dict = self.to_dict() self.replace_nonstring_keys(data_dict) - file.write( - json.dumps(data_dict, indent=2, default=lambda x: x.json_repr()).encode( - "utf-8" - ) - ) + file.write(json.dumps(data_dict, indent=2).encode("utf-8")) @classmethod def from_directory(cls, sys_output_info_dir: str) -> "SysOutputInfo": diff --git a/explainaboard/loaders/file_loader.py b/explainaboard/loaders/file_loader.py index 2ede9961..8d518f41 100644 --- a/explainaboard/loaders/file_loader.py +++ b/explainaboard/loaders/file_loader.py @@ -24,7 +24,7 @@ from datalabs.features.features import ClassLabel, Sequence from explainaboard.analysis.analyses import Analysis -from explainaboard.analysis.feature import FeatureType +from explainaboard.analysis.feature import FeatureType, get_feature_type_serializer from explainaboard.constants import Source from explainaboard.utils.load_resources import get_customized_features from explainaboard.utils.preprocessor import Preprocessor @@ -134,8 +134,15 @@ def from_dict(cls, data: dict) -> FileLoaderMetadata: custom_features: dict[str, dict[str, FeatureType]] | None = None custom_analyses: list[Analysis] | None = None if 'custom_features' in data: + ft_serializer = get_feature_type_serializer() custom_features = { - k1: {k2: FeatureType.from_dict(v2) for k2, v2 in v1.items()} + k1: { + # See https://github.com/python/mypy/issues/4717 + k2: narrow( + FeatureType, ft_serializer.deserialize(v2) # type: ignore + ) + for k2, v2 in v1.items() + } for k1, v1 in data['custom_features'].items() } if 'custom_analyses' in data: @@ -543,6 +550,7 @@ def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: config = narrow(DatalabLoaderOption, data) + ft_serializer = get_feature_type_serializer() # load customized features from global config files customized_features_from_config = get_customized_features() @@ -558,7 +566,9 @@ def load_raw( f'{level_name}' ) parsed_level_feats = { - k: FeatureType.from_dict(v) for k, v in level_feats.items() + # See https://github.com/python/mypy/issues/4717 + k: narrow(FeatureType, ft_serializer.deserialize(v)) # type: ignore + for k, v in level_feats.items() } new_features = config.custom_features.get(level_name, {}) new_features.update(parsed_level_feats) diff --git a/explainaboard/processors/argument_pair_extraction.py b/explainaboard/processors/argument_pair_extraction.py index 4f91aaf4..54379adc 100644 --- a/explainaboard/processors/argument_pair_extraction.py +++ b/explainaboard/processors/argument_pair_extraction.py @@ -56,9 +56,9 @@ def default_metrics( def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "sentences": feature.Sequence(feature=feature.Value("string")), - "true_tags": feature.Sequence(feature=feature.Value("string")), - "pred_tags": feature.Sequence(feature=feature.Value("string")), + "sentences": feature.Sequence(feature=feature.Value(dtype="string")), + "true_tags": feature.Sequence(feature=feature.Value(dtype="string")), + "pred_tags": feature.Sequence(feature=feature.Value(dtype="string")), "num_sent": feature.Value( dtype="float", description="the number of sentences", diff --git a/explainaboard/processors/cloze_generative.py b/explainaboard/processors/cloze_generative.py index a43c5ef0..19e99b6c 100644 --- a/explainaboard/processors/cloze_generative.py +++ b/explainaboard/processors/cloze_generative.py @@ -33,9 +33,9 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features: dict[str, FeatureType] = { - "context": feature.Value("string"), - "question_mark": feature.Value("string"), - "hint": feature.Value("string"), + "context": feature.Value(dtype="string"), + "question_mark": feature.Value(dtype="string"), + "hint": feature.Value(dtype="string"), "answers": feature.Sequence(feature=feature.Value(dtype="string")), "context_length": feature.Value( dtype="float", diff --git a/explainaboard/processors/cloze_multiple_choice.py b/explainaboard/processors/cloze_multiple_choice.py index 5c266294..a19c8077 100644 --- a/explainaboard/processors/cloze_multiple_choice.py +++ b/explainaboard/processors/cloze_multiple_choice.py @@ -31,14 +31,14 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features: dict[str, FeatureType] = { - "context": feature.Value("string"), - "question_mark": feature.Value("string"), - "options": feature.Sequence(feature=feature.Value("string")), + "context": feature.Value(dtype="string"), + "question_mark": feature.Value(dtype="string"), + "options": feature.Sequence(feature=feature.Value(dtype="string")), "answers": feature.Sequence( feature=feature.Dict( feature={ - "text": feature.Value("string"), - "option_index": feature.Value("int32"), + "text": feature.Value(dtype="string"), + "option_index": feature.Value(dtype="int32"), } ) ), diff --git a/explainaboard/processors/conditional_generation.py b/explainaboard/processors/conditional_generation.py index ac550439..e9c746b2 100644 --- a/explainaboard/processors/conditional_generation.py +++ b/explainaboard/processors/conditional_generation.py @@ -46,9 +46,9 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: examp_features: dict[str, FeatureType] = { - "source": feature.Value("string"), - "reference": feature.Value("string"), - "hypothesis": feature.Value("string"), + "source": feature.Value(dtype="string"), + "reference": feature.Value(dtype="string"), + "hypothesis": feature.Value(dtype="string"), "source_length": feature.Value( dtype="float", description="length of the source", diff --git a/explainaboard/processors/extractive_qa.py b/explainaboard/processors/extractive_qa.py index 32e586f2..fa0b45ec 100644 --- a/explainaboard/processors/extractive_qa.py +++ b/explainaboard/processors/extractive_qa.py @@ -28,11 +28,11 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "context": feature.Value("string"), - "question": feature.Value("string"), - "id": feature.Value("string"), - "answers": feature.Sequence(feature=feature.Value("string")), - "predicted_answers": feature.Value("string"), + "context": feature.Value(dtype="string"), + "question": feature.Value(dtype="string"), + "id": feature.Value(dtype="string"), + "answers": feature.Sequence(feature=feature.Value(dtype="string")), + "predicted_answers": feature.Value(dtype="string"), "context_length": feature.Value( dtype="float", description="context length in tokens", diff --git a/explainaboard/processors/grammatical_error_correction.py b/explainaboard/processors/grammatical_error_correction.py index 3a229961..d5e9cf71 100644 --- a/explainaboard/processors/grammatical_error_correction.py +++ b/explainaboard/processors/grammatical_error_correction.py @@ -23,13 +23,13 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features: dict[str, FeatureType] = { - "text": feature.Value("string"), + "text": feature.Value(dtype="string"), "edits": feature.Dict( feature={ - "start_idx": feature.Sequence(feature=feature.Value("int32")), - "end_idx": feature.Sequence(feature=feature.Value("int32")), + "start_idx": feature.Sequence(feature=feature.Value(dtype="int32")), + "end_idx": feature.Sequence(feature=feature.Value(dtype="int32")), "corrections": feature.Sequence( - feature=feature.Sequence(feature=feature.Value("string")) + feature=feature.Sequence(feature=feature.Value(dtype="string")) ), } ), diff --git a/explainaboard/processors/kg_link_tail_prediction.py b/explainaboard/processors/kg_link_tail_prediction.py index f6ba2829..c47f4ae7 100644 --- a/explainaboard/processors/kg_link_tail_prediction.py +++ b/explainaboard/processors/kg_link_tail_prediction.py @@ -30,13 +30,13 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "true_head": feature.Value("string"), - "true_head_decipher": feature.Value("string"), + "true_head": feature.Value(dtype="string"), + "true_head_decipher": feature.Value(dtype="string"), "true_link": feature.Value(dtype="string", description="the relation type"), "true_tail": feature.Value(dtype="string"), - "true_tail_decipher": feature.Value("string"), - "predict": feature.Value("string"), - "predictions": feature.Sequence(feature=feature.Value("string")), + "true_tail_decipher": feature.Value(dtype="string"), + "predict": feature.Value(dtype="string"), + "predictions": feature.Sequence(feature=feature.Value(dtype="string")), "tail_entity_length": feature.Value( dtype="float", description="length of the tail entity in tokens", diff --git a/explainaboard/processors/language_modeling.py b/explainaboard/processors/language_modeling.py index 231c47b0..5c82964a 100644 --- a/explainaboard/processors/language_modeling.py +++ b/explainaboard/processors/language_modeling.py @@ -34,8 +34,8 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: examp_features: dict[str, FeatureType] = { - "text": feature.Value("string"), - "log_probs": feature.Value("string"), + "text": feature.Value(dtype="string"), + "log_probs": feature.Value(dtype="string"), "text_length": feature.Value( dtype="float", description="text length in tokens", diff --git a/explainaboard/processors/processor.py b/explainaboard/processors/processor.py index 9e1ed6a1..e3253b46 100644 --- a/explainaboard/processors/processor.py +++ b/explainaboard/processors/processor.py @@ -16,7 +16,7 @@ BucketAnalysisResult, ) from explainaboard.analysis.case import AnalysisCase -from explainaboard.analysis.feature import FeatureType +from explainaboard.analysis.feature import FeatureType, get_feature_type_serializer from explainaboard.analysis.performance import BucketPerformance, Performance from explainaboard.analysis.result import Result from explainaboard.info import OverallStatistics, SysOutputInfo @@ -28,7 +28,7 @@ ) from explainaboard.utils.logging import get_logger, progress from explainaboard.utils.tokenizer import get_default_tokenizer -from explainaboard.utils.typing_utils import unwrap, unwrap_generator +from explainaboard.utils.typing_utils import narrow, unwrap, unwrap_generator class Processor(metaclass=abc.ABCMeta): @@ -210,13 +210,16 @@ def _customize_analyses( if custom_analyses is not None: analyses.extend([Analysis.from_dict(v) for v in custom_analyses]) if custom_features is not None: + ft_serializer = get_feature_type_serializer() + for level_name, feature_content in custom_features.items(): - level_map[level_name].features.update( - { - k: (FeatureType.from_dict(v) if isinstance(v, dict) else v) - for k, v in feature_content.items() - } - ) + additional_features = { + k: narrow(FeatureType, ft_serializer.deserialize(v)) # type: ignore + if isinstance(v, dict) + else v + for k, v in feature_content.items() + } + level_map[level_name].features.update(additional_features) return analysis_levels, analyses @final diff --git a/explainaboard/processors/qa_multiple_choice.py b/explainaboard/processors/qa_multiple_choice.py index fe54dbc7..c1821703 100644 --- a/explainaboard/processors/qa_multiple_choice.py +++ b/explainaboard/processors/qa_multiple_choice.py @@ -28,14 +28,14 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "context": feature.Value("string"), - "question": feature.Value("string"), - "options": feature.Sequence(feature=feature.Value("string")), + "context": feature.Value(dtype="string"), + "question": feature.Value(dtype="string"), + "options": feature.Sequence(feature=feature.Value(dtype="string")), "answers": feature.Sequence( feature=feature.Dict( feature={ - "text": feature.Value("string"), - "option_index": feature.Value("int32"), + "text": feature.Value(dtype="string"), + "option_index": feature.Value(dtype="int32"), } ) ), diff --git a/explainaboard/processors/qa_open_domain.py b/explainaboard/processors/qa_open_domain.py index ea4f0c52..58b1d948 100644 --- a/explainaboard/processors/qa_open_domain.py +++ b/explainaboard/processors/qa_open_domain.py @@ -28,9 +28,9 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "question": feature.Value("string"), + "question": feature.Value(dtype="string"), # "question_types": feature.Sequence(feature=feature.Value("string")), - "answers": feature.Sequence(feature=feature.Value("string")), + "answers": feature.Sequence(feature=feature.Value(dtype="string")), "question_length": feature.Value( dtype="float", description="context length in tokens", diff --git a/explainaboard/processors/qa_tat.py b/explainaboard/processors/qa_tat.py index 72fdfcb5..0e6c076c 100644 --- a/explainaboard/processors/qa_tat.py +++ b/explainaboard/processors/qa_tat.py @@ -30,12 +30,12 @@ def task_type(cls) -> TaskType: def default_analysis_levels(self) -> list[AnalysisLevel]: features = { - "question": feature.Value("string"), - "context": feature.Sequence(feature=feature.Value("string")), - "table": feature.Sequence(feature=feature.Value("string")), - "true_answer": feature.Sequence(feature=feature.Value("string")), - "predicted_answer": feature.Sequence(feature=feature.Value("string")), - "predicted_answer_scale": feature.Value("string"), + "question": feature.Value(dtype="string"), + "context": feature.Sequence(feature=feature.Value(dtype="string")), + "table": feature.Sequence(feature=feature.Value(dtype="string")), + "true_answer": feature.Sequence(feature=feature.Value(dtype="string")), + "predicted_answer": feature.Sequence(feature=feature.Value(dtype="string")), + "predicted_answer_scale": feature.Value(dtype="string"), "answer_type": feature.Value( dtype="string", description="type of answer", diff --git a/explainaboard/processors/sequence_labeling.py b/explainaboard/processors/sequence_labeling.py index 66c36afe..649c4ec7 100644 --- a/explainaboard/processors/sequence_labeling.py +++ b/explainaboard/processors/sequence_labeling.py @@ -38,9 +38,9 @@ def __init__(self): def default_analysis_levels(self) -> list[AnalysisLevel]: examp_features: dict[str, FeatureType] = { - "tokens": feature.Sequence(feature=feature.Value("string")), - "true_tags": feature.Sequence(feature=feature.Value("string")), - "pred_tags": feature.Sequence(feature=feature.Value("string")), + "tokens": feature.Sequence(feature=feature.Value(dtype="string")), + "true_tags": feature.Sequence(feature=feature.Value(dtype="string")), + "pred_tags": feature.Sequence(feature=feature.Value(dtype="string")), "text_length": feature.Value( dtype="float", description="text length in tokens", diff --git a/explainaboard/resources/dataset_custom_features.json b/explainaboard/resources/dataset_custom_features.json index ca5d1c86..77db2990 100644 --- a/explainaboard/resources/dataset_custom_features.json +++ b/explainaboard/resources/dataset_custom_features.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "label": { + "cls_name": "Value", "dtype": "string", "description": "the true label" } @@ -22,10 +23,12 @@ "custom_features": { "example": { "answer_from": { + "cls_name": "Value", "dtype": "string", "description": "where does the answer from" }, "q_order": { + "cls_name": "Value", "dtype": "float", "description": "the order of question" } @@ -52,18 +55,22 @@ "custom_features": { "example": { "overall_setting": { + "cls_name": "Value", "dtype": "string", "description": "overall setting" }, "dataset_name": { + "cls_name": "Value", "dtype": "string", "description": "dataset" }, "model_name": { + "cls_name": "Value", "dtype": "string", "description": "model name" }, "target_lang_data_size": { + "cls_name": "Value", "dtype": "float", "description": "dataset size for target language" } @@ -100,4 +107,4 @@ } ] } -} \ No newline at end of file +} diff --git a/explainaboard/utils/serialization.py b/explainaboard/serialization/legacy.py similarity index 87% rename from explainaboard/utils/serialization.py rename to explainaboard/serialization/legacy.py index 25a9aee3..f18abf5d 100644 --- a/explainaboard/utils/serialization.py +++ b/explainaboard/serialization/legacy.py @@ -4,11 +4,14 @@ import dataclasses from inspect import getsource +from explainaboard.analysis.feature import FeatureType, get_feature_type_serializer from explainaboard.utils.tokenizer import get_tokenizer_serializer, Tokenizer def general_to_dict(data): """DEPRECATED: do not use this function for new implementations.""" + if isinstance(data, FeatureType): + return get_feature_type_serializer().serialize(data) if isinstance(data, Tokenizer): return get_tokenizer_serializer().serialize(data) elif hasattr(data, 'to_dict'): diff --git a/integration_tests/artifacts/kg_link_tail_prediction/with_custom_feature.json b/integration_tests/artifacts/kg_link_tail_prediction/with_custom_feature.json index 981fbd43..16e962ad 100644 --- a/integration_tests/artifacts/kg_link_tail_prediction/with_custom_feature.json +++ b/integration_tests/artifacts/kg_link_tail_prediction/with_custom_feature.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "rel_type": { + "cls_name": "Value", "dtype": "string", "description": "symmetric or asymmetric" } @@ -180,4 +181,4 @@ "example_id": "10" } ] -} \ No newline at end of file +} diff --git a/integration_tests/artifacts/machine_translation/output_with_features.json b/integration_tests/artifacts/machine_translation/output_with_features.json index 14faa63b..86450dee 100644 --- a/integration_tests/artifacts/machine_translation/output_with_features.json +++ b/integration_tests/artifacts/machine_translation/output_with_features.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "num_capital_letters": { + "cls_name": "Value", "dtype": "float" } } diff --git a/integration_tests/artifacts/qa_multiple_choice/output_fig_qa_customized_features.json b/integration_tests/artifacts/qa_multiple_choice/output_fig_qa_customized_features.json index 120703a8..cfb1c2a4 100644 --- a/integration_tests/artifacts/qa_multiple_choice/output_fig_qa_customized_features.json +++ b/integration_tests/artifacts/qa_multiple_choice/output_fig_qa_customized_features.json @@ -3,6 +3,7 @@ "custom_features": { "example": { "commonsense_category": { + "cls_name": "Value", "dtype": "string", "description": "common sense category" } @@ -68,4 +69,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/integration_tests/resources_test.py b/integration_tests/resources_test.py index 14c33bb1..6374800f 100644 --- a/integration_tests/resources_test.py +++ b/integration_tests/resources_test.py @@ -12,7 +12,11 @@ def test_get_customized_features(self): { "custom_features": { "example": { - "label": {"dtype": "string", "description": "the true label"} + "label": { + "cls_name": "Value", + "dtype": "string", + "description": "the true label", + } } }, "custom_analyses": [