diff --git a/README.md b/README.md
index 30f7b4844f..3e3d06801a 100644
--- a/README.md
+++ b/README.md
@@ -543,6 +543,9 @@ eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time
max_steps:
+eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
+eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
+
# save model as safetensors (require safetensors package)
save_safetensors:
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index 2a0af130be..fc2307faed 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -55,6 +55,8 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
+eval_table_size: 5
+eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index 3ad2a7e4fd..872b8935d4 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -57,6 +57,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
+eval_table_size: 5
save_steps:
debug:
deepspeed:
diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml
new file mode 100644
index 0000000000..a53c9c831b
--- /dev/null
+++ b/examples/llama-2/tiny-llama.yml
@@ -0,0 +1,69 @@
+base_model: PY007/TinyLlama-1.1B-step-50K-105b
+base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
+
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: true
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./lora-out
+
+sequence_len: 4096
+sample_packing: true
+
+adapter: lora
+lora_model_dir:
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_steps: 10
+eval_steps: 20
+eval_table_size: 5
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index 39cfb5c173..f6adcdac25 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -194,7 +194,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
- if cu_seqlens is not None and max_seqlen is not None:
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
@@ -262,6 +262,8 @@ def flashattn_forward(
if attention_mask is not None
else None,
)
+ if q_unpad.dtype != kv_unpad.dtype:
+ kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index 8fc5a918b3..2f7b4fd1dd 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -11,10 +11,13 @@
import pandas as pd
import torch
import torch.distributed as dist
+import wandb
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
+ GenerationConfig,
+ Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
@@ -317,3 +320,191 @@ def on_evaluate(
trainer.log(results)
return BenchEvalCallback
+
+
+def log_prediction_callback_factory(trainer: Trainer, tokenizer):
+ class LogPredictionCallback(TrainerCallback):
+ """Callback to log prediction values during each evaluation"""
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.logged = False
+
+ def on_evaluate(
+ self,
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl,
+ train_dataloader, # pylint: disable=unused-argument
+ eval_dataloader,
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ eval_table_size = self.cfg.eval_table_size
+
+ if eval_table_size <= 0:
+ return control
+
+ trainer.model.eval()
+ device = torch.device(self.cfg.device)
+
+ # pylint: disable=duplicate-code
+ generation_config = GenerationConfig(
+ max_new_tokens=self.cfg.eval_table_max_new_tokens,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ output_scores=False,
+ )
+
+ def logits_to_tokens(logits) -> str:
+ probabilities = torch.softmax(logits, dim=-1)
+ # Get the predicted token ids (the ones with the highest probability)
+ predicted_token_ids = torch.argmax(probabilities, dim=-1)
+ return predicted_token_ids
+
+ def find_ranges(lst):
+ ranges = []
+ start = 0
+ for i in range(1, len(lst)):
+ if lst[i] == 0:
+ ranges.append((start, i - 1))
+ start = i
+ end = len(lst) - 1
+ ranges.append((start, end))
+ return ranges
+
+ def log_table_from_dataloader(name: str, table_dataloader):
+ table = wandb.Table(
+ columns=[
+ "id",
+ "Prompt",
+ "Correct Completion",
+ "Predicted Completion (model.generate)",
+ "Predicted Completion (trainer.prediction_step)",
+ ]
+ )
+ row_index = 0
+
+ for batch in tqdm(table_dataloader):
+ if row_index > eval_table_size:
+ break
+
+ batch_labels = batch["labels"].to(device)
+ batch_input_ids = batch["input_ids"].to(device)
+
+ if "position_ids" in batch:
+ batch_pos_ids = batch["position_ids"].tolist()
+ else:
+ batch_pos_ids = [None] * len(batch["input_ids"])
+
+ (_, batch_logits, _) = trainer.prediction_step(
+ trainer.model,
+ batch,
+ prediction_loss_only=False,
+ )
+
+ prompt_token_ids_list = []
+ pred_step_token_ids_list = []
+ completion_token_ids_list = []
+
+ for input_ids_all, labels_all, pos_ids, logits in zip(
+ batch_input_ids,
+ batch_labels,
+ batch_pos_ids,
+ batch_logits,
+ ):
+ if pos_ids is None:
+ pos_ranges = [(0, len(input_ids_all) - 1)]
+ else:
+ pos_ranges = find_ranges(pos_ids)
+
+ for pos_range in pos_ranges:
+ start, end = pos_range
+ if start == end:
+ continue
+
+ input_ids = input_ids_all[start : end + 1]
+ labels = labels_all[start : end + 1]
+
+ tokens_without_loss = labels == IGNORE_INDEX
+ tokens_with_loss = labels != IGNORE_INDEX
+ tokens_exclude_padding = input_ids != tokenizer.pad_token_id
+ prompt_token_includes = (
+ tokens_without_loss & tokens_exclude_padding
+ )
+
+ prompt_token_ids = input_ids[prompt_token_includes]
+ prompt_token_ids_list.append(prompt_token_ids)
+
+ completion_token_ids = input_ids[tokens_with_loss]
+ completion_token_ids_list.append(completion_token_ids)
+
+ pred_step_token_ids = logits_to_tokens(
+ logits[start : end + 1]
+ )[tokens_with_loss]
+ pred_step_token_ids_list.append(pred_step_token_ids)
+
+ prompt_texts = tokenizer.batch_decode(
+ prompt_token_ids_list, skip_special_tokens=True
+ )
+ completion_texts = tokenizer.batch_decode(
+ completion_token_ids_list, skip_special_tokens=True
+ )
+ pred_step_texts = tokenizer.batch_decode(
+ pred_step_token_ids_list, skip_special_tokens=True
+ )
+
+ with torch.no_grad():
+ prompt_encoding = tokenizer(
+ prompt_texts, padding=True, return_tensors="pt"
+ ).to(self.cfg.device)
+ predictions = trainer.model.generate(
+ **prompt_encoding, generation_config=generation_config
+ )
+
+ prediction_all_tokens = predictions["sequences"].cpu().tolist()
+ prediction_without_prompt_tokens_list = []
+ for prompt_token_ids, prediction_tokens in zip(
+ prompt_token_ids_list, prediction_all_tokens
+ ):
+ prediction_without_prompt_tokens = prediction_tokens[
+ len(prompt_token_ids) :
+ ]
+ prediction_without_prompt_tokens_list.append(
+ prediction_without_prompt_tokens
+ )
+
+ predicted_texts = tokenizer.batch_decode(
+ prediction_without_prompt_tokens_list, skip_special_tokens=True
+ )
+
+ for (
+ prompt_text,
+ completion_text,
+ prediction_text,
+ pred_step_text,
+ ) in zip(
+ prompt_texts, completion_texts, predicted_texts, pred_step_texts
+ ):
+ table.add_data(
+ row_index,
+ prompt_text,
+ completion_text,
+ prediction_text,
+ pred_step_text,
+ )
+ row_index += 1
+
+ wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
+
+ if is_main_process():
+ log_table_from_dataloader("Eval", eval_dataloader)
+
+ return control
+
+ return LogPredictionCallback
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index 93a23f7738..ce1f2255b8 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -48,6 +48,8 @@ def normalize_config(cfg):
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
+ cfg.eval_table_size = cfg.eval_table_size or 0
+ cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 9f0795af76..e4ff7b5d42 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -340,10 +340,10 @@ def load_model(
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
- and cfg.sequence_len >= model.config.max_position_embeddings
+ and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
- f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
+ f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index f91f4e318e..f7d0b4329a 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -30,6 +30,7 @@
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
bench_eval_callback_factory,
+ log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -719,6 +720,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs,
)
+ if cfg.use_wandb and cfg.eval_table_size > 0:
+ LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
+ trainer.add_callback(LogPredictionCallback(cfg))
+
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))