diff --git a/policyengine_uk_data/datasets/frs/enhanced_frs.py b/policyengine_uk_data/datasets/frs/enhanced_frs.py index f023d49..28b873b 100644 --- a/policyengine_uk_data/datasets/frs/enhanced_frs.py +++ b/policyengine_uk_data/datasets/frs/enhanced_frs.py @@ -12,7 +12,7 @@ class EnhancedFRS(Dataset): def generate(self): - data = self.input_frs().load_dataset() + data = self.input_frs(require=True).load_dataset() original_weights = data["household_weight"][str(self.time_period)] + 10 for year in range(self.time_period, self.end_year + 1): loss_matrix, targets_array = create_target_matrix( @@ -100,5 +100,4 @@ def loss(weights): if __name__ == "__main__": - ReweightedFRS_2022_23().generate() EnhancedFRS_2022_23().generate() diff --git a/policyengine_uk_data/utils/imputations/capital_gains.py b/policyengine_uk_data/utils/imputations/capital_gains.py index e53536d..54c7862 100644 --- a/policyengine_uk_data/utils/imputations/capital_gains.py +++ b/policyengine_uk_data/utils/imputations/capital_gains.py @@ -93,7 +93,7 @@ def loss(blend_factor): loss_value.backward() optimiser.step() progress.set_description(f"Loss: {loss_value.item()}") - if loss_value.item() < 1e-5: + if loss_value.item() < 1e-3: break new_household_weight = household_weight.detach().numpy()