code | \n", + "name | \n", + "total_income_count | \n", + "total_income_amount | \n", + "
---|---|---|---|
\n",
+ " \n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "Loading ITables v2.2.3 from the init_notebook_mode cell...\n",
+ "(need help?) | \n",
+ "\n",
+ "
code | \n", + "name | \n", + "all | \n", + "0 | \n", + "1 | \n", + "2 | \n", + "3 | \n", + "4 | \n", + "5 | \n", + "6 | \n", + "7 | \n", + "8 | \n", + "9 | \n", + "10 | \n", + "11 | \n", + "12 | \n", + "13 | \n", + "14 | \n", + "15 | \n", + "16 | \n", + "17 | \n", + "18 | \n", + "19 | \n", + "20 | \n", + "21 | \n", + "22 | \n", + "23 | \n", + "24 | \n", + "25 | \n", + "26 | \n", + "27 | \n", + "28 | \n", + "29 | \n", + "30 | \n", + "31 | \n", + "32 | \n", + "33 | \n", + "34 | \n", + "35 | \n", + "36 | \n", + "37 | \n", + "38 | \n", + "39 | \n", + "40 | \n", + "41 | \n", + "42 | \n", + "43 | \n", + "44 | \n", + "45 | \n", + "46 | \n", + "47 | \n", + "48 | \n", + "49 | \n", + "50 | \n", + "51 | \n", + "52 | \n", + "53 | \n", + "54 | \n", + "55 | \n", + "56 | \n", + "57 | \n", + "58 | \n", + "59 | \n", + "60 | \n", + "61 | \n", + "62 | \n", + "63 | \n", + "64 | \n", + "65 | \n", + "66 | \n", + "67 | \n", + "68 | \n", + "69 | \n", + "70 | \n", + "71 | \n", + "72 | \n", + "73 | \n", + "74 | \n", + "75 | \n", + "76 | \n", + "77 | \n", + "78 | \n", + "79 | \n", + "80 | \n", + "81 | \n", + "82 | \n", + "83 | \n", + "84 | \n", + "85 | \n", + "86 | \n", + "87 | \n", + "88 | \n", + "89 | \n", + "90+ | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
\n",
+ " \n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "Loading ITables v2.2.3 from the init_notebook_mode cell...\n",
+ "(need help?) | \n",
+ "\n",
+ "
code | \n", + "name | \n", + "employment_income_lower_bound | \n", + "employment_income_upper_bound | \n", + "employment_income_count | \n", + "employment_income_amount | \n", + "
---|---|---|---|---|---|
\n",
+ " \n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "Loading ITables v2.2.3 from the init_notebook_mode cell...\n",
+ "(need help?) | \n",
+ "\n",
+ "
Unnamed: 0 | \n", + "E14001063 | \n", + "E14001064 | \n", + "E14001065 | \n", + "E14001066 | \n", + "E14001067 | \n", + "E14001294 | \n", + "E14001366 | \n", + "E14001599 | \n", + "E14001068 | \n", + "E14001140 | \n", + "E14001069 | \n", + "E14001570 | \n", + "E14001070 | \n", + "E14001352 | \n", + "E14001071 | \n", + "E14001360 | \n", + "E14001600 | \n", + "E14001072 | \n", + "E14001090 | \n", + "E14001073 | \n", + "E14001189 | \n", + "E14001074 | \n", + "E14001075 | \n", + "E14001076 | \n", + "E14001077 | \n", + "E14001078 | \n", + "E14001392 | \n", + "E14001403 | \n", + "E14001079 | \n", + "E14001375 | \n", + "E14001080 | \n", + "E14001196 | \n", + "E14001506 | \n", + "E14001081 | \n", + "E14001434 | \n", + "E14001082 | \n", + "E14001162 | \n", + "E14001083 | \n", + "E14001137 | \n", + "E14001084 | \n", + "E14001359 | \n", + "E14001384 | \n", + "E14001085 | \n", + "E14001421 | \n", + "E14001559 | \n", + "E14001285 | \n", + "E14001397 | \n", + "E14001086 | \n", + "E14001525 | \n", + "E14001087 | \n", + "E14001127 | \n", + "E14001088 | \n", + "E14001274 | \n", + "E14001330 | \n", + "E14001533 | \n", + "E14001089 | \n", + "E14001229 | \n", + "E14001414 | \n", + "E14001091 | \n", + "E14001092 | \n", + "E14001096 | \n", + "E14001097 | \n", + "E14001093 | \n", + "E14001094 | \n", + "E14001099 | \n", + "E14001100 | \n", + "E14001095 | \n", + "E14001098 | \n", + "E14001101 | \n", + "E14001382 | \n", + "E14001102 | \n", + "E14001450 | \n", + "E14001103 | \n", + "E14001145 | \n", + "E14001459 | \n", + "E14001104 | \n", + "E14001105 | \n", + "E14001106 | \n", + "E14001244 | \n", + "E14001567 | \n", + "E14001107 | \n", + "E14001183 | \n", + "E14001108 | \n", + "E14001166 | \n", + "E14001109 | \n", + "E14001391 | \n", + "E14001110 | \n", + "E14001111 | \n", + "E14001112 | \n", + "E14001329 | \n", + "E14001113 | \n", + "E14001114 | \n", + "E14001343 | \n", + "E14001288 | \n", + "E14001364 | \n", + "E14001115 | \n", + "E14001116 | \n", + "E14001363 | \n", + "E14001429 | \n", + "N05000012 | \n", + "N05000006 | \n", + "N05000007 | \n", + "N05000010 | \n", + "N05000008 | \n", + "N05000018 | \n", + "N05000009 | \n", + "N05000015 | \n", + "N05000011 | \n", + "N05000017 | \n", + "N05000016 | \n", + "S14000060 | \n", + "S14000061 | \n", + "S14000063 | \n", + "S14000070 | \n", + "S14000065 | \n", + "S14000066 | \n", + "S14000067 | \n", + "S14000107 | \n", + "S14000062 | \n", + "S14000091 | \n", + "S14000108 | \n", + "S14000069 | \n", + "S14000109 | \n", + "S14000072 | \n", + "S14000097 | \n", + "S14000073 | \n", + "S14000074 | \n", + "S14000075 | \n", + "S14000071 | \n", + "S14000076 | \n", + "S14000086 | \n", + "S14000077 | \n", + "S14000092 | \n", + "S14000104 | \n", + "S14000078 | \n", + "S14000096 | \n", + "S14000021 | \n", + "S14000080 | \n", + "S14000079 | \n", + "S14000082 | \n", + "S14000081 | \n", + "S14000027 | \n", + "S14000064 | \n", + "S14000083 | \n", + "S14000084 | \n", + "S14000085 | \n", + "S14000087 | \n", + "S14000088 | \n", + "S14000089 | \n", + "S14000106 | \n", + "S14000101 | \n", + "S14000090 | \n", + "S14000100 | \n", + "S14000093 | \n", + "S14000094 | \n", + "S14000098 | \n", + "S14000110 | \n", + "S14000099 | \n", + "S14000068 | \n", + "S14000095 | \n", + "S14000045 | \n", + "S14000048 | \n", + "S14000103 | \n", + "S14000105 | \n", + "S14000051 | \n", + "S14000102 | \n", + "S14000111 | \n", + "W07000112 | \n", + "W07000082 | \n", + "W07000094 | \n", + "W07000111 | \n", + "W07000098 | \n", + "W07000097 | \n", + "W07000103 | \n", + "W07000108 | \n", + "W07000081 | \n", + "W07000089 | \n", + "W07000091 | \n", + "W07000090 | \n", + "W07000107 | \n", + "W07000109 | \n", + "W07000101 | \n", + "W07000104 | \n", + "W07000105 | \n", + "W07000083 | \n", + "W07000096 | \n", + "W07000095 | \n", + "W07000102 | \n", + "W07000093 | \n", + "W07000100 | \n", + "W07000087 | \n", + "W07000085 | \n", + "W07000099 | \n", + "W07000106 | \n", + "W07000084 | \n", + "W07000086 | \n", + "W07000092 | \n", + "W07000088 | \n", + "W07000110 | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
\n",
+ " \n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "Loading ITables v2.2.3 from the init_notebook_mode cell...\n",
+ "(need help?) | \n",
+ "\n",
+ "
\n", @@ -344,9 +350,9 @@ "\n" ], "text/plain": [ - " index name \\\n", - "11360 E14001373 New Forest East \n", - "4325 E14001488 South Leicestershire \n", - "821 E14001234 Farnham and Bordon \n", - "2676 E14001139 Broxbourne \n", - "10790 E14001453 Rugby \n", - "... ... ... \n", - "6697 E14001260 Hackney South and Shoreditch \n", - "7249 E14001162 Chesham and Amersham \n", - "6599 E14001162 Chesham and Amersham \n", - "7388 E14001301 Ilford South \n", - "6738 E14001301 Ilford South \n", + " index name \\\n", + "10276 E14001589 Wirral West \n", + "5283 E14001146 Bury St Edmunds and Stowmarket \n", + "4154 E14001317 Knowsley \n", + "5855 E14001068 Ashfield \n", + "5197 W07000110 Vale of Glamorgan \n", + "... ... ... \n", + "6842 E14001405 North West Norfolk \n", + "7791 W07000104 Newport East \n", + "7141 W07000104 Newport East \n", + "6643 E14001206 Dunstable and Leighton Buzzard \n", + "7293 E14001206 Dunstable and Leighton Buzzard \n", "\n", " metric estimate target \\\n", - "11360 hmrc/employment_income/amount/30000_40000 2.509288e+08 2.509499e+08 \n", - "4325 age/40_50 1.193506e+04 1.193367e+04 \n", - "821 hmrc/total_income/count 6.737637e+04 6.736803e+04 \n", - "2676 age/20_30 1.162183e+04 1.162016e+04 \n", - "10790 hmrc/employment_income/count/30000_40000 7.828644e+03 7.830460e+03 \n", + "10276 hmrc/employment_income/amount/20000_30000 1.322333e+08 1.322420e+08 \n", + "5283 age/60_70 1.451102e+04 1.450993e+04 \n", + "4154 age/40_50 1.269529e+04 1.269403e+04 \n", + "5855 age/70_80 1.078475e+04 1.078346e+04 \n", + "5197 age/50_60 1.489186e+04 1.489368e+04 \n", "... ... ... ... \n", - "6697 hmrc/employment_income/count/12570_15000 1.436137e+03 4.885362e+01 \n", - "7249 hmrc/employment_income/amount/12570_15000 1.991328e+07 6.456983e+05 \n", - "6599 hmrc/employment_income/count/12570_15000 1.435062e+03 4.649818e+01 \n", - "7388 hmrc/employment_income/amount/12570_15000 1.991574e+07 6.234373e+05 \n", - "6738 hmrc/employment_income/count/12570_15000 1.435239e+03 4.489512e+01 \n", + "6842 hmrc/employment_income/count/12570_15000 8.755749e+02 8.431956e+01 \n", + "7791 hmrc/employment_income/amount/12570_15000 1.544117e+07 1.432437e+06 \n", + "7141 hmrc/employment_income/count/12570_15000 1.118195e+03 1.031530e+02 \n", + "6643 hmrc/employment_income/count/12570_15000 7.247283e+02 6.187779e+01 \n", + "7293 hmrc/employment_income/amount/12570_15000 1.006603e+07 8.592676e+05 \n", "\n", " error abs_error rel_abs_error \n", - "11360 -2.108173e+04 2.108173e+04 0.000084 \n", - "4325 1.389142e+00 1.389142e+00 0.000116 \n", - "821 8.340814e+00 8.340814e+00 0.000124 \n", - "2676 1.670008e+00 1.670008e+00 0.000144 \n", - "10790 -1.816220e+00 1.816220e+00 0.000232 \n", + "10276 -8.699413e+03 8.699413e+03 0.000066 \n", + "5283 1.083593e+00 1.083593e+00 0.000075 \n", + "4154 1.256512e+00 1.256512e+00 0.000099 \n", + "5855 1.288091e+00 1.288091e+00 0.000119 \n", + "5197 -1.817116e+00 1.817116e+00 0.000122 \n", "... ... ... ... \n", - "6697 1.387283e+03 1.387283e+03 28.396738 \n", - "7249 1.926758e+07 1.926758e+07 29.839914 \n", - "6599 1.388564e+03 1.388564e+03 29.862760 \n", - "7388 1.929230e+07 1.929230e+07 30.945050 \n", - "6738 1.390344e+03 1.390344e+03 30.968717 \n", + "6842 7.912553e+02 7.912553e+02 9.384007 \n", + "7791 1.400874e+07 1.400874e+07 9.779649 \n", + "7141 1.015042e+03 1.015042e+03 9.840155 \n", + "6643 6.628505e+02 6.628505e+02 10.712253 \n", + "7293 9.206761e+06 9.206761e+06 10.714661 \n", "\n", "[14300 rows x 8 columns]" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -411,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -433,13 +439,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - " |
---|
\n", @@ -554,9 +560,9 @@ "\n" ], "text/plain": [ @@ -174,14 +174,24 @@ "import itables.options as opt\n", "from pathlib import Path\n", "from policyengine_uk_data.storage import STORAGE_FOLDER\n", + "from policyengine.utils.huggingface import download\n", "\n", "opt.maxBytes = \"1MB\"\n", "init_notebook_mode(all_interactive=True)\n", "\n", "REPO = Path(\".\").resolve().parent\n", "\n", - "with h5py.File(STORAGE_FOLDER / \"local_authority_weights.h5\", \"r\") as f:\n", - " weights = f[\"2025\"][:]\n", + "weights_file_path = STORAGE_FOLDER / \"local_authority_weights.h5\"\n", + "constituency_names_file_path = download(\n", + " repo=\"policyengine/policyengine-uk-data\",\n", + " repo_filename=\"local_authorities_2021.csv\",\n", + " local_folder=None,\n", + " version=None,\n", + ")\n", + "constituencies_2024 = pd.read_csv(constituency_names_file_path)\n", + "\n", + "with h5py.File(weights_file_path, \"r\") as f:\n", + " weights = f[str(2025)][...]\n", "\n", "baseline = Microsimulation()\n", "household_weights = baseline.calculate(\"household_weight\", 2025).values\n", @@ -190,7 +200,6 @@ "\n", "local_authority_target_matrix, local_authority_actuals = create_local_authority_target_matrix(\"enhanced_frs_2022_23\", 2025, None)\n", "national_target_matrix, national_actuals = create_national_target_matrix(\"enhanced_frs_2022_23\", 2025, None)\n", - "constituencies_2024 = pd.read_csv(STORAGE_FOLDER / \"local_authorities_2021.csv\")\n", "\n", "local_authority_wide = weights @ local_authority_target_matrix\n", "local_authority_wide.index = constituencies_2024.code.values\n", @@ -221,13 +230,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ - " |
---|
\n", @@ -344,9 +353,9 @@ "\n" ], "text/plain": [ - " index name \\\n", - "7682 E07000102 Three Rivers \n", - "3227 W06000011 Swansea \n", - "5640 E08000014 Sefton \n", - "6179 E06000063 Cumberland \n", - "7022 E07000193 East Staffordshire \n", - "... ... ... \n", - "4998 S12000023 Orkney Islands \n", - "4943 E09000001 City of London \n", - "769 E06000053 Isles of Scilly \n", - "1129 E06000053 Isles of Scilly \n", - "1849 E06000053 Isles of Scilly \n", - "\n", - " metric estimate target \\\n", - "7682 hmrc/employment_income/amount/50000_70000 6.835184e+08 6.835077e+08 \n", - "3227 age/60_70 2.865268e+04 2.865445e+04 \n", - "5640 hmrc/employment_income/amount/20000_30000 3.690511e+08 3.690074e+08 \n", - "6179 hmrc/employment_income/amount/30000_40000 7.644177e+08 7.643270e+08 \n", - "7022 hmrc/employment_income/amount/40000_50000 3.140948e+08 3.140508e+08 \n", - "... ... ... ... \n", - "4998 hmrc/employment_income/amount/15000_20000 4.427085e+06 1.308275e+06 \n", - "4943 hmrc/employment_income/amount/15000_20000 4.447282e+06 1.308275e+06 \n", - "769 age/0_10 6.712070e+02 1.954337e+02 \n", - "1129 age/10_20 7.329886e+02 2.066304e+02 \n", - "1849 age/30_40 8.760654e+02 2.300417e+02 \n", - "\n", - " error abs_error rel_abs_error \n", - "7682 1.071678e+04 1.071678e+04 0.000016 \n", - "3227 -1.769448e+00 1.769448e+00 0.000062 \n", - "5640 4.374969e+04 4.374969e+04 0.000119 \n", - "6179 9.073156e+04 9.073156e+04 0.000119 \n", - "7022 4.394198e+04 4.394198e+04 0.000140 \n", - "... ... ... ... \n", - "4998 3.118810e+06 3.118810e+06 2.383910 \n", - "4943 3.139007e+06 3.139007e+06 2.399348 \n", - "769 4.757734e+02 4.757734e+02 2.434449 \n", - "1129 5.263582e+02 5.263582e+02 2.547342 \n", - "1849 6.460237e+02 6.460237e+02 2.808289 \n", + " index name ... abs_error rel_abs_error\n", + "1024 N09000009 Mid Ulster ... 7.786641e-01 0.000036\n", + "3485 E08000019 Sheffield ... 1.751454e+00 0.000039\n", + "2392 E08000006 Salford ... 1.939367e+00 0.000058\n", + "174 E07000175 Newark and Sherwood ... 2.067212e+05 0.000072\n", + "6517 E06000040 Windsor and Maidenhead ... 6.069600e-01 0.000077\n", + "... ... ... ... ... ...\n", + "4998 S12000023 Orkney Islands ... 3.384977e+06 2.587359\n", + "4943 E09000001 City of London ... 3.397485e+06 2.596920\n", + "5000 S12000027 Shetland Islands ... 3.436770e+06 2.626947\n", + "1129 E06000053 Isles of Scilly ... 5.483983e+02 2.654006\n", + "1849 E06000053 Isles of Scilly ... 6.821887e+02 2.965500\n", "\n", "[7920 rows x 8 columns]" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } diff --git a/policyengine_uk_data/datasets/frs/local_areas/constituencies/calibrate.py b/policyengine_uk_data/datasets/frs/local_areas/constituencies/calibrate.py index 0dc38ad..f94c758 100644 --- a/policyengine_uk_data/datasets/frs/local_areas/constituencies/calibrate.py +++ b/policyengine_uk_data/datasets/frs/local_areas/constituencies/calibrate.py @@ -4,6 +4,7 @@ import numpy as np from tqdm import tqdm import h5py +import os from policyengine_uk_data.datasets.frs.local_areas.constituencies.transform_constituencies import ( transform_2010_to_2024, ) @@ -57,6 +58,18 @@ def loss(w): return mse_c + mse_n + def pct_close(w, t=0.1): + # Return the percentage of metrics that are within t% of the target + pred_c = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1) + e_c = torch.sum(torch.abs((pred_c / (1 + y) - 1)) < t) + c_c = pred_c.shape[0] * pred_c.shape[1] + + pred_n = (w.sum(axis=0) * matrix_national.T).sum(axis=1) + e_n = torch.sum(torch.abs((pred_n / (1 + y_national) - 1)) < t) + c_n = pred_n.shape[0] + + return (e_c + e_n) / (c_c + c_n) + def dropout_weights(weights, p): if p == 0: return weights @@ -69,7 +82,7 @@ def dropout_weights(weights, p): optimizer = torch.optim.Adam([weights], lr=0.1) - desc = range(512) + desc = range(32) if os.environ.get("DATA_LITE") else range(256) for epoch in desc: optimizer.zero_grad() @@ -77,8 +90,9 @@ def dropout_weights(weights, p): l = loss(torch.exp(weights_)) l.backward() optimizer.step() - if epoch % 50 == 0: - print(f"Loss: {l.item()}, Epoch: {epoch}") + close = pct_close(torch.exp(weights_)) + if epoch % 10 == 0: + print(f"Loss: {l.item()}, Epoch: {epoch}, Within 10%: {close:.2%}") final_weights = torch.exp(weights).detach().numpy() mapping_matrix = pd.read_csv( diff --git a/policyengine_uk_data/datasets/frs/local_areas/local_authorities/calibrate.py b/policyengine_uk_data/datasets/frs/local_areas/local_authorities/calibrate.py index 8142672..3a8d71c 100644 --- a/policyengine_uk_data/datasets/frs/local_areas/local_authorities/calibrate.py +++ b/policyengine_uk_data/datasets/frs/local_areas/local_authorities/calibrate.py @@ -4,10 +4,11 @@ import numpy as np from tqdm import tqdm import h5py +import os from policyengine_uk_data.storage import STORAGE_FOLDER -from loss import ( +from policyengine_uk_data.datasets.frs.local_areas.local_authorities.loss import ( create_local_authority_target_matrix, create_national_target_matrix, ) @@ -50,6 +51,18 @@ def loss(w): return mse_c + mse_n + def pct_close(w, t=0.1): + # Return the percentage of metrics that are within t% of the target + pred_c = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1) + e_c = torch.sum(torch.abs((pred_c / (1 + y) - 1)) < t) + c_c = pred_c.shape[0] * pred_c.shape[1] + + pred_n = (w.sum(axis=0) * matrix_national.T).sum(axis=1) + e_n = torch.sum(torch.abs((pred_n / (1 + y_national) - 1)) < t) + c_n = pred_n.shape[0] + + return (e_c + e_n) / (c_c + c_n) + def dropout_weights(weights, p): if p == 0: return weights @@ -62,7 +75,7 @@ def dropout_weights(weights, p): optimizer = torch.optim.Adam([weights], lr=0.1) - desc = range(512) + desc = range(32) if os.environ.get("DATA_LITE") else range(256) for epoch in desc: optimizer.zero_grad() @@ -70,8 +83,9 @@ def dropout_weights(weights, p): l = loss(torch.exp(weights_)) l.backward() optimizer.step() - if epoch % 50 == 0: - print(f"Loss: {l.item()}, Epoch: {epoch}") + close = pct_close(torch.exp(weights_)) + if epoch % 10 == 0: + print(f"Loss: {l.item()}, Epoch: {epoch}, Within 10%: {close:.2%}") if epoch % 100 == 0: final_weights = torch.exp(weights).detach().numpy() diff --git a/policyengine_uk_data/storage/download_private_prerequisites.py b/policyengine_uk_data/storage/download_private_prerequisites.py index 2094a64..bb390d9 100644 --- a/policyengine_uk_data/storage/download_private_prerequisites.py +++ b/policyengine_uk_data/storage/download_private_prerequisites.py @@ -28,6 +28,5 @@ def extract_zipped_folder(folder): repo_filename=file.name, local_folder=file.parent, ) - print(f"Extracting {file}") extract_zipped_folder(file) file.unlink() diff --git a/policyengine_uk_data/utils/huggingface.py b/policyengine_uk_data/utils/huggingface.py index a46da04..95f2a81 100644 --- a/policyengine_uk_data/utils/huggingface.py +++ b/policyengine_uk_data/utils/huggingface.py @@ -9,7 +9,6 @@ def download( token = os.environ.get( "HUGGING_FACE_TOKEN", ) - login(token=token) hf_hub_download( repo_id=repo, @@ -17,6 +16,7 @@ def download( filename=repo_filename, local_dir=local_folder, revision=version, + token=token, ) diff --git a/policyengine_uk_data/utils/reweight.py b/policyengine_uk_data/utils/reweight.py index 9f25d17..f07d669 100644 --- a/policyengine_uk_data/utils/reweight.py +++ b/policyengine_uk_data/utils/reweight.py @@ -1,5 +1,6 @@ import numpy as np import torch +import os def reweight( @@ -32,6 +33,12 @@ def loss(weights): raise ValueError("Relative error contains NaNs") return rel_error.mean() + def pct_close(weights, t=0.1): + # Return the percentage of metrics that are within t% of the target + estimate = weights @ loss_matrix + abs_error = torch.abs((estimate - targets_array) / (1 + targets_array)) + return (abs_error < t).sum() / abs_error.numel() + def dropout_weights(weights, p): if p == 0: return weights @@ -47,17 +54,20 @@ def dropout_weights(weights, p): start_loss = None - iterator = range(1_000) + iterator = range(128) if os.environ.get("DATA_LITE") else range(2048) for i in iterator: optimizer.zero_grad() weights_ = dropout_weights(weights, dropout_rate) l = loss(torch.exp(weights_)) + close = pct_close(torch.exp(weights_)) if start_loss is None: start_loss = l.item() loss_rel_change = (l.item() - start_loss) / start_loss l.backward() if i % 100 == 0: - print(f"Loss: {l.item()}, Rel change: {loss_rel_change}") + print( + f"Loss: {l.item()}, Rel change: {loss_rel_change}, Epoch: {i}, Within 10%: {close:.2%}" + ) optimizer.step() return torch.exp(weights).detach().numpy() |
---|