From 12aaeb11330ed310cebc0fa1f541523d6ff138db Mon Sep 17 00:00:00 2001
From: Wing Lian <wing.lian@gmail.com>
Date: Mon, 18 Sep 2023 11:45:44 -0400
Subject: [PATCH] minor tweaks to simplify (#597)

---
 src/axolotl/utils/tokenization.py | 9 ++-------
 src/axolotl/utils/trainer.py      | 8 ++++----
 2 files changed, 6 insertions(+), 11 deletions(-)

diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 82fcbc638e..4746ceddef 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -18,21 +18,16 @@ def check_example_labels(example, tokenizer, text_only=False):
     # Get the input_ids, labels, and attention_mask from the dataset
     input_ids = example["input_ids"]
     labels = example["labels"]
-    attention_mask = example["attention_mask"]
 
     # You can compare the input_ids and labels element-wise
     # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
     colored_tokens = []
-    for _, (input_id, label_id, mask) in enumerate(
-        zip(input_ids, labels, attention_mask)
-    ):
+    for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
         decoded_input_token = tokenizer.decode(input_id)
         # Choose the color based on whether the label has the ignore value or not
         color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
         colored_token = colored(decoded_input_token, color) + (
-            not text_only
-            and colored(f"({label_id}, {mask}, {input_id})", "white")
-            or ""
+            not text_only and colored(f"({label_id}, {input_id})", "white") or ""
         )
         colored_tokens.append(colored_token)
 
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 8f9e5e4265..2067a90069 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -429,7 +429,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
                 .apply(lambda x: len(x))  # pylint: disable=unnecessary-lambda
                 .values
             )
-            LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
+            LOG.info(f"total_num_tokens: {total_num_tokens}")
             cfg.total_num_tokens = total_num_tokens
 
         if not cfg.total_supervised_tokens:
@@ -489,6 +489,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
             data_loader_len = data_loader.len_w_stats()
             actual_eff = data_loader.efficiency()
             LOG.info(f"data_loader_len: {data_loader_len}")
+            # FIXME: is there a bug here somewhere? the total num steps depends
+            # on the agreed on value for sample_packing_eff_est
             total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
 
             def calc_sample_packing_eff_est(estimates: List[float]):
@@ -502,10 +504,8 @@ def calc_sample_packing_eff_est(estimates: List[float]):
             sample_packing_eff_est = (
                 math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
             )
-            LOG.info(
-                f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
-            )
             cfg.sample_packing_eff_est = sample_packing_eff_est
+            LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
     else:
         total_num_steps = int(
             math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)