forked from LukasKG/GAN_SHL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
33 lines (27 loc) · 1.15 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding: utf-8 -*-
from sklearn.metrics import f1_score
import torch
def calc_accuracy(predictions: torch.Tensor, labels: torch.Tensor) -> float:
if labels.size()[1] == 1:
return (predictions.round() == labels).sum().item() / labels.size(0)
else:
return (predictions.max(dim=1)[1] == labels.max(dim=1)[1]).sum().item() / labels.size(0)
def calc_f1score(predictions: torch.Tensor, labels: torch.Tensor, average: str = 'binary') -> float:
"""
Calculate f1 score based on averaging method defined.
Args:
predictions: tensor with predictions
labels: tensor with original labels
average: averaging method
Returns:
f1 score
"""
if predictions.size(1) > 1:
y_pred = predictions.max(dim=1)[1].detach().cpu().numpy()
else:
y_pred = torch.reshape(predictions, (-1,)).round().detach().cpu().numpy()
if labels.size(1) > 1:
y_true = labels.max(dim=1)[1].detach().cpu().numpy()
else:
y_true = torch.reshape(labels, (-1,)).detach().cpu().numpy()
return f1_score(y_true, y_pred, average=average, zero_division=0)