diff --git a/explainaboard/analysis/analyses.py b/explainaboard/analysis/analyses.py index 75ff1ecf..8536729f 100644 --- a/explainaboard/analysis/analyses.py +++ b/explainaboard/analysis/analyses.py @@ -859,3 +859,22 @@ def deserialize(cls, data: dict[str, SerializableData]) -> Serializable: features=features, metric_configs=metric_configs, ) + + @final + def replace_metric_configs( + self, metric_configs: dict[str, MetricConfig] + ) -> AnalysisLevel: + """Creates a new AnalysisLevel with replacing the set of MetricConfigs. + + Args: + metric_configs: + New dict of MetricConfigs to replace the original member. + + Returns: + A new MetricConfigs with the replaced metric_configs. + """ + return AnalysisLevel( + name=self.name, + features=self.features, + metric_configs=metric_configs, + ) diff --git a/explainaboard/analysis/analyses_test.py b/explainaboard/analysis/analyses_test.py index 328a1610..8bc4c344 100644 --- a/explainaboard/analysis/analyses_test.py +++ b/explainaboard/analysis/analyses_test.py @@ -525,3 +525,15 @@ def test_serialization(self) -> None: } self.assertEqual(serializer.serialize(level), level_serialized) self.assertEqual(serializer.deserialize(level_serialized), level) + + def test_replace_metric_configs(self) -> None: + level = AnalysisLevel( + name="test", features={}, metric_configs={"foo": AccuracyConfig()} + ) + new_level = level.replace_metric_configs({"bar": AccuracyConfig()}) + self.assertIsNot(level, new_level) + self.assertIn("foo", level.metric_configs) + self.assertNotIn("bar", level.metric_configs) + self.assertNotIn("foo", new_level.metric_configs) + self.assertIn("bar", new_level.metric_configs) + self.assertIsNot(level.metric_configs["foo"], new_level.metric_configs["bar"]) diff --git a/explainaboard/metrics/metric.py b/explainaboard/metrics/metric.py index 712ae6ce..dfddc95e 100644 --- a/explainaboard/metrics/metric.py +++ b/explainaboard/metrics/metric.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +import copy from dataclasses import dataclass from typing import Any, final, Optional, TypeVar @@ -222,6 +223,28 @@ def to_metric(self) -> Metric: """ ... + @final + def replace_languages( + self, source_language: str | None, target_language: str | None + ) -> MetricConfig: + """Creates a new MetricConfig with specified source/target languages. + + Args: + source_language: New source language. + target_language: New target language. + + Returns: + A new MetricConfig object, in which source/target_language are replaced to + the new config, while other values are maintained. + """ + # NOTE(odashi): Since this class can be inherited, we need to collect every + # member not listed in this class. + # TODO(odashi): Avoid copy. + copied = copy.deepcopy(self) + copied.source_language = source_language + copied.target_language = target_language + return copied + class MetricStats(metaclass=abc.ABCMeta): """Interface of sufficient statistics necessary to calculate a metric.""" diff --git a/explainaboard/metrics/metric_test.py b/explainaboard/metrics/metric_test.py index 1708cbf5..96504b5d 100644 --- a/explainaboard/metrics/metric_test.py +++ b/explainaboard/metrics/metric_test.py @@ -191,6 +191,19 @@ def test_get_value_or_none(self) -> None: self.assertIs(result.get_value_or_none(ConfidenceInterval, "baz"), ci) +class MetricConfigTest(unittest.TestCase): + def test_replace_languages(self) -> None: + config = _DummyMetricConfig(source_language="xx", target_language="yy") + new_config = config.replace_languages( + source_language="aa", target_language="bb" + ) + self.assertIsNot(new_config, config) + self.assertEqual(config.source_language, "xx") + self.assertEqual(config.target_language, "yy") + self.assertEqual(new_config.source_language, "aa") + self.assertEqual(new_config.target_language, "bb") + + class MetricTest(unittest.TestCase): def test_aggregate_stats_1dim(self) -> None: metric = _DummyMetric(_DummyMetricConfig("test")) @@ -406,7 +419,6 @@ def test_evaluate_from_stats_bootstrap_with_ci(self) -> None: result = metric.evaluate_from_stats(stats, confidence_alpha=0.05) self.assertEqual(result.get_value(Score, "score").value, 3.0) ci = result.get_value(ConfidenceInterval, "score_ci") - print(dataclasses.asdict(ci)) # TODO(odahsi): According to the current default settings of bootstrapping, # estimated confidence intervals tends to become very wide for small data self.assertAlmostEqual(ci.low, 1.8) diff --git a/explainaboard/processors/processor.py b/explainaboard/processors/processor.py index 7d2e2b5e..eaa55524 100644 --- a/explainaboard/processors/processor.py +++ b/explainaboard/processors/processor.py @@ -250,26 +250,35 @@ def _customize_analyses( Args: custom_features: the features to customize - metric_configs: additional metric configurations. Keys are analysis level - name and metric name. + metric_configs: MetricConfgs to replace. + If `metric_configs[analysis_level_name]` has a dict, it is used instead + of the default MetricConfigs associated to `analysis_level_name`. custom_analyses: the analyses to customize Returns: Customized analyses. """ - analysis_levels = self.default_analysis_levels() - analyses = self.default_analyses() - for level in analysis_levels: - for name, config in metric_configs.get(level.name, {}).items(): - level.metric_configs[name] = config - for config in level.metric_configs.values(): - config.source_language = sys_info.source_language - config.target_language = sys_info.target_language + analysis_levels: list[AnalysisLevel] = [] + + # Replaces MetricConfigs for each AnalysisLevel. + for level in self.default_analysis_levels(): + metric_configs_orig = metric_configs.get(level.name, level.metric_configs) + metric_configs_replaced = { + name: config.replace_languages( + source_language=sys_info.source_language, + target_language=sys_info.target_language, + ) + for name, config in metric_configs_orig.items() + } + analysis_levels.append( + level.replace_metric_configs(metric_configs_replaced) + ) level_map = {x.name: x for x in analysis_levels} serializer = PrimitiveSerializer() + analyses = self.default_analyses() analyses.extend( [ narrow(Analysis, serializer.deserialize(v)) # type: ignore @@ -516,12 +525,19 @@ def get_overall_statistics( custom_features: dict = metadata.get('custom_features', {}) custom_analyses: list = metadata.get('custom_analyses', []) - metric_configs: dict[str, dict[str, MetricConfig]] = { - "example": metadata.get('metric_configs', {}) - } + metric_configs = metadata.get("metric_configs") + if metric_configs is not None: + metric_configs_dict = { + "example": { + narrow(str, k): narrow(MetricConfig, v) # type: ignore + for k, v in metric_configs.items() + } + } + else: + metric_configs_dict = {} sys_info.analysis_levels, sys_info.analyses = self._customize_analyses( - sys_info, custom_features, metric_configs, custom_analyses + sys_info, custom_features, metric_configs_dict, custom_analyses ) # get scoring statistics