From 178b5187cc6cc672d1221d6c0f297cde67498a1c Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 10 Oct 2024 11:47:43 -0700 Subject: [PATCH] Improve json serialization to accomodate numpy float32 --- .../widgets/validation_json_generator.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/nvflare/app_common/widgets/validation_json_generator.py b/nvflare/app_common/widgets/validation_json_generator.py index 03c950e46f..35fce6aa8f 100644 --- a/nvflare/app_common/widgets/validation_json_generator.py +++ b/nvflare/app_common/widgets/validation_json_generator.py @@ -14,6 +14,9 @@ import json import os.path +from functools import singledispatch + +import numpy as np from nvflare.apis.dxo import DataKind, from_shareable, get_leaf_dxos from nvflare.apis.event_type import EventType @@ -23,6 +26,17 @@ from nvflare.widgets.widget import Widget +@singledispatch +def to_serializable(val): + """Default json serializable method.""" + return str(val) + + +@to_serializable.register(np.float32) +def ts_float32(val): + return np.float64(val) + + class ValidationJsonGenerator(Widget): def __init__(self, results_dir=AppConstants.CROSS_VAL_DIR, json_file_name="cross_val_results.json"): """Catches VALIDATION_RESULT_RECEIVED event and generates a results.json containing accuracy of each @@ -58,7 +72,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if val_results: try: dxo = from_shareable(val_results) - dxo.validate() if dxo.data_kind == DataKind.METRICS: if data_client not in self._val_results: @@ -71,7 +84,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): for err in errors: self.log_error(fl_ctx, f"Bad result from {data_client}: {err}") for _sub_data_client, _dxo in leaf_dxos.items(): - _dxo.validate() if _sub_data_client not in self._val_results: self._val_results[_sub_data_client] = {} self._val_results[_sub_data_client][model_owner] = _dxo.data @@ -93,4 +105,4 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): res_file_path = os.path.join(cross_val_res_dir, self._json_file_name) with open(res_file_path, "w") as f: - json.dump(self._val_results, f) + json.dump(self._val_results, f, default=to_serializable)