Skip to content

Commit

Permalink
Add debug option for RL dataset preprocessing (#1404)
Browse files Browse the repository at this point in the history
* adding debug option for RL dataset preprocessing

* Refine formatting of debugging code in RL dataset preprocessing

* Update __init__.py

* chore: fix lint

---------

Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
abhinand5 and NanoCode012 authored Apr 30, 2024
1 parent a8bdb14 commit a56e062
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 3 deletions.
17 changes: 17 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,23 @@ def load_rl_datasets(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)

return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand Down
61 changes: 58 additions & 3 deletions src/axolotl/utils/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Module for tokenization utilities"""


import logging
import re
from typing import Dict, List
Expand All @@ -10,10 +9,19 @@
LOG = logging.getLogger("axolotl")


def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
def check_dataset_labels(
dataset,
tokenizer,
num_examples=5,
text_only=False,
rl_mode=False,
):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(num_examples):
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
if not rl_mode:
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
else:
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)


def check_example_labels(example, tokenizer, text_only=False):
Expand All @@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
return " ".join(colored_tokens)


def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
"""Helper function to color tokens based on their type."""
colored_text = colored(decoded_token, color)
return (
colored_text
if text_only
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
)


def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens)
]
return colored_tokens


def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"

input_tokens = example[field_prompt]
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]

# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)

# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "

# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")

return delimiter.join(colored_tokens)


GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",
Expand Down

0 comments on commit a56e062

Please sign in to comment.