Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Processor._customize_analyses #581

Merged
merged 2 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions explainaboard/analysis/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,22 @@ def deserialize(cls, data: dict[str, SerializableData]) -> Serializable:
features=features,
metric_configs=metric_configs,
)

@final
def replace_metric_configs(
self, metric_configs: dict[str, MetricConfig]
) -> AnalysisLevel:
"""Creates a new AnalysisLevel with replacing the set of MetricConfigs.

Args:
metric_configs:
New dict of MetricConfigs to replace the original member.

Returns:
A new MetricConfigs with the replaced metric_configs.
"""
return AnalysisLevel(
name=self.name,
features=self.features,
metric_configs=metric_configs,
)
12 changes: 12 additions & 0 deletions explainaboard/analysis/analyses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,15 @@ def test_serialization(self) -> None:
}
self.assertEqual(serializer.serialize(level), level_serialized)
self.assertEqual(serializer.deserialize(level_serialized), level)

def test_replace_metric_configs(self) -> None:
level = AnalysisLevel(
name="test", features={}, metric_configs={"foo": AccuracyConfig()}
)
new_level = level.replace_metric_configs({"bar": AccuracyConfig()})
self.assertIsNot(level, new_level)
self.assertIn("foo", level.metric_configs)
self.assertNotIn("bar", level.metric_configs)
self.assertNotIn("foo", new_level.metric_configs)
self.assertIn("bar", new_level.metric_configs)
self.assertIsNot(level.metric_configs["foo"], new_level.metric_configs["bar"])
23 changes: 23 additions & 0 deletions explainaboard/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
import copy
from dataclasses import dataclass
from typing import Any, final, Optional, TypeVar

Expand Down Expand Up @@ -222,6 +223,28 @@ def to_metric(self) -> Metric:
"""
...

@final
def replace_languages(
self, source_language: str | None, target_language: str | None
) -> MetricConfig:
"""Creates a new MetricConfig with specified source/target languages.

Args:
source_language: New source language.
target_language: New target language.

Returns:
A new MetricConfig object, in which source/target_language are replaced to
the new config, while other values are maintained.
"""
# NOTE(odashi): Since this class can be inherited, we need to collect every
# member not listed in this class.
# TODO(odashi): Avoid copy.
copied = copy.deepcopy(self)
copied.source_language = source_language
copied.target_language = target_language
return copied


class MetricStats(metaclass=abc.ABCMeta):
"""Interface of sufficient statistics necessary to calculate a metric."""
Expand Down
14 changes: 13 additions & 1 deletion explainaboard/metrics/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,19 @@ def test_get_value_or_none(self) -> None:
self.assertIs(result.get_value_or_none(ConfidenceInterval, "baz"), ci)


class MetricConfigTest(unittest.TestCase):
def test_replace_languages(self) -> None:
config = _DummyMetricConfig(source_language="xx", target_language="yy")
new_config = config.replace_languages(
source_language="aa", target_language="bb"
)
self.assertIsNot(new_config, config)
self.assertEqual(config.source_language, "xx")
self.assertEqual(config.target_language, "yy")
self.assertEqual(new_config.source_language, "aa")
self.assertEqual(new_config.target_language, "bb")


class MetricTest(unittest.TestCase):
def test_aggregate_stats_1dim(self) -> None:
metric = _DummyMetric(_DummyMetricConfig("test"))
Expand Down Expand Up @@ -406,7 +419,6 @@ def test_evaluate_from_stats_bootstrap_with_ci(self) -> None:
result = metric.evaluate_from_stats(stats, confidence_alpha=0.05)
self.assertEqual(result.get_value(Score, "score").value, 3.0)
ci = result.get_value(ConfidenceInterval, "score_ci")
print(dataclasses.asdict(ci))
# TODO(odahsi): According to the current default settings of bootstrapping,
# estimated confidence intervals tends to become very wide for small data
self.assertAlmostEqual(ci.low, 1.8)
Expand Down
44 changes: 30 additions & 14 deletions explainaboard/processors/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,26 +250,35 @@ def _customize_analyses(

Args:
custom_features: the features to customize
metric_configs: additional metric configurations. Keys are analysis level
name and metric name.
metric_configs: MetricConfgs to replace.
If `metric_configs[analysis_level_name]` has a dict, it is used instead
of the default MetricConfigs associated to `analysis_level_name`.
custom_analyses: the analyses to customize

Returns:
Customized analyses.
"""
analysis_levels = self.default_analysis_levels()
analyses = self.default_analyses()
for level in analysis_levels:
for name, config in metric_configs.get(level.name, {}).items():
level.metric_configs[name] = config
for config in level.metric_configs.values():
config.source_language = sys_info.source_language
config.target_language = sys_info.target_language
analysis_levels: list[AnalysisLevel] = []

# Replaces MetricConfigs for each AnalysisLevel.
for level in self.default_analysis_levels():
metric_configs_orig = metric_configs.get(level.name, level.metric_configs)
metric_configs_replaced = {
name: config.replace_languages(
source_language=sys_info.source_language,
target_language=sys_info.target_language,
)
for name, config in metric_configs_orig.items()
}
analysis_levels.append(
level.replace_metric_configs(metric_configs_replaced)
)

level_map = {x.name: x for x in analysis_levels}

serializer = PrimitiveSerializer()

analyses = self.default_analyses()
analyses.extend(
[
narrow(Analysis, serializer.deserialize(v)) # type: ignore
Expand Down Expand Up @@ -516,12 +525,19 @@ def get_overall_statistics(
custom_features: dict = metadata.get('custom_features', {})
custom_analyses: list = metadata.get('custom_analyses', [])

metric_configs: dict[str, dict[str, MetricConfig]] = {
"example": metadata.get('metric_configs', {})
}
metric_configs = metadata.get("metric_configs")
if metric_configs is not None:
metric_configs_dict = {
"example": {
lyuyangh marked this conversation as resolved.
Show resolved Hide resolved
narrow(str, k): narrow(MetricConfig, v) # type: ignore
for k, v in metric_configs.items()
}
}
else:
metric_configs_dict = {}

sys_info.analysis_levels, sys_info.analyses = self._customize_analyses(
sys_info, custom_features, metric_configs, custom_analyses
sys_info, custom_features, metric_configs_dict, custom_analyses
)

# get scoring statistics
Expand Down