Skip to content

Commit

Permalink
Add macro impact caching (#196)
Browse files Browse the repository at this point in the history
* Fix Uprating doesn't work in some cases #193

* Add caching of variables in microsimulations
Fixes #194
  • Loading branch information
nikhilwoodruff authored May 6, 2024
1 parent bfc7daf commit f5c644e
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 14 deletions.
7 changes: 7 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- bump: minor
changes:
added:
- Macro impact caching.
- Dictionary-input start-stop date reform handling.
fixed:
- Uprating bugs.
6 changes: 6 additions & 0 deletions policyengine_core/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(

self.values_list: List[ParameterAtInstant] = values_list

self.modified: bool = False

def __repr__(self):
return os.linesep.join(
[
Expand Down Expand Up @@ -208,6 +210,10 @@ def update(self, period=None, start=None, stop=None, value=None):

self.parent.clear_parent_cache()

def mark_as_modified(self):
self.modified = True
self.parent.mark_as_modified()

def get_descendants(self):
return iter(())

Expand Down
7 changes: 7 additions & 0 deletions policyengine_core/parameters/parameter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(
child = _parse_child(child_name_expanded, child, file_path)
self.add_child(child_name, child)

self.modified: bool = False

def merge(self, other: "ParameterNode") -> None:
"""
Merges another ParameterNode into the current node.
Expand Down Expand Up @@ -246,3 +248,8 @@ def clear_parent_cache(self):
if self.parent is not None:
self.parent.clear_parent_cache()
self._at_instant_cache.clear()

def mark_as_modified(self):
self.modified = True
if self.parent is not None:
self.parent.mark_as_modified()
21 changes: 17 additions & 4 deletions policyengine_core/reforms/reform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,24 @@ class reform(Reform):
def apply(self):
for path, period_values in parameter_values.items():
for period, value in period_values.items():
self.modify_parameters(
set_parameter(
path, value, period, return_modifier=True
if "." in period:
start, stop = period.split(".")
self.modify_parameters(
set_parameter(
path,
value,
period=None,
start=start,
stop=stop,
return_modifier=True,
)
)
else:
self.modify_parameters(
set_parameter(
path, value, period, return_modifier=True
)
)
)

reform.country_id = country_id
reform.parameter_values = parameter_values
Expand Down
106 changes: 96 additions & 10 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
SimpleTracer,
TracingParameterNodeAtInstant,
)
import h5py
from pathlib import Path
import shutil

import json

Expand Down Expand Up @@ -75,7 +78,6 @@ def __init__(
reform: Reform = None,
trace: bool = False,
):
self.is_over_dataset = dataset is not None
reform_applied_after = False
if tax_benefit_system is None:
if (
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
if dataset is None:
if self.default_dataset is not None:
dataset = self.default_dataset
self.is_over_dataset = dataset is not None

self.invalidated_caches = set()
self.debug: bool = False
Expand Down Expand Up @@ -523,6 +526,10 @@ def _calculate(
if cached_array is not None:
return cached_array

cache_path = self._get_macro_cache(variable_name, str(period))
if cache_path and cache_path.exists():
return self._get_macro_cache_value(cache_path)

if variable.requires_computation_after is not None:
if variable.requires_computation_after not in [
node.get("name") for node in self.tracer.stack
Expand All @@ -534,18 +541,25 @@ def _calculate(
for node in self.tracer.stack
)
)

alternate_period_handling = False
if variable.definition_period == MONTH and period.unit == YEAR:
if variable.quantity_type == QuantityType.STOCK:
contained_months = period.get_subperiods(MONTH)
return self._calculate(variable_name, contained_months[-1])
values = self._calculate(variable_name, contained_months[-1])
else:
return self.calculate_add(variable_name, period)
values = self.calculate_add(variable_name, period)
alternate_period_handling = True
elif variable.definition_period == YEAR and period.unit == MONTH:
alternate_period_handling = True
if variable.quantity_type == QuantityType.STOCK:
return self._calculate(variable_name, period.this_year)
values = self._calculate(variable_name, period.this_year)
else:
return self.calculate_divide(variable_name, period)
values = self.calculate_divide(variable_name, period)

if alternate_period_handling:
if cache_path is not None:
self._set_macro_cache_value(cache_path, values)
return values

self._check_period_consistency(period, variable)

Expand Down Expand Up @@ -578,11 +592,12 @@ def _calculate(
str(known_period.start)
for known_period in known_periods
if known_period.unit == variable.definition_period
and known_period.start < period.start
]
latest_known_period = known_periods[
np.argmax(start_instants)
]
if latest_known_period.start < period.start:
if len(start_instants) > 0:
latest_known_period = known_periods[
np.argmax(start_instants)
]
try:
uprating_parameter = get_parameter(
self.tax_benefit_system.parameters,
Expand Down Expand Up @@ -642,6 +657,9 @@ def _calculate(
f"RecursionError while calculating {variable_name} for period {period}. The full computation stack is:\n{stack_formatted}"
)

if cache_path is not None:
self._set_macro_cache_value(cache_path, array)

return array

def purge_cache_of_invalid_values(self) -> None:
Expand Down Expand Up @@ -1295,6 +1313,74 @@ def extract_person(

return json.loads(json.dumps(situation, cls=NpEncoder))

def _get_macro_cache(
self,
variable_name: str,
period: str,
):
"""
Get the cache location of a variable for a given period, if it exists.
"""
if not self.is_over_dataset:
return None

variable = self.tax_benefit_system.get_variable(variable_name)
parameter_deps = variable.exhaustive_parameter_dependencies

if parameter_deps is None:
return None

for parameter in parameter_deps:
param = get_parameter(
self.tax_benefit_system.parameters, parameter
)
if param.modified:
return None

storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
storage_folder.mkdir(exist_ok=True)

cache_file_path = (
storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5"
)

return cache_file_path

def clear_macro_cache(self):
"""
Clear the cache of all variables.
"""
storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
if storage_folder.exists():
shutil.rmtree(storage_folder)

def _get_macro_cache_value(
self,
cache_file_path: Path,
):
"""
Get the value of a variable from a cache file.
"""
with h5py.File(cache_file_path, "r") as f:
return f["values"][()]

def _set_macro_cache_value(
self,
cache_file_path: Path,
value: ArrayLike,
):
"""
Set the value of a variable in a cache file.
"""
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)


class NpEncoder(json.JSONEncoder):
def default(self, obj):
Expand Down
11 changes: 11 additions & 0 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class Variable:
requires_computation_after: str = None
"""Name of a variable that must be computed before this variable."""

exhaustive_parameter_dependencies: List[str] = None
"""If these parameters (plus the dataset, branch and period) haven't changed, Core will use caching on this variable."""

def __init__(self, baseline_variable=None):
self.name = self.__class__.__name__
attr = {
Expand Down Expand Up @@ -294,6 +297,14 @@ def __init__(self, baseline_variable=None):
attr, "requires_computation_after", allowed_type=str
)

self.exhaustive_parameter_dependencies = self.set(
attr, "exhaustive_parameter_dependencies"
)
if isinstance(self.exhaustive_parameter_dependencies, str):
self.exhaustive_parameter_dependencies = [
self.exhaustive_parameter_dependencies
]

formulas_attr, unexpected_attrs = helpers._partition(
attr,
lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX),
Expand Down

0 comments on commit f5c644e

Please sign in to comment.