diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 0670f64e3a..9f40bb476d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -433,6 +433,23 @@ def load_rl_datasets( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + return TrainDatasetMeta( train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index afbdef8778..845296b7a6 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,6 +1,5 @@ """Module for tokenization utilities""" - import logging import re from typing import Dict, List @@ -10,10 +9,19 @@ LOG = logging.getLogger("axolotl") -def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False): +def check_dataset_labels( + dataset, + tokenizer, + num_examples=5, + text_only=False, + rl_mode=False, +): # the dataset is already shuffled, so let's just check the first 5 elements for idx in range(num_examples): - check_example_labels(dataset[idx], tokenizer, text_only=text_only) + if not rl_mode: + check_example_labels(dataset[idx], tokenizer, text_only=text_only) + else: + check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only) def check_example_labels(example, tokenizer, text_only=False): @@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False): return " ".join(colored_tokens) +def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only): + """Helper function to color tokens based on their type.""" + colored_text = colored(decoded_token, color) + return ( + colored_text + if text_only + else f"{colored_text}{colored(f'({encoded_token})', 'white')}" + ) + + +def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): + """Helper function to process and color tokens.""" + colored_tokens = [ + color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) + for token in tokenizer.encode(tokens) + ] + return colored_tokens + + +def check_rl_example_labels(example, tokenizer, text_only=False): + field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" + + input_tokens = example[field_prompt] + labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] + + # Process and color each type of token + colored_tokens = process_tokens_for_rl_debug( + input_tokens, "yellow", tokenizer, text_only + ) + colored_chosens = process_tokens_for_rl_debug( + labels_chosen, "green", tokenizer, text_only + ) + colored_rejecteds = process_tokens_for_rl_debug( + labels_rejected, "red", tokenizer, text_only + ) + + # Create a delimiter based on text_only flag + delimiter = "" if text_only else " " + + # Logging information + LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") + LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") + LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") + + return delimiter.join(colored_tokens) + + GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] GLAIVE_TO_SHAREGPT_ROLE = { "SYSTEM": "system",