From 14d26e15acab90327b1e67e80f27df21aefa2a61 Mon Sep 17 00:00:00 2001 From: Glavin Wiechert Date: Tue, 12 Sep 2023 06:51:26 +0000 Subject: [PATCH] Add eval_table_size, eval_table_max_new_tokens configs & clean up code --- docker-compose.yaml | 11 +-- examples/llama-2/lora.yml | 2 + .../monkeypatch/llama_attn_hijack_flash.py | 2 + src/axolotl/utils/callbacks.py | 74 +++++++++++++------ src/axolotl/utils/config.py | 2 + 5 files changed, 61 insertions(+), 30 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 143d5f6c91..d40422f94f 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,23 +1,20 @@ # version: '3.8' services: axolotl: - # build: - # context: . - # dockerfile: ./docker/Dockerfile - image: winglian/axolotl:main-py3.10-cu118-2.0.1 + build: + context: . + dockerfile: ./docker/Dockerfile volumes: - .:/workspace/axolotl - ~/.cache/huggingface/:/root/.cache/huggingface/ # set environment variables environment: - WANDB_API_KEY=${WANDB_API_KEY} - ports: - - "5678:5678" deploy: resources: reservations: devices: - driver: nvidia - count: 1 + # count: 1 capabilities: [gpu] command: tail -f /dev/null diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 2a0af130be..fc2307faed 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -55,6 +55,8 @@ flash_attention: true warmup_steps: 10 eval_steps: 20 +eval_table_size: 5 +eval_table_max_new_tokens: 128 save_steps: debug: deepspeed: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6b3fd4355e..f6adcdac25 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -262,6 +262,8 @@ def flashattn_forward( if attention_mask is not None else None, ) + if q_unpad.dtype != kv_unpad.dtype: + kv_unpad = kv_unpad.to(q_unpad.dtype) output_unpad = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 2272ce0a2b..be81503cbf 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -339,9 +339,33 @@ def on_evaluate( eval_dataloader, **kwargs, ): + eval_table_size = self.cfg.eval_table_size + + if eval_table_size <= 0: + return control + trainer.model.eval() device = torch.device(self.cfg.device) + generation_config = GenerationConfig( + max_new_tokens=self.cfg.eval_table_max_new_tokens, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + + def logits_to_tokens(logits) -> str: + probabilities = torch.softmax(logits, dim=-1) + # Get the predicted token ids (the ones with the highest probability) + predicted_token_ids = torch.argmax(probabilities, dim=-1) + return predicted_token_ids + def find_ranges(lst): ranges = [] start = 0 @@ -359,13 +383,17 @@ def log_table_from_dataloader(name: str, table_dataloader): "id", "Prompt", "Correct Completion", - "Predicted Completion", + "Predicted Completion (model.generate)", + "Predicted Completion (trainer.prediction_step)", ] ) row_index = 0 - max_new_tokens = 128 - for batch in tqdm(table_dataloader, total=len(table_dataloader)): + for batch in tqdm(table_dataloader): + + if row_index > eval_table_size: + break + batch_labels = batch["labels"].to(device) batch_input_ids = batch["input_ids"].to(device) @@ -374,11 +402,18 @@ def log_table_from_dataloader(name: str, table_dataloader): else: batch_pos_ids = [None] * len(batch["input_ids"]) + (_, batch_logits, _) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + prompt_token_ids_list = [] + pred_step_token_ids_list = [] completion_token_ids_list = [] - for input_ids_all, labels_all, pos_ids in zip( - batch_input_ids, batch_labels, batch_pos_ids + for input_ids_all, labels_all, pos_ids, logits in zip( + batch_input_ids, batch_labels, batch_pos_ids, batch_logits, ): if pos_ids is None: pos_ranges = [(0, len(input_ids_all) - 1)] @@ -396,7 +431,6 @@ def log_table_from_dataloader(name: str, table_dataloader): tokens_without_loss = labels == IGNORE_INDEX tokens_with_loss = labels != IGNORE_INDEX tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = ( tokens_without_loss & tokens_exclude_padding ) @@ -407,27 +441,20 @@ def log_table_from_dataloader(name: str, table_dataloader): completion_token_ids = input_ids[tokens_with_loss] completion_token_ids_list.append(completion_token_ids) + pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss] + pred_step_token_ids_list.append(pred_step_token_ids) + prompt_texts = tokenizer.batch_decode( prompt_token_ids_list, skip_special_tokens=True ) completion_texts = tokenizer.batch_decode( completion_token_ids_list, skip_special_tokens=True ) + pred_step_texts = tokenizer.batch_decode( + pred_step_token_ids_list, skip_special_tokens=True + ) with torch.no_grad(): - generation_config = GenerationConfig( - max_new_tokens=max_new_tokens, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - prompt_encoding = tokenizer( prompt_texts, padding=True, return_tensors="pt" ).to(self.cfg.device) @@ -451,17 +478,18 @@ def log_table_from_dataloader(name: str, table_dataloader): prediction_without_prompt_tokens_list, skip_special_tokens=True ) - for prompt_text, completion_text, prediction_text in zip( - prompt_texts, completion_texts, predicted_texts + for prompt_text, completion_text, prediction_text, pred_step_text in zip( + prompt_texts, completion_texts, predicted_texts, pred_step_texts ): table.add_data( - row_index, prompt_text, completion_text, prediction_text + row_index, prompt_text, completion_text, prediction_text, pred_step_text ) row_index += 1 wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) - log_table_from_dataloader("Eval", eval_dataloader) + if is_main_process(): + log_table_from_dataloader("Eval", eval_dataloader) return control diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 93a23f7738..ce1f2255b8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -48,6 +48,8 @@ def normalize_config(cfg): ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + cfg.eval_table_size = cfg.eval_table_size or 0 + cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128 choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: