Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to output ParameterNode as YAML #295

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: minor
changes:
added:
- write_yaml function to output ParameterNode data to a YAML file
- test_write_yaml test to produce a sample output
4 changes: 4 additions & 0 deletions policyengine_core/parameters/at_instant_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def get_at_instant(self, instant: Instant) -> Any:

@abc.abstractmethod
def _get_at_instant(self, instant): ...

@abc.abstractmethod
def get_attr_dict(self) -> dict:
raise NotImplementedError
28 changes: 23 additions & 5 deletions policyengine_core/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import os
from typing import Dict, List, Optional

from policyengine_core.errors import ParameterParsingError
from .at_instant_like import AtInstantLike
from .parameter_at_instant import ParameterAtInstant
import numpy

from .helpers import _validate_parameter, _compose_name
from .config import COMMON_KEYS
from policyengine_core.commons.misc import empty_clone
from policyengine_core.errors import ParameterParsingError
from policyengine_core.periods import INSTANT_PATTERN, period as get_period
from .at_instant_like import AtInstantLike
from .config import COMMON_KEYS
from .helpers import _validate_parameter, _compose_name
from .parameter_at_instant import ParameterAtInstant


class Parameter(AtInstantLike):
Expand Down Expand Up @@ -45,6 +46,8 @@ class Parameter(AtInstantLike):

"""

_exclusion_list = ["parent", "_at_instant_cache"]

def __init__(
self, name: str, data: dict, file_path: Optional[str] = None
) -> None:
Expand Down Expand Up @@ -233,3 +236,18 @@ def relative_change(self, start_instant, end_instant):
if end_value is None or start_value is None:
return None
return end_value / start_value - 1

def get_attr_dict(self) -> dict:
data = self.__dict__.copy()
for attr in self._exclusion_list:
if attr in data.keys():
del data[attr]
if "values_list" in data.keys():
value_dict = {}
for value_at_instant in data["values_list"]:
value = value_at_instant.value
if type(value) is numpy.float64:
value = float(value)
value_dict[value_at_instant.instant_str] = value
data["values_list"] = value_dict
return data
37 changes: 32 additions & 5 deletions policyengine_core/parameters/parameter_node.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import copy
import os
import typing
from typing import Iterable, List, Type, Union
from pathlib import Path
from typing import Iterable, Union

import yaml

from policyengine_core import commons, parameters, tools
from policyengine_core.data_structures import Reference
from policyengine_core.periods.instant_ import Instant
from policyengine_core.tracers import TracingParameterNodeAtInstant

from .at_instant_like import AtInstantLike
from .parameter import Parameter
from .parameter_node_at_instant import ParameterNodeAtInstant
from .config import COMMON_KEYS, FILE_EXTENSIONS
from .helpers import (
load_parameter_file,
Expand All @@ -19,6 +18,8 @@
_parse_child,
_load_yaml_file,
)
from .parameter import Parameter
from .parameter_node_at_instant import ParameterNodeAtInstant

EXCLUDED_PARAMETER_CHILD_NAMES = ["reference", "__pycache__"]

Expand All @@ -32,6 +33,8 @@ class ParameterNode(AtInstantLike):
None # By default, no restriction on the keys
)

_exclusion_list = ["parent", "_at_instant_cache"]

parent: "ParameterNode" = None
"""The parent of the node, or None if the node is the root of the tree."""

Expand Down Expand Up @@ -274,3 +277,27 @@ def get_child(self, path: str) -> "ParameterNode":
f"Could not find the parameter (failed at {name})."
)
return node

def get_attr_dict(self) -> dict:
data = self.__dict__.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one minor suggestion: in this function, could we process every attr that isn't in child_dict, then process all the children? This makes keys at the end of the alphabet (like "tracer") more easily readable.

for attr in self._exclusion_list:
if attr in data.keys():
del data[attr]
if "children" in data.keys():
child_dict = data.get("children")
for child_name, child in child_dict.items():
data[child_name] = child.get_attr_dict()
del data["children"]
return data

class NoAliasDumper(yaml.SafeDumper):
def ignore_aliases(self, data):
return True

def write_yaml(self, file_path: Path) -> yaml:
data = self.get_attr_dict()
anth-volk marked this conversation as resolved.
Show resolved Hide resolved
try:
with open(file_path, "w") as f:
yaml.dump(data, f, sort_keys=True, Dumper=self.NoAliasDumper)
except Exception as e:
print(f"Error when writing YAML file: {e}")
16 changes: 16 additions & 0 deletions policyengine_core/parameters/parameter_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ParameterScale(AtInstantLike):
# 'unit' and 'reference' are only listed here for backward compatibility
_allowed_keys = config.COMMON_KEYS.union({"brackets"})

_exclusion_list = ["parent", "_at_instant_cache"]

def __init__(self, name: str, data: dict, file_path: str):
"""
:param name: name of the scale, eg "taxes.some_scale"
Expand Down Expand Up @@ -169,3 +171,17 @@ def _get_at_instant(self, instant: Instant) -> TaxScaleLike:
threshold = bracket.threshold
scale.add_bracket(threshold, rate * base)
return scale

def get_attr_dict(self) -> dict:
data = self.__dict__.copy()
for attr in self._exclusion_list:
if attr in data.keys():
del data[attr]
if "brackets" in data.keys():
node_list = data["brackets"]
i = 0
for node in node_list:
node_list[i] = node.get_attr_dict()
i += 1
data["brackets"] = node_list
return data
43 changes: 42 additions & 1 deletion tests/core/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile

from pathlib import Path
import os
import pytest

from policyengine_core.parameters import (
Expand All @@ -8,6 +9,7 @@
ParameterNotFoundError,
load_parameter_file,
)
from policyengine_core.tools.test_runner import yaml


def test_get_at_instant(tax_benefit_system):
Expand Down Expand Up @@ -141,3 +143,42 @@ def test_name():
}
parameter = ParameterNode("root", data=parameter_data)
assert parameter.children["2010"].name == "root.2010"


def test_write_yaml():
parameter_data = {
"amount": {
"values": {
"2015-01-01": {"value": 550},
"2016-01-01": {"value": 600},
},
"description": "The amount of the basic income",
"documentation": None,
"modified": False,
},
"min_age": {
"values": {
"2015-01-01": {"value": 25},
"2016-01-01": {"value": 18},
},
"description": "The minimum age to receive the basic income",
"documentation": None,
"modified": True,
},
}
parameter = ParameterNode("root", data=parameter_data)
parameter.write_yaml(Path("output.yaml"))

try:
with open("output.yaml", "r") as file:
data = yaml.safe_load(file)
os.remove("output.yaml")
except yaml.YAMLError as e:
pytest.fail(f"Output is not valid YAML: {e}")


# from policyengine_us import Microsimulation
# def test_yaml_us():
# baseline = Microsimulation()
# tbs = baseline.tax_benefit_system
# tbs.parameters.gov.write_yaml("test_output.yaml")
Loading