Skip to content

Commit

Permalink
Reforms improvement (#230)
Browse files Browse the repository at this point in the history
* Add option for loading from dataframe

* Improve flexibility of reforms

* Versioning
  • Loading branch information
nikhilwoodruff authored Jul 28, 2024
1 parent b3c4de1 commit dd6ac92
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 56 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Reform syntax to increase flexibility.
124 changes: 124 additions & 0 deletions docs/python_api/reforms.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reforms\n",
"\n",
"To define a reform, simply define a class inheriting from `Reform` with an `apply(self)` function. Inside it, `self` is the tax-benefit system attached to the simulation with loaded data `self.simulation: Simulation`. From this, you can run any kind of modification on the `Simulation` instance that you like- modify parameters, variable logic or even adjust simulation data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from policyengine_core.country_template import Microsimulation\n",
"from policyengine_core.model_api import *\n",
"\n",
"baseline = Microsimulation()\n",
"\n",
"\n",
"class reform(Reform):\n",
" def apply(self):\n",
" simulation = self.simulation\n",
"\n",
" # Modify parameters\n",
"\n",
" simulation.tax_benefit_system.parameters.taxes.housing_tax.rate.update(\n",
" 20\n",
" )\n",
"\n",
" # Modify simulation data\n",
"\n",
" salary = simulation.calculate(\"salary\", \"2022-01\")\n",
"\n",
" new_salary = salary * 1.1\n",
"\n",
" simulation.set_input(\"salary\", \"2022-01\", new_salary)\n",
"\n",
"\n",
"reformed = Microsimulation(reform=reform)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"( value weight\n",
" 0 110.0 1000000.0\n",
" 1 0.0 1000000.0\n",
" 2 220.0 1200000.0,\n",
" value weight\n",
" 0 100.0 1000000.0\n",
" 1 0.0 1000000.0\n",
" 2 200.0 1200000.0)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reformed.calculate(\"salary\", \"2022-01\"), baseline.calculate(\n",
" \"salary\", \"2022-01\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"( value weight\n",
" 0 4000.0 1000000.0\n",
" 1 6000.0 1200000.0,\n",
" value weight\n",
" 0 2000.0 1000000.0\n",
" 1 3000.0 1200000.0)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reformed.calculate(\"housing_tax\", 2022), baseline.calculate(\n",
" \"housing_tax\", 2022\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
42 changes: 0 additions & 42 deletions docs/usage/reforms.md

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def generate(self) -> None:
"person_household_id": {ETERNITY: person_household_id},
"person_household_role": {ETERNITY: person_household_role},
"salary": {salary_time_period: salary},
"accommodation_size": {salary_time_period: [200, 300]},
"household_weight": {weight_time_period: weight},
}
self.save_dataset(data)
23 changes: 23 additions & 0 deletions policyengine_core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,26 @@ def from_file(file_path: str, time_period: str = None):
)()

return dataset

@staticmethod
def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
"""Creates a dataset from a DataFrame.
Returns:
Dataset: The dataset.
"""
file_path = Path(file_path)
dataset = type(
"Dataset",
(Dataset,),
{
"name": file_path.stem,
"label": file_path.stem,
"data_format": Dataset.FLAT_FILE,
"file_path": "dataframe",
"time_period": time_period,
"load": lambda: dataframe,
},
)()

return dataset
4 changes: 2 additions & 2 deletions policyengine_core/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def clone(self):
]
return clone

def update(self, period=None, start=None, stop=None, value=None):
def update(self, value=None, period=None, start=None, stop=None):
"""
Change the value for a given period.
Expand All @@ -156,7 +156,7 @@ def update(self, period=None, start=None, stop=None, value=None):
start = period.start
stop = period.stop
if start is None:
raise ValueError("You must provide either a start or a period")
start = "0000-01-01"
start_str = str(start)
stop_str = str(stop.offset(1, "day")) if stop else None

Expand Down
7 changes: 6 additions & 1 deletion policyengine_core/reforms/reform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import copy
from typing import Callable, Union
from typing import Callable, Union, TYPE_CHECKING

from policyengine_core.parameters import ParameterNode, Parameter
from policyengine_core.taxbenefitsystems import TaxBenefitSystem

if TYPE_CHECKING:
from policyengine_core.simulations import Simulation
from policyengine_core.periods import (
period as period_,
instant as instant_,
Expand Down Expand Up @@ -60,6 +63,8 @@ class Reform(TaxBenefitSystem):
parameter_values: dict = None
"""The parameter values of the reform. This is used to inform any calls to the PolicyEngine API."""

simulation: "Simulation" = None

def __init__(self, baseline: TaxBenefitSystem):
"""
:param baseline: Baseline TaxBenefitSystem.
Expand Down
17 changes: 6 additions & 11 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
reform: Reform = None,
trace: bool = False,
):
reform_applied_after = False
if tax_benefit_system is None:
if (
self.default_tax_benefit_system_instance is not None
Expand All @@ -94,20 +93,11 @@ def __init__(
tax_benefit_system = self.default_tax_benefit_system_instance
else:
# If reform is taken as an arg, pass it
try:
tax_benefit_system = self.default_tax_benefit_system(
reform=reform
)
except:
tax_benefit_system = self.default_tax_benefit_system()
reform_applied_after = True
tax_benefit_system = self.default_tax_benefit_system()
self.tax_benefit_system = tax_benefit_system

self.reform = reform
self.tax_benefit_system = tax_benefit_system

if reform_applied_after and reform is not None:
self.apply_reform(reform)
self.branch_name = "default"

if dataset is None:
Expand Down Expand Up @@ -169,6 +159,11 @@ def __init__(
self.dataset = dataset
self.build_from_dataset()

self.tax_benefit_system.simulation = self

if self.reform is not None:
self.tax_benefit_system.apply_reform_set(self.reform)

# Backwards compatibility methods
self.calc = self.calculate
self.df = self.calculate_dataframe
Expand Down

0 comments on commit dd6ac92

Please sign in to comment.