diff --git a/danling/metrics/metric_meter.py b/danling/metrics/metric_meter.py index 09274078..ecddafd8 100644 --- a/danling/metrics/metric_meter.py +++ b/danling/metrics/metric_meter.py @@ -207,6 +207,9 @@ class MultiTaskMetricMeters(MultiTaskAverageMeters): >>> metrics.update({"dataset1.cls": [[0.1, 0.4, 0.6, 0.8], [0, 0, 1, 0]], "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 0, 1]}}) >>> f"{metrics:.4f}" 'dataset1.cls: acc: 0.7500 (0.6250)\ndataset2: acc: 0.7500 (0.5000)' + >>> metrics.update(dict(loss="")) # doctest: +ELLIPSIS + Traceback (most recent call last): + ValueError: Metric loss not found in ... """ # noqa: E501 def __init__(self, *args, **kwargs): @@ -225,31 +228,28 @@ def update( # type: ignore[override] # pylint: disable=W0221 """ for metric, value in values.items(): - if isinstance(value, (Mapping, Sequence)): - if metric not in self: - raise ValueError(f"Metric {metric} not found in {self}") - if isinstance(self[metric], MultiTaskMetricMeters): - for met in self[metric].all_values(): - if isinstance(value, Mapping): - met.update(**value) - elif isinstance(value, Sequence): - met.update(*value) - else: - raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}") - elif isinstance(self[metric], (MetricMeters, MetricMeter)): + if metric not in self: + raise ValueError(f"Metric {metric} not found in {self}") + if isinstance(self[metric], MultiTaskMetricMeters): + for met in self[metric].all_values(): if isinstance(value, Mapping): - self[metric].update(**value) + met.update(**value) elif isinstance(value, Sequence): - self[metric].update(*value) + met.update(*value) else: raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}") + elif isinstance(self[metric], (MetricMeters, MetricMeter)): + if isinstance(value, Mapping): + self[metric].update(**value) + elif isinstance(value, Sequence): + self[metric].update(*value) else: - raise ValueError( - f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, " - "or MetricMeter, but got {type(self[metric])}" - ) + raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}") else: - raise ValueError(f"Expected values to be a Mapping or Sequence, but got {type(value)}") + raise ValueError( + f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, " + f"or MetricMeter, but got {type(self[metric])}" + ) # MultiTaskAverageMeters.get is hacked def get(self, name: Any, default=None) -> Any: @@ -258,10 +258,12 @@ def get(self, name: Any, default=None) -> Any: def set( # pylint: disable=W0237 self, name: str, - meter: MetricMeter | MetricMeters | Callable, # type: ignore[override] + metric: MetricMeter | MetricMeters | Callable, # type: ignore[override] ) -> None: - if callable(meter): - meter = MetricMeter(meter) - if not isinstance(meter, (MetricMeter, MetricMeters)): - raise ValueError(f"Expected meter to be an instance of MetricMeter or MetricMeters, but got {type(meter)}") - super().set(name, meter) + if callable(metric): + metric = MetricMeter(metric) + if not isinstance(metric, (MetricMeter, MetricMeters)): + raise ValueError( + f"Expected {metric} to be an instance of MetricMeter or MetricMeters, but got {type(metric)}" + ) + super().set(name, metric) diff --git a/danling/metrics/metrics.py b/danling/metrics/metrics.py index 4dd27b1f..5a781e41 100644 --- a/danling/metrics/metrics.py +++ b/danling/metrics/metrics.py @@ -476,14 +476,12 @@ class MultiTaskMetrics(MultiTaskDict): >>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset1.reg": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0.2, 0.4, 0.6, 0.8]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 1, 0]}}) >>> f"{metrics:.4f}" 'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)' - >>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8]}}}) - Traceback (most recent call last): - ValueError: Expected values to be a flat dictionary, but got - This is likely due to nested dictionary in the values. - Nested dictionaries cannot be processed due to the method's design, which uses Mapping to pass both input and target. Ensure your input is a flat dictionary or a single value. - >>> metrics.update(dict(loss="")) + >>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [1, 0, 1, 0]}}}) + >>> f"{metrics:.4f}" + 'dataset1.cls: auroc: 0.2500 (0.5286)\tauprc: 0.5000 (0.4789)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)' + >>> metrics.update(dict(loss="")) # doctest: +ELLIPSIS Traceback (most recent call last): - ValueError: Expected values to be a flat dictionary, but got + ValueError: Metric loss not found in ... """ # noqa: E501 def __init__(self, *args, **kwargs): @@ -501,17 +499,36 @@ def update(self, values: Mapping[str, Mapping[str, Tensor | NestedTensor | Seque """ for metric, value in values.items(): - if isinstance(value, Mapping): - if metric not in self: - raise ValueError(f"Metric {metric} not found in {self}") - try: + if metric not in self: + raise ValueError(f"Metric {metric} not found in {self}") + if isinstance(self[metric], MultiTaskMetrics): + for name, met in self[metric].items(): + if name in value: + val = value[name] + if isinstance(value, Mapping): + met.update(**val) + elif isinstance(value, Sequence): + met.update(*val) + else: + raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}") + elif isinstance(self[metric], (Metrics, Metric)): + if isinstance(value, Mapping): self[metric].update(**value) - except TypeError: - raise ValueError( - f"Expected values to be a flat dictionary, but got {type(value)}\n" - "This is likely due to nested dictionary in the values.\n" - "Nested dictionaries cannot be processed due to the method's design, which uses Mapping " - "to pass both input and target. Ensure your input is a flat dictionary or a single value." - ) from None + elif isinstance(value, Sequence): + self[metric].update(*value) + else: + raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}") else: - raise ValueError(f"Expected values to be a flat dictionary, but got {type(value)}") + raise ValueError( + f"Expected {metric} to be an instance of MultiTaskMetrics, Metrics, or Metric, " + "but got {type(self[metric])}" + ) + + def set( # pylint: disable=W0237 + self, + name: str, + metric: Metrics | Metric, # type: ignore[override] + ) -> None: + if not isinstance(metric, (Metrics, Metric)): + raise ValueError(f"Expected {metric} to be an instance of Metrics or Metric, but got {type(metric)}") + super().set(name, metric)