Skip to content

Commit

Permalink
Clean up code for PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
Glavin001 committed Sep 9, 2023
1 parent 83e6b29 commit aaf4d1e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
// "env": {"CUDA_LAUNCH_BLOCKING": "1"}
},
]
}
}
4 changes: 0 additions & 4 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,6 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape

# if cu_seqlens is not None and max_seqlen is not None:
# if cu_seqlens is not None and max_seqlen is not None and self.training:
# if cu_seqlens is not None and max_seqlen is not None and query_states.shape == key_states.shape:
# if cu_seqlens is not None and max_seqlen is not None and len(cu_seqlens[0]) > 2:
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
Expand Down
162 changes: 66 additions & 96 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
from typing import TYPE_CHECKING, Dict, List

import itertools
import evaluate
import numpy as np
import pandas as pd
Expand All @@ -16,12 +15,12 @@
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
GenerationConfig,
Trainer,
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
GenerationConfig,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

Expand Down Expand Up @@ -324,8 +323,6 @@ def on_evaluate(


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
LOG.info("log_prediction_callback_factory")

class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""

Expand All @@ -338,70 +335,53 @@ def on_evaluate(
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
model,
# tokenizer,
train_dataloader,
eval_dataloader,
**kwargs,
):
LOG.info("=" * 80)
LOG.info("logging predictions")

trainer.model.eval()
device = torch.device(self.cfg.device)

def logits_to_tokens(logits) -> str:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids

def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i-1))
ranges.append((start, i - 1))
start = i
ranges.append((start, len(lst)-1)) # for the last range
end = len(lst) - 1
ranges.append((start, end))
return ranges

def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table(
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion",
]
)
row_index = 0
max_new_tokens = 128

# Initialize an empty wandb.Table
table = wandb.Table(columns=["id", "Prompt", "Correct Completion", "Predicted Completion 1", "Predicted Completion 2"])

batch_index = 0
for batch in tqdm(table_dataloader, total=len(table_dataloader)):
# For each batch I want prompt, completion, 2x predictions

# (loss, logits, labels) = trainer.prediction_step(
# (batch_loss, batch_logits, batch_labels) = trainer.prediction_step(
# trainer.model,
# batch,
# prediction_loss_only=False,
# )

batch_labels = batch['labels'].to(device)
batch_input_ids = batch['input_ids'].to(device)
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if 'position_ids' in batch:
batch_pos_ids = batch['position_ids'].tolist()
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch['input_ids'])
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []
# completion_texts = []
# prediction_texts = []

# for input_ids in batch['input_ids']:
# for batch_item_idx, (input_ids, labels) in enumerate(zip(batch['input_ids'], logits, labels)):
# for batch_item_idx, (input_ids, logits, labels) in enumerate(zip(batch['input_ids'].to(device), batch_logits, batch_labels)):
for batch_item_idx, (input_ids_all, labels_all, pos_ids) in enumerate(zip(batch_input_ids, batch_labels, batch_pos_ids)):

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids, batch_labels, batch_pos_ids
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all)-1)]
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

Expand All @@ -410,37 +390,33 @@ def log_table_from_dataloader(name: str, table_dataloader):
if start == end:
continue

input_ids = input_ids_all[start:end+1]
labels = labels_all[start:end+1]
# input_ids[start:end] = tokenizer.pad_token_id
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = (labels == IGNORE_INDEX)
tokens_with_loss = (labels != IGNORE_INDEX)
tokens_exclude_padding = (input_ids != tokenizer.pad_token_id)
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id

prompt_token_includes = tokens_without_loss & tokens_exclude_padding
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
# completion_text = tokenizer.decode(completion_token_ids)
# completion_texts.append(completion_text)

# completion_logit = logits[tokens_with_loss]
# predicted_tokens = logits_to_tokens(completion_logit)
# prediction_text = tokenizer.decode(predicted_tokens)
# prediction_texts.append(prediction_text)

prompt_texts = tokenizer.batch_decode(prompt_token_ids_list, skip_special_tokens=True)
completion_texts = tokenizer.batch_decode(completion_token_ids_list, skip_special_tokens=True)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
generation_config = GenerationConfig(
# repetition_penalty=1.1,
max_new_tokens=128,
# max_new_tokens=32,
max_new_tokens=max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
Expand All @@ -452,47 +428,41 @@ def log_table_from_dataloader(name: str, table_dataloader):
output_scores=False,
)

encoding = tokenizer(prompt_texts, padding=True, return_tensors='pt').to(self.cfg.device)
new_predictions = trainer.model.generate(**encoding, generation_config=generation_config) # FIXME: when sample_packing=True then error: "TypeError: varlen_fwd(): incompatible function arguments."

new_prediction_all_tokens = new_predictions["sequences"].cpu().tolist()
new_prediction_without_prompt_tokens_list = []
for prompt_token_ids, new_prediction_tokens in zip(prompt_token_ids_list, new_prediction_all_tokens):
new_prediction_without_prompt_tokens = new_prediction_tokens[len(prompt_token_ids):]
new_prediction_without_prompt_tokens_list.append(new_prediction_without_prompt_tokens)

new_predicted_texts = tokenizer.batch_decode(new_prediction_without_prompt_tokens_list, skip_special_tokens=True)
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)

# for i, (prompt_text, completion_text, prediction_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, prediction_texts, new_predicted_texts)):
for i, (prompt_text, completion_text, new_predicted_text) in enumerate(zip(prompt_texts, completion_texts, new_predicted_texts)):
prediction_text = ""
table.add_data(i, prompt_text, completion_text, prediction_text, new_predicted_text)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

batch_index += 1
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

wandb.run.log({ f"{name} - Predictions vs Ground Truth": table })
for prompt_text, completion_text, prediction_text in zip(
prompt_texts, completion_texts, predicted_texts
):
table.add_data(
row_index, prompt_text, completion_text, prediction_text
)
row_index += 1

# log_table_from_dataloader("Train", train_dataloader)
# log_table_from_dataloader("Train", train_dataloader)
wandb.run.log({f"{name} - Predictions vs Ground Truth": table})

# # Get first 10 records from train_dataloader as a new dataloader
# train_data_subset = [next(iter(train_dataloader)) for _ in range(10)]
# train_dataloader_subset = torch.utils.data.DataLoader(train_data_subset, batch_size=train_dataloader.batch_size, shuffle=False)
# log_table_from_dataloader("Train Subset", train_dataloader_subset)

log_table_from_dataloader("Eval", eval_dataloader)

return control

return LogPredictionCallback


def group_sublists_by(lst: List[int], value: int) -> List[List[int]]:
"""
Group sublists of lst by value
"""
grouped = []
for key, group in itertools.groupby(lst, lambda x: x == value):
if key:
grouped.append(list(group))
return grouped

0 comments on commit aaf4d1e

Please sign in to comment.