From f5c644e2c7106d3f4fd10eac9d5e93634293b919 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Mon, 6 May 2024 15:23:52 +0100 Subject: [PATCH] Add macro impact caching (#196) * Fix Uprating doesn't work in some cases #193 * Add caching of variables in microsimulations Fixes #194 --- changelog_entry.yaml | 7 ++ policyengine_core/parameters/parameter.py | 6 + .../parameters/parameter_node.py | 7 ++ policyengine_core/reforms/reform.py | 21 +++- policyengine_core/simulations/simulation.py | 106 ++++++++++++++++-- policyengine_core/variables/variable.py | 11 ++ 6 files changed, 144 insertions(+), 14 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..9fa52f874 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,7 @@ +- bump: minor + changes: + added: + - Macro impact caching. + - Dictionary-input start-stop date reform handling. + fixed: + - Uprating bugs. \ No newline at end of file diff --git a/policyengine_core/parameters/parameter.py b/policyengine_core/parameters/parameter.py index 0406d3f2e..57d04b430 100644 --- a/policyengine_core/parameters/parameter.py +++ b/policyengine_core/parameters/parameter.py @@ -108,6 +108,8 @@ def __init__( self.values_list: List[ParameterAtInstant] = values_list + self.modified: bool = False + def __repr__(self): return os.linesep.join( [ @@ -208,6 +210,10 @@ def update(self, period=None, start=None, stop=None, value=None): self.parent.clear_parent_cache() + def mark_as_modified(self): + self.modified = True + self.parent.mark_as_modified() + def get_descendants(self): return iter(()) diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py index 050b806b0..f497d3e8e 100644 --- a/policyengine_core/parameters/parameter_node.py +++ b/policyengine_core/parameters/parameter_node.py @@ -159,6 +159,8 @@ def __init__( child = _parse_child(child_name_expanded, child, file_path) self.add_child(child_name, child) + self.modified: bool = False + def merge(self, other: "ParameterNode") -> None: """ Merges another ParameterNode into the current node. @@ -246,3 +248,8 @@ def clear_parent_cache(self): if self.parent is not None: self.parent.clear_parent_cache() self._at_instant_cache.clear() + + def mark_as_modified(self): + self.modified = True + if self.parent is not None: + self.parent.mark_as_modified() diff --git a/policyengine_core/reforms/reform.py b/policyengine_core/reforms/reform.py index 363feafd9..6443fc5b2 100644 --- a/policyengine_core/reforms/reform.py +++ b/policyengine_core/reforms/reform.py @@ -135,11 +135,24 @@ class reform(Reform): def apply(self): for path, period_values in parameter_values.items(): for period, value in period_values.items(): - self.modify_parameters( - set_parameter( - path, value, period, return_modifier=True + if "." in period: + start, stop = period.split(".") + self.modify_parameters( + set_parameter( + path, + value, + period=None, + start=start, + stop=stop, + return_modifier=True, + ) + ) + else: + self.modify_parameters( + set_parameter( + path, value, period, return_modifier=True + ) ) - ) reform.country_id = country_id reform.parameter_values = parameter_values diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index d68ecd285..e6dbed5d4 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -20,6 +20,9 @@ SimpleTracer, TracingParameterNodeAtInstant, ) +import h5py +from pathlib import Path +import shutil import json @@ -75,7 +78,6 @@ def __init__( reform: Reform = None, trace: bool = False, ): - self.is_over_dataset = dataset is not None reform_applied_after = False if tax_benefit_system is None: if ( @@ -104,6 +106,7 @@ def __init__( if dataset is None: if self.default_dataset is not None: dataset = self.default_dataset + self.is_over_dataset = dataset is not None self.invalidated_caches = set() self.debug: bool = False @@ -523,6 +526,10 @@ def _calculate( if cached_array is not None: return cached_array + cache_path = self._get_macro_cache(variable_name, str(period)) + if cache_path and cache_path.exists(): + return self._get_macro_cache_value(cache_path) + if variable.requires_computation_after is not None: if variable.requires_computation_after not in [ node.get("name") for node in self.tracer.stack @@ -534,18 +541,25 @@ def _calculate( for node in self.tracer.stack ) ) - + alternate_period_handling = False if variable.definition_period == MONTH and period.unit == YEAR: if variable.quantity_type == QuantityType.STOCK: contained_months = period.get_subperiods(MONTH) - return self._calculate(variable_name, contained_months[-1]) + values = self._calculate(variable_name, contained_months[-1]) else: - return self.calculate_add(variable_name, period) + values = self.calculate_add(variable_name, period) + alternate_period_handling = True elif variable.definition_period == YEAR and period.unit == MONTH: + alternate_period_handling = True if variable.quantity_type == QuantityType.STOCK: - return self._calculate(variable_name, period.this_year) + values = self._calculate(variable_name, period.this_year) else: - return self.calculate_divide(variable_name, period) + values = self.calculate_divide(variable_name, period) + + if alternate_period_handling: + if cache_path is not None: + self._set_macro_cache_value(cache_path, values) + return values self._check_period_consistency(period, variable) @@ -578,11 +592,12 @@ def _calculate( str(known_period.start) for known_period in known_periods if known_period.unit == variable.definition_period + and known_period.start < period.start ] - latest_known_period = known_periods[ - np.argmax(start_instants) - ] - if latest_known_period.start < period.start: + if len(start_instants) > 0: + latest_known_period = known_periods[ + np.argmax(start_instants) + ] try: uprating_parameter = get_parameter( self.tax_benefit_system.parameters, @@ -642,6 +657,9 @@ def _calculate( f"RecursionError while calculating {variable_name} for period {period}. The full computation stack is:\n{stack_formatted}" ) + if cache_path is not None: + self._set_macro_cache_value(cache_path, array) + return array def purge_cache_of_invalid_values(self) -> None: @@ -1295,6 +1313,74 @@ def extract_person( return json.loads(json.dumps(situation, cls=NpEncoder)) + def _get_macro_cache( + self, + variable_name: str, + period: str, + ): + """ + Get the cache location of a variable for a given period, if it exists. + """ + if not self.is_over_dataset: + return None + + variable = self.tax_benefit_system.get_variable(variable_name) + parameter_deps = variable.exhaustive_parameter_dependencies + + if parameter_deps is None: + return None + + for parameter in parameter_deps: + param = get_parameter( + self.tax_benefit_system.parameters, parameter + ) + if param.modified: + return None + + storage_folder = ( + self.dataset.file_path.parent + / f"{self.dataset.name}_variable_cache" + ) + storage_folder.mkdir(exist_ok=True) + + cache_file_path = ( + storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5" + ) + + return cache_file_path + + def clear_macro_cache(self): + """ + Clear the cache of all variables. + """ + storage_folder = ( + self.dataset.file_path.parent + / f"{self.dataset.name}_variable_cache" + ) + if storage_folder.exists(): + shutil.rmtree(storage_folder) + + def _get_macro_cache_value( + self, + cache_file_path: Path, + ): + """ + Get the value of a variable from a cache file. + """ + with h5py.File(cache_file_path, "r") as f: + return f["values"][()] + + def _set_macro_cache_value( + self, + cache_file_path: Path, + value: ArrayLike, + ): + """ + Set the value of a variable in a cache file. + """ + with h5py.File(cache_file_path, "w") as f: + f.create_dataset("values", data=value) + class NpEncoder(json.JSONEncoder): def default(self, obj): diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py index f8b211057..b57d3a9e5 100644 --- a/policyengine_core/variables/variable.py +++ b/policyengine_core/variables/variable.py @@ -129,6 +129,9 @@ class Variable: requires_computation_after: str = None """Name of a variable that must be computed before this variable.""" + exhaustive_parameter_dependencies: List[str] = None + """If these parameters (plus the dataset, branch and period) haven't changed, Core will use caching on this variable.""" + def __init__(self, baseline_variable=None): self.name = self.__class__.__name__ attr = { @@ -294,6 +297,14 @@ def __init__(self, baseline_variable=None): attr, "requires_computation_after", allowed_type=str ) + self.exhaustive_parameter_dependencies = self.set( + attr, "exhaustive_parameter_dependencies" + ) + if isinstance(self.exhaustive_parameter_dependencies, str): + self.exhaustive_parameter_dependencies = [ + self.exhaustive_parameter_dependencies + ] + formulas_attr, unexpected_attrs = helpers._partition( attr, lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX),