-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
154 lines (114 loc) · 3.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
import tomllib
import typing
import tyro
import saev
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger("main")
def activations(cfg: typing.Annotated[saev.ActivationsConfig, tyro.conf.arg(name="")]):
"""
Save ViT activations for use later on.
Args:
cfg: Configuration for activations.
"""
import saev.activations
saev.activations.dump(cfg)
def sweep(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")], sweep: str):
"""
Run a grid search over a set of hyperparameters.
Args:
cfg: Baseline config for training an SAE.
sweep: Path to .toml file defining the sweep parameters.
"""
import submitit
import saev.config
import saev.training
with open(sweep, "rb") as fd:
cfgs, errs = saev.config.grid(cfg, tomllib.load(fd))
if errs:
for err in errs:
logger.warning("Error in config: %s", err)
return
logger.info("Sweep has %d experiments.", len(cfgs))
if cfg.slurm:
executor = submitit.SlurmExecutor(folder=cfg.log_to)
executor.update_parameters(
time=60,
partition="preemptible",
gpus_per_node=1,
cpus_per_task=cfg.n_workers + 4,
stderr_to_stdout=True,
account=cfg.slurm_acct,
)
else:
executor = submitit.DebugExecutor(folder=cfg.log_to)
job = executor.submit(saev.training.main, cfgs)
job.result()
# for i, result in enumerate(submitit.helpers.as_completed(jobs)):
# exp_id = result.result()
# logger.info("Finished task %s (%d/%d)", exp_id, i + 1, len(jobs))
def train(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")]):
def fn():
import saev.training
saev.training.main(cfg)
import submitit
if cfg.slurm:
executor = submitit.SlurmExecutor(folder=cfg.log_to)
executor.update_parameters(
time=30,
partition="debug",
gpus_per_node=1,
cpus_per_task=12,
stderr_to_stdout=True,
account=cfg.slurm_acct,
)
else:
executor = submitit.DebugExecutor(folder=cfg.log_to)
job = executor.submit(fn)
job.result()
def evaluate(cfg: typing.Annotated[saev.EvaluateConfig, tyro.conf.arg(name="")]):
def run_histograms():
import saev.training
return saev.training.evaluate(cfg.histograms)
def run_broden():
import saev.broden
return saev.broden.evaluate(cfg.broden)
def run_imagenet1k():
import saev.imagenet1k
return saev.imagenet1k.evaluate(cfg.imagenet)
import submitit
if cfg.slurm:
executor = submitit.SlurmExecutor(folder=cfg.log_to)
executor.update_parameters(
time=30,
partition="debug",
gpus_per_node=1,
cpus_per_task=12,
stderr_to_stdout=True,
account=cfg.slurm_acct,
)
else:
executor = submitit.DebugExecutor(folder=cfg.log_to)
jobs = []
# jobs.append(executor.submit(run_histograms))
# jobs.append(executor.submit(run_broden))
jobs.append(executor.submit(run_imagenet1k))
for job in jobs:
job.result()
# def webapp(cfg: typing.Annotated[saev.WebappConfig, tyro.conf.arg(name="")]):
# import saev.webapp
# saev.webapp.main(cfg)
# print()
# print("To view the webapp, run:")
# print()
# print(" uv run marimo edit webapp.py")
# print()
if __name__ == "__main__":
tyro.extras.subcommand_cli_from_dict({
"activations": activations,
"sweep": sweep,
"evaluate": evaluate,
# "webapp": webapp,
})
logger.info("Done.")