-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
69 lines (58 loc) · 1.86 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from networks import *
args = get_args()
# Config hyperparameter
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = args.num_classes
ROOT = args.root
ENCODER_NAME = 'resnet50'
print(f"Device using: {DEVICE}")
model = smp.DeepLabV3Plus(
encoder_name=ENCODER_NAME,
encoder_weights='imagenet',
classes=NUM_CLASSES,
activation='sigmoid',
)
model.to(DEVICE)
# Load pretrained if exist
PRETRAINED = 'checkpoints/lastest_model.pth'
checkpoint = torch.load(PRETRAINED, map_location=DEVICE)
model.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
iou = checkpoint['iou']
f1score = checkpoint['f1-score']
print('Evaluation by weights ---{}--- have IoU = {}, F1Score = {}, trained with epoch: {} \n'.format(PRETRAINED, iou, f1score, start_epoch))
preprocessing_fn = get_preprocessing_fn(ENCODER_NAME)
val_dataset = BuildingsDataset(path_dataset=ROOT, mode='val',
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn))
valid_dataloader = DataLoader(
val_dataset,
batch_size=1,
pin_memory=True,
num_workers=args.num_workers,
shuffle=False,
drop_last=True
)
# define loss function
loss = smp.utils.losses.DiceLoss()
# define metrics
metrics = [
smp.utils.metrics.IoU(threshold=0.5),
smp.utils.metrics.Fscore()
]
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
def eval():
valid_logs_list = []
valid_logs = valid_epoch.run(valid_dataloader)
valid_logs_list.append(valid_logs)
print(valid_logs_list)
# use "if __name__ == '__main__' to fix error Parallel"
# https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206
if __name__ == '__main__':
eval()