Skip to content

Commit

Permalink
fix function to find correct checkpoint path and UID for restarting b…
Browse files Browse the repository at this point in the history
…reaking IID experiments
  • Loading branch information
mcw92 committed Jan 7, 2025
1 parent eeebbca commit 0bfcd8b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
6 changes: 5 additions & 1 deletion scripts/examples/get_breaking_iid_checkpoint_dir.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#!/usr/bin/env python

import pathlib

from specialcouscous.utils.slurm import find_checkpoint_dir_and_uuid

if __name__ == "__main__":
base_path = pathlib.Path("./breaking_iid/")
base_path = pathlib.Path(
"/hkfs/work/workspace/scratch/ku4408-SpecialCouscous/results/"
)
log_n_samples = 6
log_n_features = 4
n_classes = 10
Expand Down
18 changes: 9 additions & 9 deletions specialcouscous/utils/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,10 @@ def find_checkpoint_dir_and_uuid(

# Construct the expected directory path pattern
search_pattern = (
f"chunking/n{log_n_samples}_m{log_n_features}/nodes_16/"
f"breaking_iid/n{log_n_samples}_m{log_n_features}/nodes_16/"
f"*_{data_seed}_{model_seed}_{mu_global_str}_{mu_local_str}/"
)
print(f"The search pattern is {search_pattern}.")

# Search for matching directories.
matching_dirs = list(base_path.glob(search_pattern))
Expand All @@ -233,13 +234,12 @@ def find_checkpoint_dir_and_uuid(

checkpoint_dir = matching_dirs[0]

# Extract the UUID from filenames in the directory.
uuid_pattern = re.compile(
r"(\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b)"
)
# Extract the UUID from the results file.
uuid_pattern = re.compile(r"--([\d\w-]+)_results\.csv$")
for file in checkpoint_dir.iterdir():
match = uuid_pattern.search(file.name)
if match:
return checkpoint_dir, match.group(1)
if file.name.endswith("_results.csv"):
match = uuid_pattern.search(file.name)
if match:
return checkpoint_dir, match.group(1)

raise ValueError(f"No UUID found in files within {checkpoint_dir}.")
raise ValueError(f"No UUID found in results files within {checkpoint_dir}")

0 comments on commit 0bfcd8b

Please sign in to comment.