-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogging_composer.py
77 lines (58 loc) · 2.48 KB
/
logging_composer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# see https://torchmetrics.readthedocs.io/en/latest/
# for more info on torchmetrics.
# At a high level, a metric must implement `update`,
# which can update some state, and `compute`, which actually
# returns the metric.
from composer import Callback
import wandb
from torchmetrics import Metric
import torch
# this is likely not the optimal way to do this...
class Accuracy(Metric):
# this metric just records the simple accuracy.
def __init__(self):
super().__init__()
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, model, batch, output):
# arguments are the model, batch, and output from a step of the trainer.
ignore_index = model.ignore_index
loss, logits = output
preds = torch.argmax(logits, dim=-1)
targets = batch["targets"]
mask = targets != ignore_index
self.correct += torch.sum((preds == targets) * mask)
self.total += torch.sum(mask)
def compute(self):
return self.correct.float() / self.total
class BitsPerByte(Metric):
def __init__(self):
super().__init__()
self.add_state("bits", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("bytes", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, model, batch, output):
# arguments are the model, batch, and output from a step of the trainer.
loss, logits = output
ignore_index = model.ignore_index
targets = batch["targets"]
num_examples = torch.sum(targets != ignore_index)
self.bits += torch.sum(loss * num_examples)
self.bytes += torch.sum(batch["tokenized_bytes"])
def compute(self):
return self.bits / self.bytes
class Loss(Metric):
def __init__(self):
super().__init__()
self.add_state("loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, model, batch, output):
# arguments are the model, batch, and output from a step of the trainer.
ignore_index = model.ignore_index
loss, logits = output
targets = batch["targets"]
mask = targets != ignore_index
num_examples = torch.sum(mask)
self.loss += loss * num_examples.float()
self.total += num_examples
def compute(self):
return self.loss / self.total