Skip to content

Commit

Permalink
fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci]
Browse files Browse the repository at this point in the history
* fix: ds3 and fsdp lmbench eval

* chore: update comment

* fix: test signature
  • Loading branch information
NanoCode012 authored Nov 30, 2024
1 parent 6e0fb4a commit f4cabc2
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 87 deletions.
173 changes: 98 additions & 75 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from trl.models import unwrap_model_for_generation

from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
Expand All @@ -46,6 +47,7 @@
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments


IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")

Expand All @@ -64,7 +66,10 @@ def on_step_end(
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
control.should_evaluate = True
return control

Expand Down Expand Up @@ -375,7 +380,10 @@ def __maybe_load_metrics(self):
for metric in self.cfg.eval_causal_lm_metrics:
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
metrics[metric] = Perplexity(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
else:
try:
metrics[metric] = evaluate.load(metric)
Expand All @@ -392,8 +400,11 @@ def on_evaluate(
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
trainer.model_wrapped.eval()

device = torch.device(
self.cfg.device
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded

# pylint: disable=duplicate-code
generation_config = GenerationConfig(
Expand Down Expand Up @@ -430,6 +441,10 @@ def compute(metric: evaluate.Metric, **kwargs):
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}

if isinstance(metric, Perplexity):
metric_kwargs["model"] = trainer.model_wrapped

metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
Expand Down Expand Up @@ -465,89 +480,97 @@ def evaluate_preds(sources, predictions, references):
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []

for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []
with unwrap_model_for_generation(
trainer.model_wrapped, trainer.accelerator
) as unwrapped_model:
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
batch_pos_ids = [None] * len(batch["input_ids"])

prompt_token_ids_list = []
completion_token_ids_list = []

for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)

for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = (
input_ids != tokenizer.pad_token_id
)
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(device)

tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
predictions = unwrapped_model.generate(
**prompt_encoding, generation_config=generation_config
)

prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)

completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)

prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)

with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
del prompt_encoding

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)

prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list,
skip_special_tokens=True,
)

predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)

eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)

return eval_src, eval_pred, eval_ref

if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))

return control

Expand Down
16 changes: 10 additions & 6 deletions src/axolotl/utils/callbacks/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer

from axolotl.utils.distributed import is_main_process


class Perplexity:
"""
Expand All @@ -17,49 +19,51 @@ class Perplexity:

def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"

def _feature_names(self) -> List[str]:
return ["references"]

def compute(
self,
model: PreTrainedModel,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
Compute perplexity in a fixed length sliding window across the sequence.
"""
assert references is not None, "Missing parameter: references"

model.eval()

references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device)
input_ids = input_ids.to(model.device)

sequence_length = input_ids.size(1)

losses = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
for begin_loc in tqdm(
range(0, sequence_length, self.stride), disable=not is_main_process()
):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
labels_slice = input_ids_slice.clone()
labels_slice[:, :-trg_len] = -100

with torch.no_grad():
outputs: CausalLMOutput = self.model(
outputs: CausalLMOutput = model(
input_ids=input_ids_slice, labels=labels_slice
)

Expand Down
15 changes: 9 additions & 6 deletions tests/test_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,33 @@

@fixture()
def metric(tokenizer):
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
return Perplexity(tokenizer=tokenizer, max_seq_len=512)

return Perplexity(model, tokenizer, 512)

@fixture()
def model():
return AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)


@fixture()
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)


def test_perplexity_longer_than_stride(metric):
def test_perplexity_longer_than_stride(model, metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = """
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
"""
result = metric.compute([sample_text])
result = metric.compute(model, [sample_text])
ppl = result["score"]
assert round(ppl, 2) == 5.37


def test_perplexity_short(metric):
def test_perplexity_short(model, metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute([sample_text])
result = metric.compute(model, [sample_text])
ppl = result["score"]
assert round(ppl, 2) == 10.02

0 comments on commit f4cabc2

Please sign in to comment.