diff --git a/.vscode/launch.json b/.vscode/launch.json index e264f9d69f..3f60f05d36 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -34,4 +34,4 @@ // "env": {"CUDA_LAUNCH_BLOCKING": "1"} }, ] -} \ No newline at end of file +} diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d90d5e5497..6b3fd4355e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -194,10 +194,6 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - # if cu_seqlens is not None and max_seqlen is not None: - # if cu_seqlens is not None and max_seqlen is not None and self.training: - # if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape: - # if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8ab19707bb..2272ce0a2b 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -6,7 +6,6 @@ import os from typing import TYPE_CHECKING, Dict, List -import itertools import evaluate import numpy as np import pandas as pd @@ -16,12 +15,12 @@ from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + GenerationConfig, Trainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments, - GenerationConfig, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy @@ -324,8 +323,6 @@ def on_evaluate( def log_prediction_callback_factory(trainer: Trainer, tokenizer): - LOG.info("log_prediction_callback_factory") - class LogPredictionCallback(TrainerCallback): """Callback to log prediction values during each evaluation""" @@ -338,70 +335,53 @@ def on_evaluate( args: AxolotlTrainingArguments, state: TrainerState, control: TrainerControl, - model, - # tokenizer, train_dataloader, eval_dataloader, **kwargs, ): - LOG.info("=" * 80) - LOG.info("logging predictions") - trainer.model.eval() device = torch.device(self.cfg.device) - def logits_to_tokens(logits) -> str: - probabilities = torch.softmax(logits, dim=-1) - # Get the predicted token ids (the ones with the highest probability) - predicted_token_ids = torch.argmax(probabilities, dim=-1) - return predicted_token_ids - def find_ranges(lst): ranges = [] start = 0 for i in range(1, len(lst)): if lst[i] == 0: - ranges.append((start, i-1)) + ranges.append((start, i - 1)) start = i - ranges.append((start, len(lst)-1)) # for the last range + end = len(lst) - 1 + ranges.append((start, end)) return ranges def log_table_from_dataloader(name: str, table_dataloader): + table = wandb.Table( + columns=[ + "id", + "Prompt", + "Correct Completion", + "Predicted Completion", + ] + ) + row_index = 0 + max_new_tokens = 128 - # Initialize an empty wandb.Table - table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"]) - - batch_index = 0 for batch in tqdm(table_dataloader, total=len(table_dataloader)): - # For each batch I want prompt, completion, 2x predictions - - # (loss, logits, labels) = trainer.prediction_step( - # (batch_loss, batch_logits, batch_labels) = trainer.prediction_step( - # trainer.model, - # batch, - # prediction_loss_only=False, - # ) - - batch_labels = batch['labels'].to(device) - batch_input_ids = batch['input_ids'].to(device) + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) - if 'position_ids' in batch: - batch_pos_ids = batch['position_ids'].tolist() + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() else: - batch_pos_ids = [None] * len(batch['input_ids']) + batch_pos_ids = [None] * len(batch["input_ids"]) prompt_token_ids_list = [] completion_token_ids_list = [] - # completion_texts = [] - # prediction_texts = [] - - # for input_ids in batch['input_ids']: - # for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)): - # for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'].to(device), batch_logits, batch_labels)): - for batch_item_idx, (input_ids_all, labels_all, pos_ids) in enumerate(zip(batch_input_ids, batch_labels, batch_pos_ids)): + for input_ids_all, labels_all, pos_ids in zip( + batch_input_ids, batch_labels, batch_pos_ids + ): if pos_ids is None: - pos_ranges = [(0, len(input_ids_all)-1)] + pos_ranges = [(0, len(input_ids_all) - 1)] else: pos_ranges = find_ranges(pos_ids) @@ -410,37 +390,33 @@ def log_table_from_dataloader(name: str, table_dataloader): if start == end: continue - input_ids = input_ids_all[start:end+1] - labels = labels_all[start:end+1] - # input_ids[start:end] = tokenizer.pad_token_id + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] - tokens_without_loss = (labels == IGNORE_INDEX) - tokens_with_loss = (labels != IGNORE_INDEX) - tokens_exclude_padding = (input_ids != tokenizer.pad_token_id) + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = tokens_without_loss & tokens_exclude_padding + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) prompt_token_ids = input_ids[prompt_token_includes] prompt_token_ids_list.append(prompt_token_ids) completion_token_ids = input_ids[tokens_with_loss] completion_token_ids_list.append(completion_token_ids) - # completion_text = tokenizer.decode(completion_token_ids) - # completion_texts.append(completion_text) - - # completion_logit = logits[tokens_with_loss] - # predicted_tokens = logits_to_tokens(completion_logit) - # prediction_text = tokenizer.decode(predicted_tokens) - # prediction_texts.append(prediction_text) - prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True) - completion_texts = tokenizer.batch_decode(completion_token_ids_list, skip_special_tokens=True) + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) with torch.no_grad(): generation_config = GenerationConfig( - # repetition_penalty=1.1, - max_new_tokens=128, - # max_new_tokens=32, + max_new_tokens=max_new_tokens, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, @@ -452,47 +428,41 @@ def log_table_from_dataloader(name: str, table_dataloader): output_scores=False, ) - encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device) - new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) # FIXME: when sample_packing=True then error: "TypeError: varlen_fwd(): incompatible function arguments." - - new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist() - new_prediction_without_prompt_tokens_list = [] - for prompt_token_ids, new_prediction_tokens in zip(prompt_token_ids_list, new_prediction_all_tokens): - new_prediction_without_prompt_tokens = new_prediction_tokens[len(prompt_token_ids):] - new_prediction_without_prompt_tokens_list.append(new_prediction_without_prompt_tokens) - - new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True) + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) - # for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)): - for i, (prompt_text, completion_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, new_predicted_texts)): - prediction_text = "" - table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text) + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) - batch_index += 1 + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) - wandb.run.log({ f"{name} - Predictions vs Ground Truth": table }) + for prompt_text, completion_text, prediction_text in zip( + prompt_texts, completion_texts, predicted_texts + ): + table.add_data( + row_index, prompt_text, completion_text, prediction_text + ) + row_index += 1 - # log_table_from_dataloader("Train", train_dataloader) - # log_table_from_dataloader("Train", train_dataloader) + wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) - # # Get first 10 records from train_dataloader as a new dataloader - # train_data_subset = [next(iter(train_dataloader)) for _ in range(10)] - # train_dataloader_subset = torch.utils.data.DataLoader(train_data_subset, batch_size=train_dataloader.batch_size, shuffle=False) - # log_table_from_dataloader("Train Subset", train_dataloader_subset) - log_table_from_dataloader("Eval", eval_dataloader) return control return LogPredictionCallback - - -def group_sublists_by(lst: List[int], value: int) -> List[List[int]]: - """ - Group sublists of lst by value - """ - grouped = [] - for key, group in itertools.groupby(lst, lambda x: x == value): - if key: - grouped.append(list(group)) - return grouped