Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add macro impact caching #196

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- bump: minor
changes:
added:
- Macro impact caching.
- Dictionary-input start-stop date reform handling.
fixed:
- Uprating bugs.
6 changes: 6 additions & 0 deletions policyengine_core/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(

self.values_list: List[ParameterAtInstant] = values_list

self.modified: bool = False

def __repr__(self):
return os.linesep.join(
[
Expand Down Expand Up @@ -208,6 +210,10 @@ def update(self, period=None, start=None, stop=None, value=None):

self.parent.clear_parent_cache()

def mark_as_modified(self):
self.modified = True
self.parent.mark_as_modified()

def get_descendants(self):
return iter(())

Expand Down
7 changes: 7 additions & 0 deletions policyengine_core/parameters/parameter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(
child = _parse_child(child_name_expanded, child, file_path)
self.add_child(child_name, child)

self.modified: bool = False

def merge(self, other: "ParameterNode") -> None:
"""
Merges another ParameterNode into the current node.
Expand Down Expand Up @@ -246,3 +248,8 @@ def clear_parent_cache(self):
if self.parent is not None:
self.parent.clear_parent_cache()
self._at_instant_cache.clear()

def mark_as_modified(self):
self.modified = True
if self.parent is not None:
self.parent.mark_as_modified()
21 changes: 17 additions & 4 deletions policyengine_core/reforms/reform.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,24 @@ class reform(Reform):
def apply(self):
for path, period_values in parameter_values.items():
for period, value in period_values.items():
self.modify_parameters(
set_parameter(
path, value, period, return_modifier=True
if "." in period:
start, stop = period.split(".")
self.modify_parameters(
set_parameter(
path,
value,
period=None,
start=start,
stop=stop,
return_modifier=True,
)
)
else:
self.modify_parameters(
set_parameter(
path, value, period, return_modifier=True
)
)
)

reform.country_id = country_id
reform.parameter_values = parameter_values
Expand Down
106 changes: 96 additions & 10 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
SimpleTracer,
TracingParameterNodeAtInstant,
)
import h5py
from pathlib import Path
import shutil

import json

Expand Down Expand Up @@ -75,7 +78,6 @@ def __init__(
reform: Reform = None,
trace: bool = False,
):
self.is_over_dataset = dataset is not None
reform_applied_after = False
if tax_benefit_system is None:
if (
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
if dataset is None:
if self.default_dataset is not None:
dataset = self.default_dataset
self.is_over_dataset = dataset is not None

self.invalidated_caches = set()
self.debug: bool = False
Expand Down Expand Up @@ -523,6 +526,10 @@ def _calculate(
if cached_array is not None:
return cached_array

cache_path = self._get_macro_cache(variable_name, str(period))
if cache_path and cache_path.exists():
return self._get_macro_cache_value(cache_path)

if variable.requires_computation_after is not None:
if variable.requires_computation_after not in [
node.get("name") for node in self.tracer.stack
Expand All @@ -534,18 +541,25 @@ def _calculate(
for node in self.tracer.stack
)
)

alternate_period_handling = False
if variable.definition_period == MONTH and period.unit == YEAR:
if variable.quantity_type == QuantityType.STOCK:
contained_months = period.get_subperiods(MONTH)
return self._calculate(variable_name, contained_months[-1])
values = self._calculate(variable_name, contained_months[-1])
else:
return self.calculate_add(variable_name, period)
values = self.calculate_add(variable_name, period)
alternate_period_handling = True
elif variable.definition_period == YEAR and period.unit == MONTH:
alternate_period_handling = True
if variable.quantity_type == QuantityType.STOCK:
return self._calculate(variable_name, period.this_year)
values = self._calculate(variable_name, period.this_year)
else:
return self.calculate_divide(variable_name, period)
values = self.calculate_divide(variable_name, period)

if alternate_period_handling:
if cache_path is not None:
self._set_macro_cache_value(cache_path, values)
return values

self._check_period_consistency(period, variable)

Expand Down Expand Up @@ -578,11 +592,12 @@ def _calculate(
str(known_period.start)
for known_period in known_periods
if known_period.unit == variable.definition_period
and known_period.start < period.start
]
latest_known_period = known_periods[
np.argmax(start_instants)
]
if latest_known_period.start < period.start:
if len(start_instants) > 0:
latest_known_period = known_periods[
np.argmax(start_instants)
]
try:
uprating_parameter = get_parameter(
self.tax_benefit_system.parameters,
Expand Down Expand Up @@ -642,6 +657,9 @@ def _calculate(
f"RecursionError while calculating {variable_name} for period {period}. The full computation stack is:\n{stack_formatted}"
)

if cache_path is not None:
self._set_macro_cache_value(cache_path, array)

return array

def purge_cache_of_invalid_values(self) -> None:
Expand Down Expand Up @@ -1295,6 +1313,74 @@ def extract_person(

return json.loads(json.dumps(situation, cls=NpEncoder))

def _get_macro_cache(
self,
variable_name: str,
period: str,
):
"""
Get the cache location of a variable for a given period, if it exists.
"""
if not self.is_over_dataset:
return None

variable = self.tax_benefit_system.get_variable(variable_name)
parameter_deps = variable.exhaustive_parameter_dependencies

if parameter_deps is None:
return None

for parameter in parameter_deps:
param = get_parameter(
self.tax_benefit_system.parameters, parameter
)
if param.modified:
return None

storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
storage_folder.mkdir(exist_ok=True)

cache_file_path = (
storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5"
)

return cache_file_path

def clear_macro_cache(self):
"""
Clear the cache of all variables.
"""
storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
if storage_folder.exists():
shutil.rmtree(storage_folder)

def _get_macro_cache_value(
self,
cache_file_path: Path,
):
"""
Get the value of a variable from a cache file.
"""
with h5py.File(cache_file_path, "r") as f:
return f["values"][()]

def _set_macro_cache_value(
self,
cache_file_path: Path,
value: ArrayLike,
):
"""
Set the value of a variable in a cache file.
"""
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)


class NpEncoder(json.JSONEncoder):
def default(self, obj):
Expand Down
11 changes: 11 additions & 0 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class Variable:
requires_computation_after: str = None
"""Name of a variable that must be computed before this variable."""

exhaustive_parameter_dependencies: List[str] = None
"""If these parameters (plus the dataset, branch and period) haven't changed, Core will use caching on this variable."""

def __init__(self, baseline_variable=None):
self.name = self.__class__.__name__
attr = {
Expand Down Expand Up @@ -294,6 +297,14 @@ def __init__(self, baseline_variable=None):
attr, "requires_computation_after", allowed_type=str
)

self.exhaustive_parameter_dependencies = self.set(
attr, "exhaustive_parameter_dependencies"
)
if isinstance(self.exhaustive_parameter_dependencies, str):
self.exhaustive_parameter_dependencies = [
self.exhaustive_parameter_dependencies
]

formulas_attr, unexpected_attrs = helpers._partition(
attr,
lambda name, value: name.startswith(config.FORMULA_NAME_PREFIX),
Expand Down
Loading