Skip to content

Commit

Permalink
minor tweaks to simplify (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Sep 18, 2023
1 parent 6b9b229 commit 31b9e0c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
9 changes: 2 additions & 7 deletions src/axolotl/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,16 @@ def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
attention_mask = example["attention_mask"]

# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
for _, (input_id, label_id, mask) in enumerate(
zip(input_ids, labels, attention_mask)
):
for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + (
not text_only
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
not text_only and colored(f"({label_id}, {input_id})", "white") or ""
)
colored_tokens.append(colored_token)

Expand Down
8 changes: 4 additions & 4 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
LOG.info(f"total_num_tokens: {total_num_tokens}")
cfg.total_num_tokens = total_num_tokens

if not cfg.total_supervised_tokens:
Expand Down Expand Up @@ -489,6 +489,8 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))

def calc_sample_packing_eff_est(estimates: List[float]):
Expand All @@ -502,10 +504,8 @@ def calc_sample_packing_eff_est(estimates: List[float]):
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {sample_packing_eff_est}`"
)
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
Expand Down

0 comments on commit 31b9e0c

Please sign in to comment.