diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..1955401c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Simulation helper to extract individual households from a microsimulation. diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index c695424f..f89d91a3 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1182,3 +1182,113 @@ def derivative( new_value = alt_sim.calculate(variable, period) difference = new_value - original_value return difference / delta + + def sample_person(self) -> dict: + """ + Sample a person from the simulation. Returns a situation JSON with their inputs (including their containing entities). + + Returns: + dict: A dictionary containing the person's values. + """ + person_count = self.persons.count + index = np.random.randint(person_count) + return self.extract_person(index) + + def extract_person( + self, + index: int = 0, + exclude_entities: tuple = ("state",), + ) -> dict: + """ + Extract a person from the simulation. Returns a situation JSON with their inputs (including their containing entities). + + Args: + index (int): The index of the person to extract. + + Returns: + dict: A dictionary containing the person's values. + """ + situation = {} + people_indices = [] + people_indices_by_entity = {} + + for population in self.populations.values(): + entity = population.entity + if ( + not population.entity.is_person + and entity.key not in exclude_entities + ): + situation[entity.plural] = { + entity.key: { + "members": [], + }, + } + group_index = population.members_entity_id[index] + other_people_indices = [ + index + for index in range(len(population.members_entity_id)) + if population.members_entity_id[index] == group_index + ] + + people_indices.extend(other_people_indices) + people_indices = list(set(people_indices)) + people_indices_by_entity[entity.key] = other_people_indices + for variable in self.input_variables: + if ( + self.tax_benefit_system.get_variable( + variable + ).entity.key + == entity.key + ): + known_periods = self.get_holder( + variable + ).get_known_periods() + if len(known_periods) > 0: + value = self.get_holder(variable).get_array( + known_periods[0] + )[group_index] + situation[entity.plural][entity.key][variable] = { + str(known_periods[0]): value + } + + person = self.populations["person"].entity + situation[person.plural] = {} + for person_index in people_indices: + person_name = f"{person.key}_{person_index + 1}" + for entity_key in people_indices_by_entity: + entity = self.populations[entity_key].entity + if person_index in people_indices_by_entity[entity.key]: + situation[entity.plural][entity.key]["members"].append( + person_name + ) + situation[person.plural][person_name] = {} + for variable in self.input_variables: + if ( + self.tax_benefit_system.get_variable(variable).entity.key + == person.key + ): + known_periods = self.get_holder( + variable + ).get_known_periods() + if len(known_periods) > 0: + value = self.get_holder(variable).get_array( + known_periods[0] + )[person_index] + situation[person.plural][person_name][variable] = { + str(known_periods[0]): value + } + + return json.loads(json.dumps(situation, cls=NpEncoder)) + + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return str(obj) diff --git a/policyengine_core/tools/__init__.py b/policyengine_core/tools/__init__.py index 4a8ad4ce..b1d67e64 100644 --- a/policyengine_core/tools/__init__.py +++ b/policyengine_core/tools/__init__.py @@ -7,6 +7,8 @@ from policyengine_core.enums import EnumArray +from .test_from_situation import generate_test_from_situation + def assert_near( value, diff --git a/policyengine_core/tools/test_from_situation.py b/policyengine_core/tools/test_from_situation.py new file mode 100644 index 00000000..4ec3bb86 --- /dev/null +++ b/policyengine_core/tools/test_from_situation.py @@ -0,0 +1,23 @@ +import yaml +from pathlib import Path +import numpy as np +import json + + +def generate_test_from_situation(situation: dict, file_path: str): + """Generate a test from a situation. + + Args: + situation (dict): The situation to generate the test from. + test_name (str): The name of the test. + """ + + yaml_contents = [ + { + "input": situation, + "output": {}, + } + ] + + with open(Path(file_path), "w+") as f: + yaml.dump(yaml_contents, f)