-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
161 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import datetime | ||
import logging | ||
import time | ||
|
||
from libai.utils.events import EventWriter, get_event_storage | ||
|
||
|
||
class LlamaMetricPrinter(EventWriter): | ||
""" | ||
Print **Llama** metrics to the terminal, including | ||
iteration time, ETA, memory, all losses, and the learning rate. | ||
It also applies smoothing using a window of 20 elements. | ||
It's meant to print Llama metrics in Llama ways. | ||
To print something in more customized ways, please implement a similar printer by yourself. | ||
""" | ||
|
||
def __init__(self, batch_size, max_iter, log_period): | ||
""" | ||
Args: | ||
max_iter (int): the maximum number of iterations to train. | ||
Used to compute ETA. | ||
""" | ||
self.logger = logging.getLogger("libai." + __name__) | ||
self._batch_size = batch_size | ||
self._max_iter = max_iter | ||
self._last_write = None | ||
self._log_period = log_period | ||
|
||
def write(self): | ||
storage = get_event_storage() | ||
iteration = storage.iter | ||
consumed_samples = storage.samples | ||
try: | ||
done_tokens = storage.history("done_tokens").avg(self._log_period) | ||
token_time = storage.history("time").avg(self._log_period) | ||
except KeyError: | ||
done_tokens = None | ||
try: | ||
correct_tokens = storage.history("correct_tokens").avg(self._log_period) | ||
denominator = storage.history("denominator").avg(self._log_period) | ||
acc_mlm = correct_tokens / denominator | ||
except KeyError: | ||
acc_mlm = None | ||
|
||
if iteration == self._max_iter: | ||
# This hook only reports training progress (loss, ETA, etc) but not other data, | ||
# therefore do not write anything after training succeeds, even if this method | ||
# is called. | ||
return | ||
|
||
try: | ||
data_time = storage.history("data_time").avg(self._log_period) | ||
except KeyError: | ||
# they may not exist in the first few iterations (due to warmup) | ||
# or when SimpleTrainer is not used | ||
data_time = None | ||
|
||
eta_string = None | ||
try: | ||
iter_time = storage.history("time").global_avg() | ||
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1) | ||
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) | ||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | ||
except KeyError: | ||
iter_time = None | ||
# estimate eta on our own - more noisy | ||
if self._last_write is not None: | ||
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( | ||
iteration - self._last_write[0] | ||
) | ||
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) | ||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | ||
self._last_write = (iteration, time.perf_counter()) | ||
|
||
try: | ||
lr = "{:.2e}".format(storage.history("lr").latest()) | ||
except KeyError: | ||
lr = "N/A" | ||
|
||
max_mem_mb = None | ||
|
||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh" | ||
self.logger.info( | ||
" {eta} {iter} {sample} {losses} {time} {data_time} {tpt} lr: {lr} {memory} " | ||
" {tokens_speed} {acc_mlm}".format( | ||
eta=f"eta: {eta_string}" if eta_string else "", | ||
iter=f"iteration: {iteration}/{self._max_iter}", | ||
sample=f"consumed_samples: {consumed_samples}", | ||
losses=" ".join( | ||
[ | ||
"{}: {:.4g}".format(k, v.median(200)) | ||
for k, v in storage.histories().items() | ||
if "loss" in k | ||
] | ||
), | ||
time="time: {:.4f} s/iter ".format(iter_time) if iter_time is not None else "", | ||
data_time="data_time: {:.4f} s/iter".format(data_time) | ||
if data_time is not None | ||
else "", | ||
tpt="total_throughput: {:.2f} samples/s".format(self._batch_size / iter_time) | ||
if iter_time is not None | ||
else "", | ||
lr=lr, | ||
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "", | ||
tokens_speed="tokens_throughput: {:.4f} tokens/s".format(done_tokens / token_time) | ||
if done_tokens is not None | ||
else "", | ||
acc_mlm="acc_mlm: {:.4f}".format(acc_mlm) if acc_mlm is not None else "", | ||
) | ||
) |