-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_test.py
46 lines (37 loc) · 1.33 KB
/
run_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import hydra
from torch import set_float32_matmul_precision
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer
from pipeline import FPM_Pipeline
from pytorch_lightning.profilers import SimpleProfiler, AdvancedProfiler
@hydra.main(version_base=None, config_path="configs", config_name="config")
def eval_pipeline(args: DictConfig) -> None:
set_float32_matmul_precision('medium')
system = FPM_Pipeline(args, test_model=True)
if args.profiler:
profiler = AdvancedProfiler(args.tmp.logs_out,'profiler_log')
else:
profiler = None
trainer = Trainer(
logger = None,
profiler=profiler,
enable_checkpointing=True,
num_sanity_val_steps=0,
callbacks=system.get_callbacks(),
accelerator=args.device,
log_every_n_steps=10,
devices='auto',
num_nodes=1,
check_val_every_n_epoch=args.training.freq_valid,
max_epochs=args.training.n_epochs
)
print(args.exp_tag)
print("TEST CONFIGURATION:")
for k,v in args.test.items():
print(f'{k} : {v}')
print("Loading checkpoint ", args.eval.ckpt)
test_data = system.get_test_dataloader()
trainer.test(system, test_data, ckpt_path=args.eval.ckpt)
if __name__ == '__main__':
eval_pipeline()