Skip to content

Commit

Permalink
Don't run all epochs in test mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Dec 22, 2024
1 parent d0c8c85 commit ef50800
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pull_request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
- name: Build datasets
run: make data
env:
TEST_LITE: true
- name: Run tests
run: pytest
- name: Test documentation builds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from tqdm import tqdm
import h5py
import os
from policyengine_uk_data.datasets.frs.local_areas.constituencies.transform_constituencies import (
transform_2010_to_2024,
)
Expand Down Expand Up @@ -69,7 +70,7 @@ def dropout_weights(weights, p):

optimizer = torch.optim.Adam([weights], lr=0.05)

desc = range(2048)
desc = range(128) if os.environ.get("DATA_LITE") else range(2048)

for epoch in desc:
optimizer.zero_grad()
Expand All @@ -78,7 +79,7 @@ def dropout_weights(weights, p):
l.backward()
optimizer.step()
if epoch % 50 == 0:
print(f"Loss: {l.item()}, Epoch: {epoch}")
print(f"Loss: {l.item()}, Epoch: {epoch}", flush=True)

final_weights = torch.exp(weights).detach().numpy()
mapping_matrix = pd.read_csv(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
from tqdm import tqdm
import h5py
import os
from policyengine_uk_data.storage import STORAGE_FOLDER


from loss import (
from policyengine_uk_data.datasets.frs.local_areas.local_authorities.loss import (
create_local_authority_target_matrix,
create_national_target_matrix,
)
Expand Down Expand Up @@ -62,7 +63,7 @@ def dropout_weights(weights, p):

optimizer = torch.optim.Adam([weights], lr=0.05)

desc = range(2048)
desc = range(128) if os.environ.get("DATA_LITE") else range(2048)

for epoch in desc:
optimizer.zero_grad()
Expand Down
3 changes: 2 additions & 1 deletion policyengine_uk_data/utils/reweight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import os


def reweight(
Expand Down Expand Up @@ -47,7 +48,7 @@ def dropout_weights(weights, p):

start_loss = None

iterator = range(1_000)
iterator = range(128) if os.environ.get("DATA_LITE") else range(2048)
for i in iterator:
optimizer.zero_grad()
weights_ = dropout_weights(weights, dropout_rate)
Expand Down

0 comments on commit ef50800

Please sign in to comment.