Skip to content

Commit

Permalink
Add eval_table_size, eval_table_max_new_tokens configs & clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
Glavin001 committed Sep 12, 2023
1 parent aaf4d1e commit 14d26e1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 30 deletions.
11 changes: 4 additions & 7 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
# version: '3.8'
services:
axolotl:
# build:
# context: .
# dockerfile: ./docker/Dockerfile
image: winglian/axolotl:main-py3.10-cu118-2.0.1
build:
context: .
dockerfile: ./docker/Dockerfile
volumes:
- .:/workspace/axolotl
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
- WANDB_API_KEY=${WANDB_API_KEY}
ports:
- "5678:5678"
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
# count: 1
capabilities: [gpu]
command: tail -f /dev/null
2 changes: 2 additions & 0 deletions examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ flash_attention: true

warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def flashattn_forward(
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
Expand Down
74 changes: 51 additions & 23 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,33 @@ def on_evaluate(
eval_dataloader,
**kwargs,
):
eval_table_size = self.cfg.eval_table_size

if eval_table_size <= 0:
return control

trainer.model.eval()
device = torch.device(self.cfg.device)

generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

def logits_to_tokens(logits) -> str:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids

def find_ranges(lst):
ranges = []
start = 0
Expand All @@ -359,13 +383,17 @@ def log_table_from_dataloader(name: str, table_dataloader):
"id",
"Prompt",
"Correct Completion",
"Predicted Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0
max_new_tokens = 128

for batch in tqdm(table_dataloader, total=len(table_dataloader)):
for batch in tqdm(table_dataloader):

if row_index > eval_table_size:
break

batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

Expand All @@ -374,11 +402,18 @@ def log_table_from_dataloader(name: str, table_dataloader):
else:
batch_pos_ids = [None] * len(batch["input_ids"])

(_, batch_logits, _) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)

prompt_token_ids_list = []
pred_step_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids, batch_labels, batch_pos_ids
for input_ids_all, labels_all, pos_ids, logits in zip(
batch_input_ids, batch_labels, batch_pos_ids, batch_logits,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
Expand All @@ -396,7 +431,6 @@ def log_table_from_dataloader(name: str, table_dataloader):
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id

prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
Expand All @@ -407,27 +441,20 @@ def log_table_from_dataloader(name: str, table_dataloader):
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

pred_step_token_ids = logits_to_tokens(logits[start : end + 1])[tokens_with_loss]
pred_step_token_ids_list.append(pred_step_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
pred_step_texts = tokenizer.batch_decode(
pred_step_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)

prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
Expand All @@ -451,17 +478,18 @@ def log_table_from_dataloader(name: str, table_dataloader):
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

for prompt_text, completion_text, prediction_text in zip(
prompt_texts, completion_texts, predicted_texts
for prompt_text, completion_text, prediction_text, pred_step_text in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index, prompt_text, completion_text, prediction_text
row_index, prompt_text, completion_text, prediction_text, pred_step_text
)
row_index += 1

wandb.run.log({f"{name} - Predictions vs Ground Truth": table})

log_table_from_dataloader("Eval", eval_dataloader)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

return control

Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def normalize_config(cfg):
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
Expand Down

0 comments on commit 14d26e1

Please sign in to comment.