-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_selfonn.py
88 lines (71 loc) · 2.93 KB
/
main_selfonn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# common
import os
from PIL import Image
from argparse import ArgumentParser
# torch
import torch
from torch import nn
# utils
from utils.datasets import SIDDDataset,SIDDValidDataset
from utils.models import get_model
from utils.metrics import BatchPSNR
from utils.lit_module import LitDenoiser,WandbLogger,ModelCheckpoint,Trainer,LearningRateMonitor
# main function
def main():
# Injection of arguments
parser = ArgumentParser()
parser = LitDenoiser.add_model_specific_args(parser)
# training args
parser.add_argument('--max_epochs',type=int,default=100)
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--gpu_id', type=int, default=None)
parser.add_argument('--overfit_batches', type=float, default=0)
parser.add_argument('--fast_dev_run', type=bool, default=False)
parser.add_argument('--checkpointing',type=bool,default=False)
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
parser.add_argument('--load_from_checkpoint',type=str,default=None)
# wandb args
parser.add_argument('--log_to_wandb',type=bool,default=False)
parser.add_argument('--name',type=str,default=None)
parser.add_argument('--project',type=str,default='debugging')
parser.add_argument('--version',type=str,default=None)
# data args
parser.add_argument('--data_path',type=str,default='C:/Users/PA_hm17901/')
args = parser.parse_args()
# Init
wandb_logger = WandbLogger(
name=args.name,
project=args.project,
log_model=False
)
callbacks = []
checkpoint_callback = ModelCheckpoint(
monitor='val_psnr',
save_top_k=1,
mode='max',
verbose=False
)
lr_callback = LearningRateMonitor(logging_interval='epoch')
if args.checkpointing: callbacks.append(checkpoint_callback)
callbacks.append(lr_callback)
trainer = Trainer(
max_epochs=args.max_epochs, # maximum number of epochs
gpus=args.gpus if args.gpu_id is None else str(args.gpu_id), # gpus
overfit_batches=args.overfit_batches, # for debugging
fast_dev_run=args.fast_dev_run, # for debugging
logger=wandb_logger if args.log_to_wandb else True, # logger
checkpoint_callback=args.checkpointing, # if checkpointing is enabled
callbacks=callbacks, # checkpoint callback to use
distributed_backend='dp' if args.gpu_id is None else None, # callback
resume_from_checkpoint=args.resume_from_checkpoint, # resume checkpoing
)
if args.load_from_checkpoint is not None:
print("Loading from checkpoint..")
denoiser = LitDenoiser.load_from_checkpoint(args.load_from_checkpoint)
denoiser.hparams.args = args
else:
denoiser = LitDenoiser(args)
# Start training
trainer.fit(denoiser)
if __name__=='__main__':
main()