diff --git a/README.md b/README.md index 30f7b4844f..3e3d06801a 100644 --- a/README.md +++ b/README.md @@ -543,6 +543,9 @@ eval_steps: # leave empty to eval at each epoch save_total_limit: # checkpoints saved at a time max_steps: +eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 +eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128 + # save model as safetensors (require safetensors package) save_safetensors: diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 2a0af130be..fc2307faed 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -55,6 +55,8 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 save_steps: debug: deepspeed: diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 3ad2a7e4fd..872b8935d4 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -57,6 +57,7 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 save_steps: debug: deepspeed: diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml new file mode 100644 index 0000000000..a53c9c831b --- /dev/null +++ b/examples/llama-2/tiny-llama.yml @@ -0,0 +1,69 @@ +base_model: PY007/TinyLlama-1.1B-step-50K-105b +base_model_config: PY007/TinyLlama-1.1B-step-50K-105b + +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 20 +eval_table_size: 5 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 39cfb5c173..f6adcdac25 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -194,7 +194,7 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape - if cu_seqlens is not None and max_seqlen is not None: + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 @@ -262,6 +262,8 @@ def flashattn_forward( if attention_mask is not None else None, ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) output_unpad = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 8fc5a918b3..2f7b4fd1dd 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -11,10 +11,13 @@ import pandas as pd import torch import torch.distributed as dist +import wandb from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( + GenerationConfig, + Trainer, TrainerCallback, TrainerControl, TrainerState, @@ -317,3 +320,191 @@ def on_evaluate( trainer.log(results) return BenchEvalCallback + + +def log_prediction_callback_factory(trainer: Trainer, tokenizer): + class LogPredictionCallback(TrainerCallback): + """Callback to log prediction values during each evaluation""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_evaluate( + self, + args: AxolotlTrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + train_dataloader, # pylint: disable=unused-argument + eval_dataloader, + **kwargs, # pylint: disable=unused-argument + ): + eval_table_size = self.cfg.eval_table_size + + if eval_table_size <= 0: + return control + + trainer.model.eval() + device = torch.device(self.cfg.device) + + # pylint: disable=duplicate-code + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_table_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def logits_to_tokens(logits) -> str: + probabilities = torch.softmax(logits, dim=-1) + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + return predicted_token_ids + + def find_ranges(lst): + ranges = [] + start = 0 + for i in range(1, len(lst)): + if lst[i] == 0: + ranges.append((start, i - 1)) + start = i + end = len(lst) - 1 + ranges.append((start, end)) + return ranges + + def log_table_from_dataloader(name: str, table_dataloader): + table = wandb.Table( + columns=[ + "id", + "Prompt", + "Correct Completion", + "Predicted Completion (model.generate)", + "Predicted Completion (trainer.prediction_step)", + ] + ) + row_index = 0 + + for batch in tqdm(table_dataloader): + if row_index > eval_table_size: + break + + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) + + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() + else: + batch_pos_ids = [None] * len(batch["input_ids"]) + + (_, batch_logits, _) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + + prompt_token_ids_list = [] + pred_step_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids, logits in zip( + batch_input_ids, + batch_labels, + batch_pos_ids, + batch_logits, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = input_ids != tokenizer.pad_token_id + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + pred_step_token_ids = logits_to_tokens( + logits[start : end + 1] + )[tokens_with_loss] + pred_step_token_ids_list.append(pred_step_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) + pred_step_texts = tokenizer.batch_decode( + pred_step_token_ids_list, skip_special_tokens=True + ) + + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(self.cfg.device) + predictions = trainer.model.generate( + **prompt_encoding, generation_config=generation_config + ) + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) + + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, skip_special_tokens=True + ) + + for ( + prompt_text, + completion_text, + prediction_text, + pred_step_text, + ) in zip( + prompt_texts, completion_texts, predicted_texts, pred_step_texts + ): + table.add_data( + row_index, + prompt_text, + completion_text, + prediction_text, + pred_step_text, + ) + row_index += 1 + + wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) + + if is_main_process(): + log_table_from_dataloader("Eval", eval_dataloader) + + return control + + return LogPredictionCallback diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 93a23f7738..ce1f2255b8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -48,6 +48,8 @@ def normalize_config(cfg): ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 9f0795af76..e4ff7b5d42 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -340,10 +340,10 @@ def load_model( if ( hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings - and cfg.sequence_len >= model.config.max_position_embeddings + and cfg.sequence_len > model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" ) model.config.max_position_embeddings = cfg.sequence_len diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f91f4e318e..f7d0b4329a 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -30,6 +30,7 @@ SaveBetterTransformerModelCallback, SavePeftModelCallback, bench_eval_callback_factory, + log_prediction_callback_factory, ) from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader @@ -719,6 +720,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ **trainer_kwargs, ) + if cfg.use_wandb and cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) + trainer.add_callback(LogPredictionCallback(cfg)) + if cfg.do_bench_eval: trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))