From b8d68f36fab484d857d086cf22ddd58048ae2a1c Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:09:54 +0100 Subject: [PATCH] Calibration improvements (#36) * Increase epochs per year to 10k * Update data urls * Add calibration improvements --- .gitignore | 1 + CHANGELOG.md | 8 + changelog.yaml | 6 + .../datasets/frs/enhanced_frs.py | 63 +----- policyengine_uk_data/datasets/frs/frs.py | 23 ++ policyengine_uk_data/datasets/spi.py | 25 ++- .../storage/incomes_projection.csv | 113 ++++++++++ .../utils/incomes_projection.py | 207 ++++++++++++++++++ policyengine_uk_data/utils/loss.py | 11 +- policyengine_uk_data/utils/reweight.py | 65 ++++++ pyproject.toml | 2 +- 11 files changed, 450 insertions(+), 74 deletions(-) create mode 100644 policyengine_uk_data/storage/incomes_projection.csv create mode 100644 policyengine_uk_data/utils/incomes_projection.py create mode 100644 policyengine_uk_data/utils/reweight.py diff --git a/.gitignore b/.gitignore index 4ed1a4e..6a41c42 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ !incomes.csv !tax_benefit.csv !demographics.csv +!incomes_projection.csv **/_build diff --git a/CHANGELOG.md b/CHANGELOG.md index de94c95..b54c200 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.6.0] - 2024-10-18 16:05:10 + +### Added + +- Future year income targeting. +- Random takeup variable values. + ## [1.5.0] - 2024-10-16 17:05:58 ### Added @@ -66,6 +73,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 +[1.6.0]: https://github.com/PolicyEngine/policyengine-us-data/compare/1.5.0...1.6.0 [1.5.0]: https://github.com/PolicyEngine/policyengine-us-data/compare/1.4.0...1.5.0 [1.4.0]: https://github.com/PolicyEngine/policyengine-us-data/compare/1.3.0...1.4.0 [1.3.0]: https://github.com/PolicyEngine/policyengine-us-data/compare/1.2.5...1.3.0 diff --git a/changelog.yaml b/changelog.yaml index ac1d6bb..e141709 100644 --- a/changelog.yaml +++ b/changelog.yaml @@ -54,3 +54,9 @@ added: - Moved epoch count to 10k per year. date: 2024-10-16 17:05:58 +- bump: minor + changes: + added: + - Future year income targeting. + - Random takeup variable values. + date: 2024-10-18 16:05:10 diff --git a/policyengine_uk_data/datasets/frs/enhanced_frs.py b/policyengine_uk_data/datasets/frs/enhanced_frs.py index a8e6ba8..b3af56f 100644 --- a/policyengine_uk_data/datasets/frs/enhanced_frs.py +++ b/policyengine_uk_data/datasets/frs/enhanced_frs.py @@ -11,6 +11,7 @@ try: import torch + from policyengine_uk_data.utils.reweight import reweight except ImportError: torch = None @@ -59,68 +60,6 @@ class EnhancedFRS_2022_23(EnhancedFRS): url = "release://PolicyEngine/ukda/1.5.0/enhanced_frs_2022_23.h5" -def reweight( - original_weights, - loss_matrix, - targets_array, - dropout_rate=0.05, -): - target_names = np.array(loss_matrix.columns) - loss_matrix = torch.tensor(loss_matrix.values, dtype=torch.float32) - targets_array = torch.tensor(targets_array, dtype=torch.float32) - weights = torch.tensor( - np.log(original_weights), requires_grad=True, dtype=torch.float32 - ) - - # TODO: replace this with a call to the python reweight.py package. - def loss(weights): - # Check for Nans in either the weights or the loss matrix - if torch.isnan(weights).any(): - raise ValueError("Weights contain NaNs") - if torch.isnan(loss_matrix).any(): - raise ValueError("Loss matrix contains NaNs") - estimate = weights @ loss_matrix - if torch.isnan(estimate).any(): - raise ValueError("Estimate contains NaNs") - rel_error = ( - ((estimate - targets_array) + 1) / (targets_array + 1) - ) ** 2 - if torch.isnan(rel_error).any(): - raise ValueError("Relative error contains NaNs") - return rel_error.mean() - - def dropout_weights(weights, p): - if p == 0: - return weights - # Replace p% of the weights with the mean value of the rest of them - mask = torch.rand_like(weights) < p - mean = weights[~mask].mean() - masked_weights = weights.clone() - masked_weights[mask] = mean - return masked_weights - - optimizer = torch.optim.Adam([weights], lr=1e-1) - from tqdm import trange - - start_loss = None - - iterator = trange(10_000) - for i in iterator: - optimizer.zero_grad() - weights_ = dropout_weights(weights, dropout_rate) - l = loss(torch.exp(weights_)) - if start_loss is None: - start_loss = l.item() - loss_rel_change = (l.item() - start_loss) / start_loss - l.backward() - iterator.set_postfix( - {"loss": l.item(), "loss_rel_change": loss_rel_change} - ) - optimizer.step() - - return torch.exp(weights).detach().numpy() - - if __name__ == "__main__": ReweightedFRS_2022_23().generate() EnhancedFRS_2022_23().generate() diff --git a/policyengine_uk_data/datasets/frs/frs.py b/policyengine_uk_data/datasets/frs/frs.py index 5cf3a0b..95fdd9c 100644 --- a/policyengine_uk_data/datasets/frs/frs.py +++ b/policyengine_uk_data/datasets/frs/frs.py @@ -92,6 +92,29 @@ def generate(self): self.save_dataset(frs) + self.add_random_variables(frs) + + def add_random_variables(self, frs: dict): + from policyengine_uk import Microsimulation + + simulation = Microsimulation(dataset=self) + RANDOM_VARIABLES = [ + "attends_private_school", + "would_evade_tv_licence_fee", + "would_claim_pc", + "would_claim_uc", + "would_claim_child_benefit", + "main_residential_property_purchased_is_first_home", + "household_owns_tv", + "is_higher_earner", + ] + INPUT_PERIODS = list(range(self.time_period, self.time_period + 10)) + for variable in RANDOM_VARIABLES: + value = simulation.calculate(variable, self.time_period).values + frs[variable] = {period: value for period in INPUT_PERIODS} + + self.save_dataset(frs) + class FRS_2020_21(FRS): dwp_frs = DWP_FRS_2020_21 diff --git a/policyengine_uk_data/datasets/spi.py b/policyengine_uk_data/datasets/spi.py index df83303..b317f04 100644 --- a/policyengine_uk_data/datasets/spi.py +++ b/policyengine_uk_data/datasets/spi.py @@ -64,8 +64,29 @@ def generate(self): data["savings_starter_rate_income"] = np.zeros(len(df)) data["capital_allowances"] = df.CAPALL data["loss_relief"] = df.LOSSBF - data["is_SP_age"] = df.SPA == 1 - data["state_pension"] = df.SRP + + AGE_RANGES = { + -1: (16, 70), + 1: (16, 25), + 2: (25, 35), + 3: (35, 45), + 4: (45, 55), + 5: (55, 65), + 6: (65, 74), + 7: (74, 90), + } + age_range = df.AGERANGE + + # Randomly assign ages in age ranges + + percent_along_age_range = np.random.rand(len(df)) + min_age = np.array([AGE_RANGES[age][0] for age in age_range]) + max_age = np.array([AGE_RANGES[age][1] for age in age_range]) + data["age"] = ( + min_age + (max_age - min_age) * percent_along_age_range + ).astype(int) + + data["state_pension_reported"] = df.SRP data["other_tax_credits"] = df.TAX_CRED data["miscellaneous_income"] = ( df.MOTHINC diff --git a/policyengine_uk_data/storage/incomes_projection.csv b/policyengine_uk_data/storage/incomes_projection.csv new file mode 100644 index 0000000..9ddf68a --- /dev/null +++ b/policyengine_uk_data/storage/incomes_projection.csv @@ -0,0 +1,113 @@ +total_income_lower_bound,total_income_upper_bound,employment_income_count,employment_income_amount,self_employment_income_count,self_employment_income_amount,state_pension_count,state_pension_amount,private_pension_income_count,private_pension_income_amount,property_income_count,property_income_amount,savings_interest_income_count,savings_interest_income_amount,dividend_income_count,dividend_income_amount,year +12570,15000.0,1617313,21188248221,418271,4413467024,1172836,11826062936,1280904,7133494286,123235,915876662,930742,75292588,87090,114735958,2022 +15000,20000.0,4014377,67829340021,710977,8781241045,1933393,20074524141,2157907,18643993417,263525,2117911173,1982684,295609788,409517,1294016065,2022 +20000,30000.0,7638005,182995057109,1007725,16805476353,2007068,20555725192,2426387,34918926059,466492,4542955025,3134286,553976946,849159,4506035072,2022 +30000,40000.0,4812367,160342105127,589889,12992143892,859074,8751501824,1154939,23672423318,367090,4194733040,2013275,408383811,670784,6000584630,2022 +40000,50000.0,3006948,123149855217,346450,9080447336,402000,4094287302,594931,15062584091,321211,4083433533,1291407,307220534,620441,11071531331,2022 +50000,70000.0,2630421,140304691170,274622,8245838410,289939,2943023709,456486,14259907455,345406,5209269108,1127049,321084056,560827,13085662134,2022 +70000,100000.0,1271798,95761122529,129824,5446310728,121174,1273645535,191935,8117995849,199996,3667590937,581997,224268287,310898,10218921999,2022 +100000,150000.0,597630,64680084778,80650,5095131003,48942,534003641,75172,4299979467,117448,2333040513,234468,173374315,197751,7003732097,2022 +150000,200000.0,206654,31836669131,38233,3912105875,13620,167514400,22631,1606409311,44907,1017860070,89484,84351723,80052,3648385291,2022 +200000,300000.0,141418,30272381674,30994,4654653927,10896,212997356,15755,1297059502,33925,978464377,71443,95614320,63005,3870353880,2022 +300000,500000.0,73471,23710330164,18593,4658553470,6747,170643744,8347,1009244028,19441,729769248,43753,86329166,40609,3621278715,2022 +500000,1000000.0,40063,21986754528,13955,6604960307,3610,69570380,4026,515985187,10904,507445127,27446,107022725,27414,4415149392,2022 +1000000,inf,20945,39783464634,9926,21651235668,1151,13592774,1863,321927859,7027,446636040,16075,236121715,18622,16942014951,2022 +12570,inf,26071355,1003838434401,3670108,112341565037,6870448,70687092936,8391284,130859929828,2320606,30744984854,11544052,2968649630,3936170,85792401513,2022 +12570,15000.0,843123,11236073687,261924,2985312131,573270,6295209438,613494,3281912515,66533,523610843,481988,45095368,48755,68304803,2023 +15000,20000.0,2812617,44769791351,674970,8556479810,1833835,20625771473,2067113,16105512445,241655,2017275760,1712442,250305968,329578,1028052357,2023 +20000,30000.0,7325103,167321241244,971036,16168318506,2308559,26070153401,2710874,36780181679,438386,4398726735,3236246,578420205,783261,3784124613,2023 +30000,40000.0,5128594,157730546071,695222,14694800813,1123161,12600652842,1376781,27777000009,377434,4465903127,2105449,451095854,669749,6077939465,2023 +40000,50000.0,3585416,142165783416,401701,10893258804,485785,5459480154,688282,17671918002,313829,4169001683,1378695,334807019,582446,8582336855,2023 +50000,70000.0,3366673,168287105436,321676,9580842849,385986,4320363936,584437,18438376023,402500,6183111336,1412532,404071882,724464,17470410298,2023 +70000,100000.0,1690128,121506914967,148372,6228131128,154292,1756844069,237542,10118944849,220047,4250260072,672094,262587416,324847,10550076980,2023 +100000,150000.0,820537,83596472226,87106,5434939632,63210,744044615,95262,5588728202,136946,2929430041,314427,199229173,225869,9164272261,2023 +150000,200000.0,258145,38533753818,41781,4079902702,16644,214246932,26169,1935242857,50791,1181180429,96639,93980776,92334,3745070823,2023 +200000,300000.0,182638,37334672184,35325,5155059876,12875,243655890,18571,1598760207,38889,1127704046,81779,109999197,73121,4546897775,2023 +300000,500000.0,93522,29355059934,20478,5066696208,7316,181113784,9649,1213864960,22037,921337520,49639,102052560,45056,3993942809,2023 +500000,1000000.0,45500,23723048298,14441,6949979474,4337,87176896,4345,609054214,11711,562976372,30920,120440660,30813,5243819452,2023 +1000000,inf,25026,46374174408,10973,23601745566,1799,22223116,1970,350074947,8060,501453198,17674,258929932,19793,18453370935,2023 +12570,inf,26164351,1071171820250,3685005,119395467498,6971069,78620936546,8434490,141469570911,2328820,33231971162,11590524,3211016010,3950084,92708619426,2023 +12570,15000.0,533323,7108150056,181449,2146373087,318843,3627100876,392586,2458451295,56098,480899728,307255,31427626,35167,54572598,2024 +15000,20000.0,2616662,41080651005,658600,8404482195,1837300,21890115761,2072252,15245919881,232087,1942395036,1671822,235641911,304358,965592535,2024 +20000,30000.0,6910022,155050979947,966351,16044292521,2450769,29599550079,2819527,36444949890,433085,4298115216,3242869,576059894,776200,3666496936,2024 +30000,40000.0,5536028,168485102530,744641,15787662561,1232873,14882183445,1447542,28424432829,384940,4483893690,2150119,451459994,677042,6042908705,2024 +40000,50000.0,3795995,150938282334,423608,11848046987,530409,6390505804,724365,18261455269,310500,4097348122,1435137,345331530,591159,8417973689,2024 +50000,70000.0,3538290,176975142910,352253,10638395867,425589,5105092076,614948,18986728206,419010,6334560538,1502228,415014261,755087,17722018286,2024 +70000,100000.0,1812347,130174043505,157796,6752036852,168984,2049531762,248186,10368669686,225940,4300728261,699635,267516303,329432,10633733432,2024 +100000,150000.0,911929,94055996648,90741,5842428900,68236,858377967,99582,5673183821,140586,2972219639,351336,203284531,228764,9210141628,2024 +150000,200000.0,265977,39936021254,42779,4128145235,17254,235282632,26878,1945439421,52905,1195296237,99106,94210484,96346,3803256247,2024 +200000,300000.0,190733,38995907971,37099,5439994469,13694,258685161,18743,1559425398,40727,1148105719,84588,111242513,75692,4515505248,2024 +300000,500000.0,101684,31868820026,21893,5414970584,8001,198064668,10401,1286357007,23048,937077747,52706,102365266,46658,4003559094,2024 +500000,1000000.0,47851,24856320702,14636,6922573453,4499,96329669,4431,613692305,12197,557586016,32424,121486755,32430,5274912663,2024 +1000000,inf,26610,49461753821,11698,25262632224,1808,23556353,2017,350963301,8300,515042622,18357,259886480,20086,18483842623,2024 +12570,inf,26287452,1108987172709,3703545,124632034936,7078259,85214376252,8481458,141619668309,2339423,33263268570,11647583,3214927548,3968421,92794513683,2024 +12570,15000.0,397648,5431425863,95411,1131586102,149252,1819054391,173827,1103045453,28496,256585748,186798,19019331,20214,32835476,2025 +15000,20000.0,2455344,38471462065,646589,8310950387,1722945,21134674396,1941485,13854551564,213849,1846825315,1542858,207239148,262236,804064757,2025 +20000,30000.0,6725705,150473893962,957759,15956448525,2539152,31746979013,2881443,36593475414,421314,4318144012,3203190,595655533,745073,3449453705,2025 +30000,40000.0,5647922,171143600713,779962,16624354851,1315160,16645151834,1546958,30582606095,390418,4664022006,2219894,482176525,669779,5929346090,2025 +40000,50000.0,3934646,157158903741,439525,12395050887,620689,7760893290,806552,20606229440,303104,4179167617,1510217,366118786,582376,7675351842,2025 +50000,70000.0,3671529,182822780485,393018,12219483438,490784,6142528842,694332,21835031578,442584,6763455075,1589617,461844071,806153,19186649128,2025 +70000,100000.0,1899912,135434403223,173116,7474939194,201154,2529853014,287701,12365283613,248359,4953264063,753133,300700046,360245,11683965831,2025 +100000,150000.0,979383,100335533705,96760,6165403881,79730,1033748119,113892,6683888427,153787,3447777837,381372,230334975,250211,10722457328,2025 +150000,200000.0,273418,41252853110,44425,4247215925,19764,273331393,29302,2208024937,55501,1318172084,102156,102129091,98364,3713878618,2025 +200000,300000.0,199891,40283209266,39127,5776288282,15919,294767447,20817,1828740697,43519,1278760742,89232,119757299,82376,5263725031,2025 +300000,500000.0,106487,33164209023,23341,5738915302,8618,208972717,11131,1386264892,24591,1093724569,55350,118095641,49147,4346410755,2025 +500000,1000000.0,51274,26547428136,14926,6938820193,4886,105852482,4906,765803956,13062,576174766,33775,130803986,34241,5792578949,2025 +1000000,inf,26864,50644887128,12613,27021501681,1844,24614709,2074,373252939,8546,579418267,19187,276368912,20632,19789092756,2025 +12570,inf,26370022,1133164590419,3716570,130000958647,7169895,89720421646,8514422,150186199003,2347129,35275492101,11686779,3410243344,3981047,98389810267,2025 +12570,15000.0,284401,4003725725,17732,147078443,74923,957123578,60306,172132464,6604,46028865,101103,9244587,9311,15965643,2026 +15000,20000.0,2345431,36962242079,636642,8335236226,1588079,19960655075,1800032,12854418041,203802,1844088763,1440144,187088952,232587,689039295,2026 +20000,30000.0,6454559,144192679119,950130,15889059441,2584328,33206640892,2898235,36072630396,411206,4266507687,3142405,598498547,713869,3235052416,2026 +30000,40000.0,5865632,178418051066,807993,17524784854,1371469,17854558698,1642304,32840548892,389091,4817007673,2281036,506808784,663984,5781178872,2026 +40000,50000.0,4022262,161997848269,448463,12705589813,677493,8689817475,873563,22644657779,299408,4156841468,1572801,389404670,577926,7215036289,2026 +50000,70000.0,3698018,186591828331,430830,14058712441,534362,6873284739,745781,23880808783,439976,6938840004,1640515,495656920,784042,17059739216,2026 +70000,100000.0,2047654,141923191246,194535,8284719831,224621,2898156689,322403,13981337392,286911,5808618841,816189,339447195,443917,15677599277,2026 +100000,150000.0,1048299,107045075157,102029,6554381944,89675,1187900462,128334,7677174131,163385,3794166555,418397,250145698,268543,12047927482,2026 +150000,200000.0,285262,43112513157,45828,4326447835,22323,313750528,32581,2533732358,58802,1437466364,105241,108378597,102973,3925199748,2026 +200000,300000.0,208035,41735442414,41421,6088523896,16649,275664226,22139,2035423258,46858,1419157659,94096,130133744,87051,5742374807,2026 +300000,500000.0,110460,34323794901,24798,6213730757,8479,227652293,11078,1292535705,25931,1222514565,56799,119094155,50510,4556981255,2026 +500000,1000000.0,55570,28414611732,16457,7545643026,6358,144108876,6459,1071200071,13957,622353008,36771,151500018,37839,6422119621,2026 +1000000,inf,27446,52360639402,12931,28503616591,1853,25262935,2154,392409201,8660,608407839,19516,290278066,20801,20768336340,2026 +12570,inf,26453030,1161081642598,3729790,136177525098,7200611,92614576467,8545369,157449008472,2354590,36981999292,11725012,3575679933,3993352,103136550262,2026 +12570,15000.0,157614,2226514997,11419,115298820,43049,546722056,33845,99112819,2617,22017366,56940,6193902,4771,8853075,2027 +15000,20000.0,2223214,35402364648,555145,7456997250,1357334,17485274910,1498280,10713990174,164372,1557574943,1250664,149720201,174605,429100441,2027 +20000,30000.0,6203414,138834762178,934623,15659944901,2570589,34042587726,2844152,34572598129,388886,4206868786,3017719,604153765,685734,3108597813,2027 +30000,40000.0,6080671,186203956673,828891,18267316336,1470708,19622929846,1790904,36579876310,386545,4946268425,2376429,542023924,639039,5321395421,2027 +40000,50000.0,4098137,166944727639,460863,13163090850,749099,9869438240,950052,25056491329,297294,4323473061,1641086,421578492,568035,6606459754,2027 +50000,70000.0,3770816,190477817490,469723,15842128343,604335,7986055099,845496,27984743447,458723,7652871105,1748588,563132781,820562,18187476247,2027 +70000,100000.0,2161470,148735998930,219621,9247004089,266348,3513593021,378448,16986140600,316390,6614210731,879024,393824692,495114,17800024336,2027 +100000,150000.0,1106276,112820968990,111000,7045651031,104304,1406770159,149362,9282047900,176753,4401898570,458591,279132462,287755,12864985783,2027 +150000,200000.0,312758,45988801397,49699,4493773309,28000,398892466,39467,3224267421,68129,1802993891,115191,144288878,121151,5665751055,2027 +200000,300000.0,216456,43278263991,44346,6528420120,17254,261553597,24309,2362919534,49781,1616057609,98486,138771405,91377,6165331079,2027 +300000,500000.0,118643,36152473383,26718,6691546980,10106,247770448,13442,1705796458,28050,1013600098,61441,130553805,55996,5304615303,2027 +500000,1000000.0,58957,30020975132,17416,7959157397,7543,198583368,6720,1213391884,15675,1134790642,39848,177287001,40553,7535489345,2027 +1000000,inf,28098,54231566808,13541,30341716033,1865,25972647,2230,426153729,9060,687499196,19947,315639880,21238,22480167023,2027 +12570,inf,26536523,1191319192256,3743005,142812045458,7230490,95605545353,8576662,170207253774,2362275,39980124425,11763909,3866292782,4005930,111478246674,2027 +12570,15000.0,43250,547260875,9518,109572877,33003,411490857,25697,85668640,2203,20131951,20837,4681719,3510,6595012,2028 +15000,20000.0,2113708,34203074274,473365,6525113371,1135871,15013231022,1183587,8042139364,118061,1116652605,1067959,115069303,125690,263334619,2028 +20000,30000.0,5866955,131326932205,919684,15635037381,2546834,34753144224,2849394,34922424472,378287,4246272743,2907344,600277416,649167,2957600481,2028 +30000,40000.0,6291548,194764752870,836628,18874924300,1506390,20610618588,1810232,37158850467,373451,4934259286,2381332,571898095,623578,5097188919,2028 +40000,50000.0,4196513,173682520026,462999,13006979729,811666,10980804880,1026718,27432820140,293261,4503277220,1715980,451631209,547909,6201165846,2028 +50000,70000.0,3817546,192916128409,515705,17901307821,700674,9520368873,977036,33564695948,479924,8393640211,1860701,639502562,860946,19442323963,2028 +70000,100000.0,2262917,156196112423,245648,10431247111,308775,4185404952,437158,20153419279,338461,7430477103,954880,449563463,525723,18956824053,2028 +100000,150000.0,1174062,119171478189,119507,7580696445,121093,1663874523,172912,11130437531,193911,5155472105,501729,317819060,313038,14572185744,2028 +150000,200000.0,333372,48599629156,53366,4669919524,33264,482743102,46419,4029874927,75314,2146127250,124035,161469945,133793,6870049292,2028 +200000,300000.0,226873,45159925456,48221,7097630842,20081,305793949,28288,2970663268,53572,1833254339,105386,155861314,98485,6811024520,2028 +300000,500000.0,124941,38034305310,28261,7055972186,10727,255249853,14026,1827977373,29646,1147236805,64378,132018899,59273,5836785467,2028 +500000,1000000.0,61681,30668522455,18268,8368871254,8284,219026825,7762,1510238989,17570,1361971160,42423,222153751,43429,8765575489,2028 +1000000,inf,29803,57247064345,14240,32288400530,1870,26608901,2307,460217073,9331,765355789,20772,342170594,22064,24250691013,2028 +12570,inf,26543169,1222517705993,3745409,149545673373,7238518,98428186821,8581522,183289207202,2362992,43054128567,11767744,4164106845,4006604,120031344418,2028 +12570,15000.0,41322,522174791,9136,106481624,31865,399334404,24278,81593683,2067,19628044,19690,4421706,4083,7788425,2029 +15000,20000.0,1959559,32427629478,370332,5133019076,976260,13082968505,931291,5590351923,80459,699610104,914889,93930277,97080,190565674,2029 +20000,30000.0,5514866,123647055299,942571,16356359966,2505759,34384205939,2847065,35994008007,364185,4319608486,2775386,581926805,597147,2670847802,2029 +30000,40000.0,6496707,204600695976,808977,18507573314,1507572,20704038655,1763629,36142154559,348427,4781269141,2357532,596070159,600272,4678714220,2029 +40000,50000.0,4259677,178736924469,482111,13612310733,854265,11643337828,1113634,30549560706,304316,4902300404,1785116,495175585,538919,6211204554,2029 +50000,70000.0,3836680,195159795735,550918,19636501335,785823,10706908774,1077388,38119827813,474834,8564045062,1943138,687994382,856363,18916410903,2029 +70000,100000.0,2376795,163646937322,277978,11854892360,355513,4837458772,511127,24159428650,372588,8556670378,1040324,519081890,583141,21426450769,2029 +100000,150000.0,1244473,125716702018,132842,8291630742,142941,1969988420,204496,13605664626,216724,6097252826,554298,371811545,347205,16653158370,2029 +150000,200000.0,357143,51827208117,56289,4870213248,38940,554057655,53551,4880091840,81308,2505862349,134156,185312566,147335,8263579353,2029 +200000,300000.0,232665,46215797479,51648,7581090662,22792,344277284,32038,3517051694,57995,2122204759,110881,176906927,103700,7356179584,2029 +300000,500000.0,132179,39994880890,30191,7479900857,11530,262058418,15230,2221574216,32519,1404396444,68507,145572226,63558,6564150570,2029 +500000,1000000.0,65167,32071479900,19830,9366495732,9284,244865427,8919,1730353411,18375,1451646032,44872,245561015,45612,9234666542,2029 +1000000,inf,31452,59862029985,14649,33904055509,1878,26717273,2394,497368158,9895,873603983,21767,374324090,23610,26883847881,2029 +12570,inf,26548685,1254429311460,3747472,156700525157,7244421,99160217356,8585040,197089029286,2363691,46298098012,11770556,4478089174,4008026,129057564647,2029 diff --git a/policyengine_uk_data/utils/incomes_projection.py b/policyengine_uk_data/utils/incomes_projection.py new file mode 100644 index 0000000..7f65080 --- /dev/null +++ b/policyengine_uk_data/utils/incomes_projection.py @@ -0,0 +1,207 @@ +import numpy as np +import pandas as pd +from policyengine_uk_data.storage import STORAGE_FOLDER +from policyengine_uk_data.utils import uprate_values +import warnings +from policyengine_uk import Microsimulation +from policyengine_uk_data.utils.reweight import reweight +from policyengine_uk_data.datasets import SPI_2020_21 + +warnings.filterwarnings("ignore") + +tax_benefit = pd.read_csv(STORAGE_FOLDER / "tax_benefit.csv") +tax_benefit["name"] = tax_benefit["name"].apply(lambda x: f"obr/{x}") +demographics = pd.read_csv(STORAGE_FOLDER / "demographics.csv") +demographics["name"] = demographics["name"].apply(lambda x: f"ons/{x}") +statistics = pd.concat([tax_benefit, demographics]) +dfs = [] + +MIN_YEAR = 2018 +MAX_YEAR = 2029 + +for time_period in range(MIN_YEAR, MAX_YEAR + 1): + time_period_df = statistics[ + ["name", "unit", "reference", str(time_period)] + ].rename(columns={str(time_period): "value"}) + time_period_df["time_period"] = time_period + dfs.append(time_period_df) + +statistics = pd.concat(dfs) +statistics = statistics[statistics.value.notnull()] + + +def create_target_matrix( + dataset: str, + time_period: str, + reform=None, +) -> np.ndarray: + """ + Create a target matrix A, s.t. for household weights w, the target vector b and a perfectly calibrated PolicyEngine UK: + + A * w = b + + """ + + # First- tax-benefit outcomes from the DWP and OBR. + + from policyengine_uk import Microsimulation + + sim = Microsimulation(dataset=dataset, reform=reform) + sim.default_calculation_period = time_period + + household_from_person = lambda values: sim.map_result( + values, "person", "household" + ) + + df = pd.DataFrame() + + # Finally, incomes from HMRC + + target_names = [] + target_values = [] + + INCOME_VARIABLES = [ + "employment_income", + "self_employment_income", + "state_pension", + "private_pension_income", + "property_income", + "savings_interest_income", + "dividend_income", + ] + + income_df = sim.calculate_dataframe(["total_income"] + INCOME_VARIABLES) + + incomes = pd.read_csv(STORAGE_FOLDER / "incomes.csv") + for variable in INCOME_VARIABLES: + incomes[variable + "_count"] = uprate_values( + incomes[variable + "_count"], "household_weight", 2021, time_period + ) + incomes[variable + "_amount"] = uprate_values( + incomes[variable + "_amount"], variable, 2021, time_period + ) + + for i, row in incomes.iterrows(): + lower = row.total_income_lower_bound + upper = row.total_income_upper_bound + in_income_band = (income_df.total_income >= lower) & ( + income_df.total_income < upper + ) + for variable in INCOME_VARIABLES: + name_amount = ( + "hmrc/" + variable + f"_income_band_{i}_{lower:_}_to_{upper:_}" + ) + df[name_amount] = household_from_person( + income_df[variable] * in_income_band + ) + target_values.append(row[variable + "_amount"]) + target_names.append(name_amount) + name_count = ( + "hmrc/" + + variable + + f"_count_income_band_{i}_{lower:_}_to_{upper:_}" + ) + df[name_count] = household_from_person( + (income_df[variable] > 0) * in_income_band + ) + target_values.append(row[variable + "_count"]) + target_names.append(name_count) + + combined_targets = pd.DataFrame( + { + "value": target_values, + }, + index=target_names, + ) + + return df, combined_targets.value + + +def get_loss_results(dataset, time_period, reform=None): + matrix, targets = create_target_matrix(dataset, time_period, reform) + from policyengine_uk import Microsimulation + + weights = ( + Microsimulation(dataset=dataset, reform=reform) + .calculate("household_weight", time_period) + .values + ) + estimates = weights @ matrix + df = pd.DataFrame( + { + "name": estimates.index, + "estimate": estimates.values, + "target": targets, + }, + ) + df["error"] = df["estimate"] - df["target"] + df["abs_error"] = df["error"].abs() + df["rel_error"] = df["error"] / df["target"] + df["abs_rel_error"] = df["rel_error"].abs() + return df.reset_index(drop=True) + + +def create_income_projections(): + loss_matrix, targets_array = create_target_matrix(SPI_2020_21, 2022) + + sim = Microsimulation(dataset=SPI_2020_21) + household_weights = sim.calculate("household_weight", 2022).values + + reweighted_weights = reweight( + household_weights, + loss_matrix, + targets_array, + epochs=1_000, + ) + + sim = Microsimulation(dataset=SPI_2020_21) + sim.set_input("household_weight", 2022, reweighted_weights) + + INCOME_VARIABLES = [ + "employment_income", + "self_employment_income", + "state_pension", + "private_pension_income", + "property_income", + "savings_interest_income", + "dividend_income", + ] + + incomes = pd.read_csv(STORAGE_FOLDER / "incomes.csv") + + projection_df = pd.DataFrame() + lower_bounds = incomes.total_income_lower_bound + upper_bounds = incomes.total_income_upper_bound + + for year in range(2022, 2030): + year_df = pd.DataFrame() + year_df["total_income_lower_bound"] = lower_bounds + year_df["total_income_upper_bound"] = upper_bounds + for variable in INCOME_VARIABLES: + count_values = [] + amount_values = [] + for i, (lower, upper) in enumerate( + zip(lower_bounds, upper_bounds) + ): + in_band = sim.calculate("total_income", year).between( + lower, upper + ) + value = sim.calculate(variable, year) + count_in_band_with_nonzero_value = round( + ((value > 0) * in_band).sum() + ) + amount_in_band = round(value[in_band].sum()) + count_values.append(count_in_band_with_nonzero_value) + amount_values.append(amount_in_band) + year_df[f"{variable}_count"] = count_values + year_df[f"{variable}_amount"] = amount_values + year_df["year"] = year + projection_df = pd.concat([projection_df, year_df]) + + projection_df.to_csv( + STORAGE_FOLDER / "incomes_projection.csv", index=False + ) + + +if __name__ == "__main__": + create_income_projections() diff --git a/policyengine_uk_data/utils/loss.py b/policyengine_uk_data/utils/loss.py index 163cc7b..ed8bd12 100644 --- a/policyengine_uk_data/utils/loss.py +++ b/policyengine_uk_data/utils/loss.py @@ -196,15 +196,8 @@ def pe_count(*variables): income_df = sim.calculate_dataframe(["total_income"] + INCOME_VARIABLES) - incomes = pd.read_csv(STORAGE_FOLDER / "incomes.csv") - for variable in INCOME_VARIABLES: - incomes[variable + "_count"] = uprate_values( - incomes[variable + "_count"], "household_weight", 2021, time_period - ) - incomes[variable + "_amount"] = uprate_values( - incomes[variable + "_amount"], variable, 2021, time_period - ) - + incomes = pd.read_csv(STORAGE_FOLDER / "incomes_projection.csv") + incomes = incomes[incomes.year == time_period] for i, row in incomes.iterrows(): lower = row.total_income_lower_bound upper = row.total_income_upper_bound diff --git a/policyengine_uk_data/utils/reweight.py b/policyengine_uk_data/utils/reweight.py new file mode 100644 index 0000000..4e0aa89 --- /dev/null +++ b/policyengine_uk_data/utils/reweight.py @@ -0,0 +1,65 @@ +import numpy as np +import torch + + +def reweight( + original_weights, + loss_matrix, + targets_array, + dropout_rate=0.05, + epochs=10_000, +): + target_names = np.array(loss_matrix.columns) + loss_matrix = torch.tensor(loss_matrix.values, dtype=torch.float32) + targets_array = torch.tensor(targets_array, dtype=torch.float32) + weights = torch.tensor( + np.log(original_weights), requires_grad=True, dtype=torch.float32 + ) + + # TODO: replace this with a call to the python reweight.py package. + def loss(weights): + # Check for Nans in either the weights or the loss matrix + if torch.isnan(weights).any(): + raise ValueError("Weights contain NaNs") + if torch.isnan(loss_matrix).any(): + raise ValueError("Loss matrix contains NaNs") + estimate = weights @ loss_matrix + if torch.isnan(estimate).any(): + raise ValueError("Estimate contains NaNs") + rel_error = ( + ((estimate - targets_array) + 1) / (targets_array + 1) + ) ** 2 + if torch.isnan(rel_error).any(): + raise ValueError("Relative error contains NaNs") + return rel_error.mean() + + def dropout_weights(weights, p): + if p == 0: + return weights + # Replace p% of the weights with the mean value of the rest of them + mask = torch.rand_like(weights) < p + mean = weights[~mask].mean() + masked_weights = weights.clone() + masked_weights[mask] = mean + return masked_weights + + optimizer = torch.optim.Adam([weights], lr=1e-1) + from tqdm import trange + + start_loss = None + + iterator = trange(epochs) + for i in iterator: + optimizer.zero_grad() + weights_ = dropout_weights(weights, dropout_rate) + l = loss(torch.exp(weights_)) + if start_loss is None: + start_loss = l.item() + loss_rel_change = (l.item() - start_loss) / start_loss + l.backward() + iterator.set_postfix( + {"loss": l.item(), "loss_rel_change": loss_rel_change} + ) + optimizer.step() + + return torch.exp(weights).detach().numpy() diff --git a/pyproject.toml b/pyproject.toml index 1de951e..cda67e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "policyengine_uk_data" -version = "1.5.0" +version = "1.6.0" description = "A package to create representative microdata for the UK." readme = "README.md" authors = [