diff --git a/projects/Llama/llama.py b/projects/Llama/llama.py index ea1b73541..b149ac9ad 100644 --- a/projects/Llama/llama.py +++ b/projects/Llama/llama.py @@ -515,9 +515,28 @@ def __init__(self) -> None: self.lm_loss = CrossEntropyLoss() def forward(self, logits, lm_labels): + lm_labels = lm_labels.to_global(placement=logits.placement) lm_loss = self.lm_loss(logits, lm_labels) lm_loss = lm_loss.mean() - return {"lm_loss": lm_loss} + + if self.training: + # token throughput + done_tokens = ( + flow.zeros( + 1, + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=lm_labels.placement, + ) + + logits.shape[0] * logits.shape[1] + ) + return { + "lm_loss": lm_loss, + "done_tokens": done_tokens, + } + else: + return { + "lm_loss": lm_loss, + } class LlamaForCausalLM(nn.Module, Generator): diff --git a/projects/Llama/train_net.py b/projects/Llama/train_net.py index 12f55e912..dd584c9fb 100644 --- a/projects/Llama/train_net.py +++ b/projects/Llama/train_net.py @@ -26,6 +26,8 @@ from libai.engine import DefaultTrainer, default_setup from libai.utils.checkpoint import Checkpointer from projects.Llama.utils.llama_loader import LlamaLoaderHuggerFace +from libai.utils.events import JSONWriter, TensorboardXWriter +from projects.Llama3.utils.llama_metric_printer import LlamaMetricPrinter sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) @@ -59,6 +61,35 @@ def build_model(cls, cfg): model._apply(dist.convert_to_distributed_default_setting) return model + def build_writers(self): + """ + Build a list of writers to be used. By default it contains + writers that write metrics to the screen, + a json file, and a tensorboard event file respectively. + If you'd like a different list of writers, you can overwrite it in + your trainer. + + Returns: + list[EventWriter]: a list of :class:`EventWriter` objects. + + It is now implemented by: + + .. code-block:: python + + return [ + MT5MetricPrinter(self.global_batch_size, self.max_iter), + JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")), + TensorboardXWriter(self.cfg.train.output_dir), + ] + """ + # Assume the default print/log frequency. + return [ + # It may not always print what you want to see, since it prints "common" metrics only. + LlamaMetricPrinter(self.global_batch_size, self.max_iter, self.cfg.train.log_period), + JSONWriter(os.path.join(self.cfg.train.output_dir, "metrics.json")), + TensorboardXWriter(self.cfg.train.output_dir), + ] + def main(args): cfg = LazyConfig.load(args.config_file) diff --git a/projects/Llama/utils/llama_metric_printer.py b/projects/Llama/utils/llama_metric_printer.py new file mode 100644 index 000000000..e90151d6d --- /dev/null +++ b/projects/Llama/utils/llama_metric_printer.py @@ -0,0 +1,110 @@ +import datetime +import logging +import time + +from libai.utils.events import EventWriter, get_event_storage + + +class LlamaMetricPrinter(EventWriter): + """ + Print **Llama** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + It's meant to print Llama metrics in Llama ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, batch_size, max_iter, log_period): + """ + Args: + max_iter (int): the maximum number of iterations to train. + Used to compute ETA. + """ + self.logger = logging.getLogger("libai." + __name__) + self._batch_size = batch_size + self._max_iter = max_iter + self._last_write = None + self._log_period = log_period + + def write(self): + storage = get_event_storage() + iteration = storage.iter + consumed_samples = storage.samples + try: + done_tokens = storage.history("done_tokens").avg(self._log_period) + token_time = storage.history("time").avg(self._log_period) + except KeyError: + done_tokens = None + try: + correct_tokens = storage.history("correct_tokens").avg(self._log_period) + denominator = storage.history("denominator").avg(self._log_period) + acc_mlm = correct_tokens / denominator + except KeyError: + acc_mlm = None + + if iteration == self._max_iter: + # This hook only reports training progress (loss, ETA, etc) but not other data, + # therefore do not write anything after training succeeds, even if this method + # is called. + return + + try: + data_time = storage.history("data_time").avg(self._log_period) + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + data_time = None + + eta_string = None + try: + iter_time = storage.history("time").global_avg() + eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1) + storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + iter_time = None + # estimate eta on our own - more noisy + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + + try: + lr = "{:.2e}".format(storage.history("lr").latest()) + except KeyError: + lr = "N/A" + + max_mem_mb = None + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + self.logger.info( + " {eta} {iter} {sample} {losses} {time} {data_time} {tpt} lr: {lr} {memory} " + " {tokens_speed} {acc_mlm}".format( + eta=f"eta: {eta_string}" if eta_string else "", + iter=f"iteration: {iteration}/{self._max_iter}", + sample=f"consumed_samples: {consumed_samples}", + losses=" ".join( + [ + "{}: {:.4g}".format(k, v.median(200)) + for k, v in storage.histories().items() + if "loss" in k + ] + ), + time="time: {:.4f} s/iter ".format(iter_time) if iter_time is not None else "", + data_time="data_time: {:.4f} s/iter".format(data_time) + if data_time is not None + else "", + tpt="total_throughput: {:.2f} samples/s".format(self._batch_size / iter_time) + if iter_time is not None + else "", + lr=lr, + memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "", + tokens_speed="tokens_throughput: {:.4f} tokens/s".format(done_tokens / token_time) + if done_tokens is not None + else "", + acc_mlm="acc_mlm: {:.4f}".format(acc_mlm) if acc_mlm is not None else "", + ) + )