Skip to content

Commit

Permalink
add: token moniter
Browse files Browse the repository at this point in the history
  • Loading branch information
Lusfie committed Nov 7, 2024
1 parent 555ade1 commit eb16fab
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
21 changes: 20 additions & 1 deletion projects/Llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,28 @@ def __init__(self) -> None:
self.lm_loss = CrossEntropyLoss()

def forward(self, logits, lm_labels):
lm_labels = lm_labels.to_global(placement=logits.placement)
lm_loss = self.lm_loss(logits, lm_labels)
lm_loss = lm_loss.mean()
return {"lm_loss": lm_loss}

if self.training:
# token throughput
done_tokens = (
flow.zeros(
1,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=lm_labels.placement,
)
+ logits.shape[0] * logits.shape[1]
)
return {
"lm_loss": lm_loss,
"done_tokens": done_tokens,
}
else:
return {
"lm_loss": lm_loss,
}


class LlamaForCausalLM(nn.Module, Generator):
Expand Down
31 changes: 31 additions & 0 deletions projects/Llama/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from libai.engine import DefaultTrainer, default_setup
from libai.utils.checkpoint import Checkpointer
from projects.Llama.utils.llama_loader import LlamaLoaderHuggerFace
from libai.utils.events import JSONWriter, TensorboardXWriter
from projects.Llama3.utils.llama_metric_printer import LlamaMetricPrinter

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

Expand Down Expand Up @@ -59,6 +61,35 @@ def build_model(cls, cfg):
model._apply(dist.convert_to_distributed_default_setting)
return model

def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
.. code-block:: python
return [
MT5MetricPrinter(self.global_batch_size, self.max_iter),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]
"""
# Assume the default print/log frequency.
return [
# It may not always print what you want to see, since it prints "common" metrics only.
LlamaMetricPrinter(self.global_batch_size, self.max_iter, self.cfg.train.log_period),
JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")),
TensorboardXWriter(self.cfg.train.output_dir),
]


def main(args):
cfg = LazyConfig.load(args.config_file)
Expand Down
110 changes: 110 additions & 0 deletions projects/Llama/utils/llama_metric_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import datetime
import logging
import time

from libai.utils.events import EventWriter, get_event_storage


class LlamaMetricPrinter(EventWriter):
"""
Print **Llama** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print Llama metrics in Llama ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""

def __init__(self, batch_size, max_iter, log_period):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger("libai." + __name__)
self._batch_size = batch_size
self._max_iter = max_iter
self._last_write = None
self._log_period = log_period

def write(self):
storage = get_event_storage()
iteration = storage.iter
consumed_samples = storage.samples
try:
done_tokens = storage.history("done_tokens").avg(self._log_period)
token_time = storage.history("time").avg(self._log_period)
except KeyError:
done_tokens = None
try:
correct_tokens = storage.history("correct_tokens").avg(self._log_period)
denominator = storage.history("denominator").avg(self._log_period)
acc_mlm = correct_tokens / denominator
except KeyError:
acc_mlm = None

if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return

try:
data_time = storage.history("data_time").avg(self._log_period)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None

eta_string = None
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())

try:
lr = "{:.2e}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"

max_mem_mb = None

# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta} {iter} {sample} {losses} {time} {data_time} {tpt} lr: {lr} {memory} "
" {tokens_speed} {acc_mlm}".format(
eta=f"eta: {eta_string}" if eta_string else "",
iter=f"iteration: {iteration}/{self._max_iter}",
sample=f"consumed_samples: {consumed_samples}",
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(200))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} s/iter ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} s/iter".format(data_time)
if data_time is not None
else "",
tpt="total_throughput: {:.2f} samples/s".format(self._batch_size / iter_time)
if iter_time is not None
else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
tokens_speed="tokens_throughput: {:.4f} tokens/s".format(done_tokens / token_time)
if done_tokens is not None
else "",
acc_mlm="acc_mlm: {:.4f}".format(acc_mlm) if acc_mlm is not None else "",
)
)

0 comments on commit eb16fab

Please sign in to comment.