Skip to content

Commit

Permalink
WE are getting there.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Nov 29, 2023
1 parent eb9ca3c commit 8dc34c4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 41 deletions.
15 changes: 8 additions & 7 deletions src/dcegm/simulation/sim_final_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from dcegm.simulation.sim_utils import compute_final_utility_for_each_choice
from dcegm.simulation.sim_utils import draw_taste_shocks
from dcegm.simulation.sim_utils import get_state_choice_index_per_state
from jax import numpy as jnp
from jax import vmap

Expand All @@ -12,6 +13,7 @@ def simulate_final_period(
params,
basic_seed,
choice_range,
map_state_choice_to_index,
compute_utility_final_period,
):
(
Expand All @@ -22,13 +24,6 @@ def simulate_final_period(
n_choices = len(choice_range)
n_agents = len(resources_beginning_of_final_period)

# _utility = compute_utility_final_period(
# **states_beginning_of_final_period,
# choice=sim_dict["choice"][period - 1],
# resources=resources_beginning_of_final_period,
# params=params,
# )

utilities_pre_taste_shock = vmap(
vmap(
compute_final_utility_for_each_choice,
Expand All @@ -42,6 +37,12 @@ def simulate_final_period(
params,
compute_utility_final_period,
)
state_choice_indexes = get_state_choice_index_per_state(
map_state_choice_to_index, states_beginning_of_final_period
)
utilities_pre_taste_shock = jnp.where(
state_choice_indexes < 0, np.nan, utilities_pre_taste_shock
)

# Draw taste shocks and calculate final value.
key = jax.random.PRNGKey(basic_seed + period)
Expand Down
6 changes: 5 additions & 1 deletion src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ def interpolate_policy_and_value_for_all_agents(
state_choice_indexes = get_state_choice_index_per_state(
map_state_choice_to_index, states_beginning_of_period
)
value_grid_agent = jnp.take(value_solved, state_choice_indexes, axis=0)

value_grid_agent = jnp.take(
value_solved, state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan
)
policy_left_grid_agent = jnp.take(policy_left_solved, state_choice_indexes, axis=0)
policy_right_grid_agent = jnp.take(
policy_right_solved, state_choice_indexes, axis=0
Expand All @@ -55,6 +58,7 @@ def interpolate_policy_and_value_for_all_agents(
params,
compute_utility,
)

return policy_agent, value_per_agent_interp


Expand Down
1 change: 1 addition & 0 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def simulate_all_periods(
params=params,
basic_seed=seed,
choice_range=choice_range,
map_state_choice_to_index=map_state_choice_to_index,
compute_utility_final_period=compute_utility_final_period,
)

Expand Down
74 changes: 41 additions & 33 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from dcegm.pre_processing.model_functions import process_model_functions
from dcegm.pre_processing.state_space import create_state_space_and_choice_objects
from dcegm.simulation.sim_final_period import simulate_final_period
from dcegm.simulation.sim_utils import create_simulation_df
from dcegm.simulation.simulate import simulate_all_periods
from dcegm.simulation.simulate import simulate_single_period
Expand All @@ -29,13 +30,6 @@
marginal_utility,
)

# from tests.two_period_models.exog_ltc_and_job_offer.model_functions import (
# budget_dcegm_two_exog_processes,
# )
# from tests.two_period_models.exog_ltc_and_job_offer.model_functions import (
# func_exog_job_offer,
# )

WEALTH_GRID_POINTS = 100


Expand Down Expand Up @@ -72,16 +66,15 @@ def utility_functions_final_period():
def test_simulate(
state_space_functions, utility_functions, utility_functions_final_period
):
n_agents = 100_000
n_agents = 1_000_000

params = {}
params["rho"] = 0.5
params["delta"] = 0.5 * 10
params["delta"] = 0.5
params["interest_rate"] = 0.02
params["ltc_cost"] = 5
params["wage_avg"] = 8
params["sigma"] = 1
params["lambda"] = 1e-16
params["lambda"] = 10
params["beta"] = 0.95

Expand Down Expand Up @@ -173,7 +166,7 @@ def test_simulate(
resources_initial = np.ones(n_agents) * 10
states_and_wealth_beginning_of_period_zero = (initial_states, resources_initial)

_carry, _result = simulate_single_period(
carry_final, sim_dict_0 = simulate_single_period(
states_and_resources_beginning_of_period=states_and_wealth_beginning_of_period_zero,
period=0,
params=params,
Expand All @@ -193,7 +186,17 @@ def test_simulate(
update_endog_state_by_state_and_choice=update_endog_state_by_state_and_choice,
)

sim_dict = simulate_all_periods(
final_period_dict = simulate_final_period(
carry_final,
period=1,
params=params,
basic_seed=111,
choice_range=jnp.arange(2),
map_state_choice_to_index=jnp.array(map_state_choice_to_index),
compute_utility_final_period=model_funcs["compute_utility_final"],
)

result = simulate_all_periods(
states_initial=initial_states,
resources_initial=resources_initial,
n_periods=options["state_space"]["n_periods"],
Expand All @@ -215,36 +218,41 @@ def test_simulate(
compute_utility_final_period=model_funcs["compute_utility_final"],
)

df = create_simulation_df(sim_dict)

period = 0

# absrobing retirement state
# this should contain nobody
sim_dict["choice"][
:, (sim_dict["choice"][period] == 0) & (sim_dict["choice"][period + 1] == 1)
]
df = create_simulation_df(result)

_cond = [df["choice"] == 0, df["choice"] == 1]
_val = [df["taste_shock_0"], df["taste_shock_1"]]
df["taste_shock_selected_choice"] = np.select(_cond, _val)

# taste_shocks_final = df.xs(period + 1, level=0).filter(like="taste_shock_")

value_period_zero = (
df.xs(period, level=0)["utility"].mean()
+ params["beta"] * df.xs(period + 1, level=0)["value"].mean()
df.xs(0, level=0)["utility"].mean()
+ params["beta"]
* (
df.xs(1, level=0)["utility"]
+ df.xs(1, level=0)["taste_shock_selected_choice"]
).mean()
)
expected = (
df.xs(period, level=0)["value"]
- df.xs(period, level=0)["taste_shock_selected_choice"]
).mean()
df.xs(0, level=0)["value"].mean()
- df.xs(0, level=0)["taste_shock_selected_choice"].mean()
)

_cond = [final_period_dict["choice"] == 0, final_period_dict["choice"] == 1]
_val = [
final_period_dict["taste_shocks"][0, :, 0],
final_period_dict["taste_shocks"][0, :, 1],
]
taste_shock_selected_choice_final = np.select(_cond, _val)

_cond = [sim_dict_0["choice"] == 0, sim_dict_0["choice"] == 1]
_val = [sim_dict_0["taste_shocks"][:, 0], sim_dict_0["taste_shocks"][:, 1]]
taste_shock_selected_0 = np.select(_cond, _val)

_value_period_zero = df.xs(period, level=0)["utility"] + params["beta"] * (
df.xs(period + 1, level=0)["value"]
_value_period_zero = (
sim_dict_0["utility"].mean()
+ params["beta"]
* (final_period_dict["utility"] + taste_shock_selected_choice_final).mean()
)
_expected = df.xs(period, level=0)["value"]
_expected = sim_dict_0["value"].mean() - taste_shock_selected_0.mean()

aaae(value_period_zero.mean(), expected.mean(), decimal=2)
# aaae(_value_period_zero, _expected)
# aaae(_value_period_zero.mean(), _expected.mean(), decimal=2)

0 comments on commit 8dc34c4

Please sign in to comment.