From 62c11ad75be2d51dd03e2faf964637f9b5c10b1c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:00:12 +0100 Subject: [PATCH] Bug fixes (#282) * Fix Make `random(entity)` deterministic based on entity IDs #280 * Fix Subsampling doesn't preserve `Dataset.time_period` #281 * Fix Remove openfisca_core dependencies #259 * Versioning * Remove Windows test * Remove Windows CI --- .github/workflows/pr.yaml | 5 +- .github/workflows/push.yaml | 5 +- changelog_entry.yaml | 5 + policyengine_core/commons/formulas.py | 48 +- policyengine_core/simulations/simulation.py | 2 +- policyengine_core/variables/defined_for.py | 8 +- setup.py | 14 +- test.ipynb | 491 ++++++++++++++++++++ 8 files changed, 549 insertions(+), 29 deletions(-) create mode 100644 test.ipynb diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 4c737d784..64cfaf519 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -31,10 +31,7 @@ jobs: - name: Check version number has been properly updated run: .github/is-version-number-acceptable.sh Test: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest, windows-latest] + runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml index f8add96ec..9cf4a8cf5 100644 --- a/.github/workflows/push.yaml +++ b/.github/workflows/push.yaml @@ -43,13 +43,10 @@ jobs: author_name: Github Actions[bot] message: Update PolicyEngine Core Test: - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest if: | (github.repository == 'PolicyEngine/policyengine-core') && (github.event.head_commit.message == 'Update PolicyEngine Core') - strategy: - matrix: - os: [ubuntu-latest, windows-latest] steps: - name: Checkout repo uses: actions/checkout@v3 diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..72671cca7 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - Randomness based on entity IDs as seeds. + - OpenFisca-Core imports. diff --git a/policyengine_core/commons/formulas.py b/policyengine_core/commons/formulas.py index 169319cc5..4695cf9c6 100644 --- a/policyengine_core/commons/formulas.py +++ b/policyengine_core/commons/formulas.py @@ -300,15 +300,45 @@ def amount_between( return clip(amount, threshold_1, threshold_2) - threshold_1 -def random(entity, reset=True): - if reset: - np.random.seed(0) - x = np.random.rand(entity.count) - if entity.simulation.has_axes: - # Generate the same random number for each entity. - random_number = x[0] - return np.array([random_number] * entity.count) - return x +def random(population): + """ + Generate random values for each entity in the population. + + Args: + population: The population object containing simulation data. + + Returns: + np.ndarray: Array of random values for each entity. + """ + # Initialize count of random calls if not already present + if not hasattr(population.simulation, "count_random_calls"): + population.simulation.count_random_calls = 0 + population.simulation.count_random_calls += 1 + + # Get known periods or use default calculation period + known_periods = population.simulation.get_holder( + f"{population.entity.key}_id" + ).get_known_periods() + period = ( + known_periods[0] + if known_periods + else population.simulation.default_calculation_period + ) + + # Get entity IDs for the period + entity_ids = population(f"{population.entity.key}_id", period) + + # Generate random values for each entity + values = np.array( + [ + np.random.default_rng( + seed=id * 100 + population.simulation.count_random_calls + ).random() + for id in entity_ids + ] + ) + + return values def is_in(values: ArrayLike, *targets: list) -> ArrayLike: diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index a7d253e3c..482c90e42 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -1547,7 +1547,7 @@ def subsample( ) # Update the dataset and rebuild the simulation - self.dataset = Dataset.from_dataframe(df) + self.dataset = Dataset.from_dataframe(df, self.dataset.time_period) self.build_from_dataset() return self diff --git a/policyengine_core/variables/defined_for.py b/policyengine_core/variables/defined_for.py index e6e194e93..29a2a2e1a 100644 --- a/policyengine_core/variables/defined_for.py +++ b/policyengine_core/variables/defined_for.py @@ -2,10 +2,10 @@ import numpy as np from numpy.typing import ArrayLike -from openfisca_core.entities import Entity -from openfisca_core.populations import GroupPopulation, Population -from openfisca_core.projectors import EntityToPersonProjector, Projector -from openfisca_core.variables import Variable +from policyengine_core.entities import Entity +from policyengine_core.populations import GroupPopulation, Population +from policyengine_core.projectors import EntityToPersonProjector, Projector +from policyengine_core.variables import Variable class CallableSubset: diff --git a/setup.py b/setup.py index 66a8325e8..753c32583 100644 --- a/setup.py +++ b/setup.py @@ -12,25 +12,25 @@ general_requirements = [ "pytest>=8,<9", "numpy~=1.26.4", - "black", - "linecheck<1", - "yaml-changelog<1", - "coverage", "sortedcontainers<3", "numexpr<3", "dpath<3", "psutil<6", "wheel<1", "h5py>=3,<4", - "requests>=2.27.1,<3", + "requests>=2,<3", "pandas>=1", - "plotly>=5.6.0,<6", - "ipython>=7.17.0,<8", + "plotly>=5,<6", + "ipython>=7,<8", "pyvis>=0.3.2", ] dev_requirements = [ + "black", + "linecheck<1", "jupyter-book<1", + "yaml-changelog<1", + "coverage", "furo<2023", "markupsafe==2.0.1", "coverage", diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 000000000..27c0bb532 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from policyengine_uk import Microsimulation\n", + "from policyengine_core.reforms import Reform\n", + "\n", + "reform = Reform.from_dict({\n", + " \"gov.hmrc.vat.standard_rate\": {\n", + " \"2024-01-01.2100-12-31\": 0.22\n", + " }\n", + "}, country_id=\"uk\")\n", + "\n", + "\n", + "baseline = Microsimulation()\n", + "reformed = Microsimulation(reform=reform)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | age | \n", + "vat | \n", + "consumption | \n", + "full_rate_vat_expenditure_rate | \n", + "food_and_non_alcoholic_beverages_consumption | \n", + "alcohol_and_tobacco_consumption | \n", + "clothing_and_footwear_consumption | \n", + "housing_water_and_electricity_consumption | \n", + "household_furnishings_consumption | \n", + "health_consumption | \n", + "transport_consumption | \n", + "communication_consumption | \n", + "recreation_consumption | \n", + "education_consumption | \n", + "restaurants_and_hotels_consumption | \n", + "miscellaneous_consumption | \n", + "petrol_spending | \n", + "diesel_spending | \n", + "domestic_energy_consumption | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
312 | \n", + "80.0 | \n", + "-929.965881 | \n", + "2118.353027 | \n", + "-0.841608 | \n", + "1270.274780 | \n", + "0.0 | \n", + "0.0 | \n", + "614.104309 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "187.915909 | \n", + "0.0 | \n", + "0.0 | \n", + "46.057823 | \n", + "0.0 | \n", + "0.0 | \n", + "283.716187 | \n", + "
1039 | \n", + "80.0 | \n", + "-129.693314 | \n", + "-1217.768799 | \n", + "0.194851 | \n", + "85.360497 | \n", + "0.0 | \n", + "0.0 | \n", + "-1903.723389 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "291.699554 | \n", + "0.0 | \n", + "0.0 | \n", + "308.894470 | \n", + "0.0 | \n", + "0.0 | \n", + "666.303162 | \n", + "
1040 | \n", + "52.0 | \n", + "-129.693314 | \n", + "-1217.768799 | \n", + "0.194851 | \n", + "85.360497 | \n", + "0.0 | \n", + "0.0 | \n", + "-1903.723389 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "291.699554 | \n", + "0.0 | \n", + "0.0 | \n", + "308.894470 | \n", + "0.0 | \n", + "0.0 | \n", + "666.303162 | \n", + "
1896 | \n", + "80.0 | \n", + "-593.512024 | \n", + "1351.950684 | \n", + "-0.841608 | \n", + "461.806427 | \n", + "0.0 | \n", + "0.0 | \n", + "666.917297 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "187.915909 | \n", + "0.0 | \n", + "0.0 | \n", + "35.310997 | \n", + "0.0 | \n", + "0.0 | \n", + "354.338196 | \n", + "
2057 | \n", + "35.0 | \n", + "-40.791981 | \n", + "-1168.026367 | \n", + "0.058855 | \n", + "692.402588 | \n", + "0.0 | \n", + "0.0 | \n", + "-2319.164795 | \n", + "0.0 | \n", + "0.0 | \n", + "30.09111 | \n", + "0.0 | \n", + "187.915909 | \n", + "0.0 | \n", + "0.0 | \n", + "240.728882 | \n", + "0.0 | \n", + "0.0 | \n", + "-2926.207031 | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
209307 | \n", + "50.0 | \n", + "-9.732808 | \n", + "-125.584320 | \n", + "0.139750 | \n", + "131.418320 | \n", + "0.0 | \n", + "0.0 | \n", + "-424.346069 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "0.000000 | \n", + "0.0 | \n", + "0.0 | \n", + "167.343414 | \n", + "0.0 | \n", + "0.0 | \n", + "0.000000 | \n", + "
209308 | \n", + "49.0 | \n", + "-9.732808 | \n", + "-125.584320 | \n", + "0.139750 | \n", + "131.418320 | \n", + "0.0 | \n", + "0.0 | \n", + "-424.346069 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "0.000000 | \n", + "0.0 | \n", + "0.0 | \n", + "167.343414 | \n", + "0.0 | \n", + "0.0 | \n", + "0.000000 | \n", + "
209309 | \n", + "13.0 | \n", + "-9.732808 | \n", + "-125.584320 | \n", + "0.139750 | \n", + "131.418320 | \n", + "0.0 | \n", + "0.0 | \n", + "-424.346069 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "0.000000 | \n", + "0.0 | \n", + "0.0 | \n", + "167.343414 | \n", + "0.0 | \n", + "0.0 | \n", + "0.000000 | \n", + "
212169 | \n", + "68.0 | \n", + "-23.210276 | \n", + "-378.288330 | \n", + "0.109076 | \n", + "1337.519165 | \n", + "0.0 | \n", + "0.0 | \n", + "-1903.723389 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "187.915909 | \n", + "0.0 | \n", + "0.0 | \n", + "0.000000 | \n", + "0.0 | \n", + "0.0 | \n", + "638.054382 | \n", + "
212170 | \n", + "65.0 | \n", + "-23.210276 | \n", + "-378.288330 | \n", + "0.109076 | \n", + "1337.519165 | \n", + "0.0 | \n", + "0.0 | \n", + "-1903.723389 | \n", + "0.0 | \n", + "0.0 | \n", + "0.00000 | \n", + "0.0 | \n", + "187.915909 | \n", + "0.0 | \n", + "0.0 | \n", + "0.000000 | \n", + "0.0 | \n", + "0.0 | \n", + "638.054382 | \n", + "
464 rows × 19 columns
\n", + "