From f4cabc2351798596b64d38e86c7ec4dc5fd00838 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 30 Nov 2024 08:37:49 +0700 Subject: [PATCH] fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci] * fix: ds3 and fsdp lmbench eval * chore: update comment * fix: test signature --- src/axolotl/utils/callbacks/__init__.py | 173 ++++++++++++---------- src/axolotl/utils/callbacks/perplexity.py | 16 +- tests/test_perplexity.py | 15 +- 3 files changed, 117 insertions(+), 87 deletions(-) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8768bc2bf7..6bf433319e 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -28,6 +28,7 @@ TrainingArguments, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage @@ -46,6 +47,7 @@ if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments + IGNORE_INDEX = -100 LOG = logging.getLogger("axolotl.callbacks") @@ -64,7 +66,10 @@ def on_step_end( control: TrainerControl, **kwargs, ): - if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: + if ( + args.evaluation_strategy == IntervalStrategy.STEPS + and state.global_step == 1 + ): control.should_evaluate = True return control @@ -375,7 +380,10 @@ def __maybe_load_metrics(self): for metric in self.cfg.eval_causal_lm_metrics: if metric == "perplexity": max_seq_len = self.cfg.eval_max_new_tokens - metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len) + metrics[metric] = Perplexity( + tokenizer=tokenizer, + max_seq_len=max_seq_len, + ) else: try: metrics[metric] = evaluate.load(metric) @@ -392,8 +400,11 @@ def on_evaluate( eval_dataloader, **kwargs, # pylint: disable=unused-argument ): - trainer.model.eval() - device = torch.device(self.cfg.device) + trainer.model_wrapped.eval() + + device = torch.device( + self.cfg.device + ) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded # pylint: disable=duplicate-code generation_config = GenerationConfig( @@ -430,6 +441,10 @@ def compute(metric: evaluate.Metric, **kwargs): for k in metric._feature_names() # pylint: disable=protected-access if k in kwargs } + + if isinstance(metric, Perplexity): + metric_kwargs["model"] = trainer.model_wrapped + metric_score = metric.compute(**metric_kwargs) return ( metric_score["score"] @@ -465,89 +480,97 @@ def evaluate_preds(sources, predictions, references): def predict_with_generate(): eval_src, eval_pred, eval_ref = [], [], [] - for batch in tqdm(eval_dataloader): - batch_labels = batch["labels"].to(device) - batch_input_ids = batch["input_ids"].to(device) - - if "position_ids" in batch: - batch_pos_ids = batch["position_ids"].tolist() - else: - batch_pos_ids = [None] * len(batch["input_ids"]) - - prompt_token_ids_list = [] - completion_token_ids_list = [] + with unwrap_model_for_generation( + trainer.model_wrapped, trainer.accelerator + ) as unwrapped_model: + for batch in tqdm(eval_dataloader, disable=not is_main_process()): + batch_labels = batch["labels"].to(device) + batch_input_ids = batch["input_ids"].to(device) - for input_ids_all, labels_all, pos_ids in zip( - batch_input_ids, - batch_labels, - batch_pos_ids, - ): - if pos_ids is None: - pos_ranges = [(0, len(input_ids_all) - 1)] + if "position_ids" in batch: + batch_pos_ids = batch["position_ids"].tolist() else: - pos_ranges = find_ranges(pos_ids) - - for pos_range in pos_ranges: - start, end = pos_range - if start == end: - continue + batch_pos_ids = [None] * len(batch["input_ids"]) + + prompt_token_ids_list = [] + completion_token_ids_list = [] + + for input_ids_all, labels_all, pos_ids in zip( + batch_input_ids, + batch_labels, + batch_pos_ids, + ): + if pos_ids is None: + pos_ranges = [(0, len(input_ids_all) - 1)] + else: + pos_ranges = find_ranges(pos_ids) + + for pos_range in pos_ranges: + start, end = pos_range + if start == end: + continue + + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] + + tokens_without_loss = labels == IGNORE_INDEX + tokens_with_loss = labels != IGNORE_INDEX + tokens_exclude_padding = ( + input_ids != tokenizer.pad_token_id + ) + prompt_token_includes = ( + tokens_without_loss & tokens_exclude_padding + ) + + prompt_token_ids = input_ids[prompt_token_includes] + prompt_token_ids_list.append(prompt_token_ids) + + completion_token_ids = input_ids[tokens_with_loss] + completion_token_ids_list.append(completion_token_ids) + + prompt_texts = tokenizer.batch_decode( + prompt_token_ids_list, skip_special_tokens=True + ) + completion_texts = tokenizer.batch_decode( + completion_token_ids_list, skip_special_tokens=True + ) - input_ids = input_ids_all[start : end + 1] - labels = labels_all[start : end + 1] + with torch.no_grad(): + prompt_encoding = tokenizer( + prompt_texts, padding=True, return_tensors="pt" + ).to(device) - tokens_without_loss = labels == IGNORE_INDEX - tokens_with_loss = labels != IGNORE_INDEX - tokens_exclude_padding = input_ids != tokenizer.pad_token_id - prompt_token_includes = ( - tokens_without_loss & tokens_exclude_padding + predictions = unwrapped_model.generate( + **prompt_encoding, generation_config=generation_config ) - prompt_token_ids = input_ids[prompt_token_includes] - prompt_token_ids_list.append(prompt_token_ids) - - completion_token_ids = input_ids[tokens_with_loss] - completion_token_ids_list.append(completion_token_ids) - - prompt_texts = tokenizer.batch_decode( - prompt_token_ids_list, skip_special_tokens=True - ) - completion_texts = tokenizer.batch_decode( - completion_token_ids_list, skip_special_tokens=True - ) - - with torch.no_grad(): - prompt_encoding = tokenizer( - prompt_texts, padding=True, return_tensors="pt" - ).to(self.cfg.device) - predictions = trainer.model.generate( - **prompt_encoding, generation_config=generation_config - ) + del prompt_encoding + + prediction_all_tokens = predictions["sequences"].cpu().tolist() + prediction_without_prompt_tokens_list = [] + for prompt_token_ids, prediction_tokens in zip( + prompt_token_ids_list, prediction_all_tokens + ): + prediction_without_prompt_tokens = prediction_tokens[ + len(prompt_token_ids) : + ] + prediction_without_prompt_tokens_list.append( + prediction_without_prompt_tokens + ) - prediction_all_tokens = predictions["sequences"].cpu().tolist() - prediction_without_prompt_tokens_list = [] - for prompt_token_ids, prediction_tokens in zip( - prompt_token_ids_list, prediction_all_tokens - ): - prediction_without_prompt_tokens = prediction_tokens[ - len(prompt_token_ids) : - ] - prediction_without_prompt_tokens_list.append( - prediction_without_prompt_tokens + predicted_texts = tokenizer.batch_decode( + prediction_without_prompt_tokens_list, + skip_special_tokens=True, ) - predicted_texts = tokenizer.batch_decode( - prediction_without_prompt_tokens_list, skip_special_tokens=True - ) - - eval_src.extend(prompt_texts) - eval_pred.extend(predicted_texts) - eval_ref.extend(completion_texts) + eval_src.extend(prompt_texts) + eval_pred.extend(predicted_texts) + eval_ref.extend(completion_texts) return eval_src, eval_pred, eval_ref - if is_main_process(): - eval_preds = predict_with_generate() - trainer.log(evaluate_preds(*eval_preds)) + eval_preds = predict_with_generate() + trainer.log(evaluate_preds(*eval_preds)) return control diff --git a/src/axolotl/utils/callbacks/perplexity.py b/src/axolotl/utils/callbacks/perplexity.py index 2e64176812..d3a362c4cd 100644 --- a/src/axolotl/utils/callbacks/perplexity.py +++ b/src/axolotl/utils/callbacks/perplexity.py @@ -8,6 +8,8 @@ from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer +from axolotl.utils.distributed import is_main_process + class Perplexity: """ @@ -17,16 +19,13 @@ class Perplexity: def __init__( self, - model: PreTrainedModel, tokenizer: PreTrainedTokenizer, max_seq_len: int, stride: int = 512, ) -> None: self.max_seq_len = max_seq_len self.stride = stride - self.model = model self.tokenizer = tokenizer - self.device = model.device self.name = "perplexity" def _feature_names(self) -> List[str]: @@ -34,6 +33,7 @@ def _feature_names(self) -> List[str]: def compute( self, + model: PreTrainedModel, references: Optional[List[str]] = None, ) -> Dict[str, float]: """ @@ -41,17 +41,21 @@ def compute( """ assert references is not None, "Missing parameter: references" + model.eval() + references_tokenized = self.tokenizer( references, return_tensors="pt", padding=True, truncation=True ) input_ids: Tensor = references_tokenized["input_ids"] # type: ignore - input_ids = input_ids.to(self.device) + input_ids = input_ids.to(model.device) sequence_length = input_ids.size(1) losses = [] prev_end_loc = 0 - for begin_loc in tqdm(range(0, sequence_length, self.stride)): + for begin_loc in tqdm( + range(0, sequence_length, self.stride), disable=not is_main_process() + ): end_loc = min(begin_loc + self.max_seq_len, sequence_length) trg_len = end_loc - prev_end_loc input_ids_slice = input_ids[:, begin_loc:end_loc] @@ -59,7 +63,7 @@ def compute( labels_slice[:, :-trg_len] = -100 with torch.no_grad(): - outputs: CausalLMOutput = self.model( + outputs: CausalLMOutput = model( input_ids=input_ids_slice, labels=labels_slice ) diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index e66e95d0cd..8688827cec 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -12,9 +12,12 @@ @fixture() def metric(tokenizer): - model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) + return Perplexity(tokenizer=tokenizer, max_seq_len=512) - return Perplexity(model, tokenizer, 512) + +@fixture() +def model(): + return AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) @fixture() @@ -22,20 +25,20 @@ def tokenizer(): return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) -def test_perplexity_longer_than_stride(metric): +def test_perplexity_longer_than_stride(model, metric): # taken from https://huggingface.co/datasets/roneneldan/TinyStories sample_text = """ Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after. One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends. """ - result = metric.compute([sample_text]) + result = metric.compute(model, [sample_text]) ppl = result["score"] assert round(ppl, 2) == 5.37 -def test_perplexity_short(metric): +def test_perplexity_short(model, metric): # taken from https://huggingface.co/datasets/roneneldan/TinyStories sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun." - result = metric.compute([sample_text]) + result = metric.compute(model, [sample_text]) ppl = result["score"] assert round(ppl, 2) == 10.02