Skip to content

Commit

Permalink
improve metircs.update
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 20, 2024
1 parent b18935f commit fcc57bf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 44 deletions.
52 changes: 27 additions & 25 deletions danling/metrics/metric_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ class MultiTaskMetricMeters(MultiTaskAverageMeters):
>>> metrics.update({"dataset1.cls": [[0.1, 0.4, 0.6, 0.8], [0, 0, 1, 0]], "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 0, 1]}})
>>> f"{metrics:.4f}"
'dataset1.cls: acc: 0.7500 (0.6250)\ndataset2: acc: 0.7500 (0.5000)'
>>> metrics.update(dict(loss="")) # doctest: +ELLIPSIS
Traceback (most recent call last):
ValueError: Metric loss not found in ...
""" # noqa: E501

def __init__(self, *args, **kwargs):
Expand All @@ -225,31 +228,28 @@ def update( # type: ignore[override] # pylint: disable=W0221
"""

for metric, value in values.items():
if isinstance(value, (Mapping, Sequence)):
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetricMeters):
for met in self[metric].all_values():
if isinstance(value, Mapping):
met.update(**value)
elif isinstance(value, Sequence):
met.update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (MetricMeters, MetricMeter)):
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetricMeters):
for met in self[metric].all_values():
if isinstance(value, Mapping):
self[metric].update(**value)
met.update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
met.update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (MetricMeters, MetricMeter)):
if isinstance(value, Mapping):
self[metric].update(**value)
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, "
"or MetricMeter, but got {type(self[metric])}"
)
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(f"Expected values to be a Mapping or Sequence, but got {type(value)}")
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, "
f"or MetricMeter, but got {type(self[metric])}"
)

# MultiTaskAverageMeters.get is hacked
def get(self, name: Any, default=None) -> Any:
Expand All @@ -258,10 +258,12 @@ def get(self, name: Any, default=None) -> Any:
def set( # pylint: disable=W0237
self,
name: str,
meter: MetricMeter | MetricMeters | Callable, # type: ignore[override]
metric: MetricMeter | MetricMeters | Callable, # type: ignore[override]
) -> None:
if callable(meter):
meter = MetricMeter(meter)
if not isinstance(meter, (MetricMeter, MetricMeters)):
raise ValueError(f"Expected meter to be an instance of MetricMeter or MetricMeters, but got {type(meter)}")
super().set(name, meter)
if callable(metric):
metric = MetricMeter(metric)
if not isinstance(metric, (MetricMeter, MetricMeters)):
raise ValueError(
f"Expected {metric} to be an instance of MetricMeter or MetricMeters, but got {type(metric)}"
)
super().set(name, metric)
55 changes: 36 additions & 19 deletions danling/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,14 +476,12 @@ class MultiTaskMetrics(MultiTaskDict):
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset1.reg": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0.2, 0.4, 0.6, 0.8]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 1, 0]}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.6667 (0.7000)\tauprc: 0.5000 (0.5556)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)'
>>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8]}}})
Traceback (most recent call last):
ValueError: Expected values to be a flat dictionary, but got <class 'dict'>
This is likely due to nested dictionary in the values.
Nested dictionaries cannot be processed due to the method's design, which uses Mapping to pass both input and target. Ensure your input is a flat dictionary or a single value.
>>> metrics.update(dict(loss=""))
>>> metrics.update({"dataset1": {"cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [1, 0, 1, 0]}}})
>>> f"{metrics:.4f}"
'dataset1.cls: auroc: 0.2500 (0.5286)\tauprc: 0.5000 (0.4789)\ndataset1.reg: pearson: 0.9898 (0.9146)\tspearman: 1.0000 (0.9222)\ndataset2: auroc: 0.6667 (0.7333)\tauprc: 0.5000 (0.7000)'
>>> metrics.update(dict(loss="")) # doctest: +ELLIPSIS
Traceback (most recent call last):
ValueError: Expected values to be a flat dictionary, but got <class 'str'>
ValueError: Metric loss not found in ...
""" # noqa: E501

def __init__(self, *args, **kwargs):
Expand All @@ -501,17 +499,36 @@ def update(self, values: Mapping[str, Mapping[str, Tensor | NestedTensor | Seque
"""

for metric, value in values.items():
if isinstance(value, Mapping):
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
try:
if metric not in self:
raise ValueError(f"Metric {metric} not found in {self}")
if isinstance(self[metric], MultiTaskMetrics):
for name, met in self[metric].items():
if name in value:
val = value[name]
if isinstance(value, Mapping):
met.update(**val)
elif isinstance(value, Sequence):
met.update(*val)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
elif isinstance(self[metric], (Metrics, Metric)):
if isinstance(value, Mapping):
self[metric].update(**value)
except TypeError:
raise ValueError(
f"Expected values to be a flat dictionary, but got {type(value)}\n"
"This is likely due to nested dictionary in the values.\n"
"Nested dictionaries cannot be processed due to the method's design, which uses Mapping "
"to pass both input and target. Ensure your input is a flat dictionary or a single value."
) from None
elif isinstance(value, Sequence):
self[metric].update(*value)
else:
raise ValueError(f"Expected value to be a Mapping or Sequence, but got {type(value)}")
else:
raise ValueError(f"Expected values to be a flat dictionary, but got {type(value)}")
raise ValueError(
f"Expected {metric} to be an instance of MultiTaskMetrics, Metrics, or Metric, "
"but got {type(self[metric])}"
)

def set( # pylint: disable=W0237
self,
name: str,
metric: Metrics | Metric, # type: ignore[override]
) -> None:
if not isinstance(metric, (Metrics, Metric)):
raise ValueError(f"Expected {metric} to be an instance of Metrics or Metric, but got {type(metric)}")
super().set(name, metric)

0 comments on commit fcc57bf

Please sign in to comment.