Skip to content

Commit

Permalink
Test with API
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Sep 24, 2024
1 parent 8bfcd1a commit 6bb67b0
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,19 @@ def __init__(
reform: Reform = None,
trace: bool = False,
):
self.reform = reform
if tax_benefit_system is None:
if self.default_tax_benefit_system_instance is not None:
if (
self.default_tax_benefit_system_instance is not None
and reform is None
):
tax_benefit_system = self.default_tax_benefit_system_instance
else:
tax_benefit_system = self.default_tax_benefit_system()
tax_benefit_system = self.default_tax_benefit_system(
reform=reform
)
self.tax_benefit_system = tax_benefit_system
if self.reform is not None:
tax_benefit_system = tax_benefit_system.clone()

self.reform = reform
self.tax_benefit_system = tax_benefit_system
self.branch_name = "default"
self.dataset = dataset
Expand Down Expand Up @@ -168,7 +171,7 @@ def __init__(
self.tax_benefit_system.simulation = self

if self.reform is not None:
self.apply_reform(self.reform)
self.tax_benefit_system.apply_reform_set(self.reform)

# Backwards compatibility methods
self.calc = self.calculate
Expand Down Expand Up @@ -1511,10 +1514,11 @@ def subsample(
raise ValueError("Either n or frac must be provided.")
if n is None:
n = int(len(h_ids) * frac)
h_weights = pd.Series(h_df[household_weight_column])
h_weights = pd.Series(h_df[household_weight_column].values)

# Ensure n does not exceed the number of households
n = min(n, len(h_weights))
if n > len(h_weights):
# Don't need to subsample!
return self

# Seed the random number generators for reproducibility
random.seed(key)
Expand All @@ -1523,7 +1527,10 @@ def subsample(

# Sample household IDs based on their weights
chosen_household_ids = np.random.choice(
h_ids, n, p=h_weights / h_weights.sum(), replace=False
h_ids,
n,
p=h_weights.values / h_weights.values.sum(),
replace=False,
)

# Filter DataFrame to include only the chosen households
Expand Down

0 comments on commit 6bb67b0

Please sign in to comment.