diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..369eef8c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - write_yaml function to output ParameterNode data to a YAML file + - test_write_yaml test to produce a sample output diff --git a/policyengine_core/parameters/at_instant_like.py b/policyengine_core/parameters/at_instant_like.py index ecde8614..844fa2b6 100644 --- a/policyengine_core/parameters/at_instant_like.py +++ b/policyengine_core/parameters/at_instant_like.py @@ -19,3 +19,7 @@ def get_at_instant(self, instant: Instant) -> Any: @abc.abstractmethod def _get_at_instant(self, instant): ... + + @abc.abstractmethod + def get_attr_dict(self) -> dict: + raise NotImplementedError diff --git a/policyengine_core/parameters/parameter.py b/policyengine_core/parameters/parameter.py index fd792030..5127b3e3 100644 --- a/policyengine_core/parameters/parameter.py +++ b/policyengine_core/parameters/parameter.py @@ -2,14 +2,15 @@ import os from typing import Dict, List, Optional +import numpy +from collections import OrderedDict +from policyengine_core.commons.misc import empty_clone from policyengine_core.errors import ParameterParsingError +from policyengine_core.periods import INSTANT_PATTERN, period as get_period from .at_instant_like import AtInstantLike -from .parameter_at_instant import ParameterAtInstant - -from .helpers import _validate_parameter, _compose_name from .config import COMMON_KEYS -from policyengine_core.commons.misc import empty_clone -from policyengine_core.periods import INSTANT_PATTERN, period as get_period +from .helpers import _validate_parameter, _compose_name +from .parameter_at_instant import ParameterAtInstant class Parameter(AtInstantLike): @@ -45,6 +46,9 @@ class Parameter(AtInstantLike): """ + _exclusion_list = ["parent", "_at_instant_cache"] + """The keys to be excluded from the node when output to a yaml file.""" + def __init__( self, name: str, data: dict, file_path: Optional[str] = None ) -> None: @@ -233,3 +237,19 @@ def relative_change(self, start_instant, end_instant): if end_value is None or start_value is None: return None return end_value / start_value - 1 + + def get_attr_dict(self) -> dict: + data = OrderedDict(self.__dict__.copy()) + for attr in self._exclusion_list: + if attr in data.keys(): + del data[attr] + if "values_list" in data.keys(): + value_dict = {} + for value_at_instant in data["values_list"]: + value = value_at_instant.value + if type(value) is numpy.float64: + value = float(value) + value_dict[value_at_instant.instant_str] = value + data["values_list"] = value_dict + data.move_to_end("values_list") + return dict(data) diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py index 5f05e33c..2dd0fd8c 100644 --- a/policyengine_core/parameters/parameter_node.py +++ b/policyengine_core/parameters/parameter_node.py @@ -1,16 +1,16 @@ import copy import os import typing -from typing import Iterable, List, Type, Union +from pathlib import Path +from typing import Iterable, Union + +import yaml +from collections import OrderedDict from policyengine_core import commons, parameters, tools -from policyengine_core.data_structures import Reference from policyengine_core.periods.instant_ import Instant from policyengine_core.tracers import TracingParameterNodeAtInstant - from .at_instant_like import AtInstantLike -from .parameter import Parameter -from .parameter_node_at_instant import ParameterNodeAtInstant from .config import COMMON_KEYS, FILE_EXTENSIONS from .helpers import ( load_parameter_file, @@ -19,6 +19,8 @@ _parse_child, _load_yaml_file, ) +from .parameter import Parameter +from .parameter_node_at_instant import ParameterNodeAtInstant EXCLUDED_PARAMETER_CHILD_NAMES = ["reference", "__pycache__"] @@ -32,6 +34,9 @@ class ParameterNode(AtInstantLike): None # By default, no restriction on the keys ) + _exclusion_list = ["parent", "_at_instant_cache"] + """The keys to be excluded from the node when output to a yaml file.""" + parent: "ParameterNode" = None """The parent of the node, or None if the node is the root of the tree.""" @@ -274,3 +279,28 @@ def get_child(self, path: str) -> "ParameterNode": f"Could not find the parameter (failed at {name})." ) return node + + def get_attr_dict(self) -> dict: + data = OrderedDict(self.__dict__.copy()) + for attr in self._exclusion_list: + if attr in data.keys(): + del data[attr] + if "children" in data.keys(): + child_dict = data.get("children") + for child_name, child in child_dict.items(): + data[child_name] = child.get_attr_dict() + data.move_to_end(child_name) + del data["children"] + return dict(data) + + class NoAliasDumper(yaml.SafeDumper): + def ignore_aliases(self, data): + return True + + def write_yaml(self, file_path: Path) -> yaml: + data = self.get_attr_dict() + try: + with open(file_path, "w") as f: + yaml.dump(data, f, sort_keys=False, Dumper=self.NoAliasDumper) + except Exception as e: + print(f"Error when writing YAML file: {e}") diff --git a/policyengine_core/parameters/parameter_scale.py b/policyengine_core/parameters/parameter_scale.py index 39dcd229..2d9a1c14 100644 --- a/policyengine_core/parameters/parameter_scale.py +++ b/policyengine_core/parameters/parameter_scale.py @@ -2,7 +2,7 @@ import os import typing from typing import Any, Iterable - +from collections import OrderedDict from policyengine_core import commons, parameters, tools from policyengine_core.errors import ParameterParsingError from policyengine_core.parameters import AtInstantLike, config, helpers @@ -24,6 +24,9 @@ class ParameterScale(AtInstantLike): # 'unit' and 'reference' are only listed here for backward compatibility _allowed_keys = config.COMMON_KEYS.union({"brackets"}) + _exclusion_list = ["parent", "_at_instant_cache"] + """The keys to be excluded from the node when output to a yaml file.""" + def __init__(self, name: str, data: dict, file_path: str): """ :param name: name of the scale, eg "taxes.some_scale" @@ -169,3 +172,18 @@ def _get_at_instant(self, instant: Instant) -> TaxScaleLike: threshold = bracket.threshold scale.add_bracket(threshold, rate * base) return scale + + def get_attr_dict(self) -> dict: + data = OrderedDict(self.__dict__.copy()) + for attr in self._exclusion_list: + if attr in data.keys(): + del data[attr] + if "brackets" in data.keys(): + node_list = data["brackets"] + i = 0 + for node in node_list: + node_list[i] = node.get_attr_dict() + i += 1 + data["brackets"] = node_list + data.move_to_end("brackets") + return dict(data) diff --git a/tests/core/test_parameters.py b/tests/core/test_parameters.py index 2cb71e2d..68369e09 100644 --- a/tests/core/test_parameters.py +++ b/tests/core/test_parameters.py @@ -1,5 +1,6 @@ import tempfile - +from pathlib import Path +import os import pytest from policyengine_core.parameters import ( @@ -8,6 +9,7 @@ ParameterNotFoundError, load_parameter_file, ) +from policyengine_core.tools.test_runner import yaml def test_get_at_instant(tax_benefit_system): @@ -141,3 +143,42 @@ def test_name(): } parameter = ParameterNode("root", data=parameter_data) assert parameter.children["2010"].name == "root.2010" + + +def test_write_yaml(): + parameter_data = { + "amount": { + "values": { + "2015-01-01": {"value": 550}, + "2016-01-01": {"value": 600}, + }, + "description": "The amount of the basic income", + "documentation": None, + "modified": False, + }, + "min_age": { + "values": { + "2015-01-01": {"value": 25}, + "2016-01-01": {"value": 18}, + }, + "description": "The minimum age to receive the basic income", + "documentation": None, + "modified": True, + }, + } + parameter = ParameterNode("root", data=parameter_data) + parameter.write_yaml(Path("output.yaml")) + + try: + with open("output.yaml", "r") as file: + data = yaml.safe_load(file) + os.remove("output.yaml") + except yaml.YAMLError as e: + pytest.fail(f"Output is not valid YAML: {e}") + + +# from policyengine_us import Microsimulation +# def test_yaml_us(): +# baseline = Microsimulation() +# tbs = baseline.tax_benefit_system +# tbs.parameters.gov.write_yaml("test_output.yaml")