From 71c925c02f1f22c8ff855133c4a2b29c4c468d21 Mon Sep 17 00:00:00 2001 From: Sean Date: Mon, 30 Sep 2024 10:16:08 +0800 Subject: [PATCH] Feature/improvement (#100) * support configuring filter_deduplicate * support normal evaluation * support configuring more parameters * dont compute metrics * predict loss only in evaluation * hard code to fix evaluation error * remove label_names * remove_unused_columns=None * custom prediction step * bugfix * pop labels * print eval loss * fix * fix * no_grad * set do_eval * bugfix * set evaluation_strategy no when valid_ds is empty --- angle_emb/angle.py | 30 +++++++++++++++++++++--------- angle_emb/angle_trainer.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6693ede..77ff58e 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -821,7 +821,7 @@ def compute_mlm_loss(self, logits, mask_target_labels): ignore_index=self.pad_token_id, ) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs: bool = False): """ Compute loss for AnglE. :param model: Huggingface model. @@ -859,6 +859,11 @@ def compute_loss(self, model, inputs, return_outputs=False): return (loss, outputs) if return_outputs else loss + @torch.no_grad() + def prediction_step(self, model, inputs, *args, **kwargs): + eval_loss = self.compute_loss(model, inputs, return_outputs=False) + return eval_loss, None, None + class AngleESETrainer(AngleTrainer): """ @@ -1412,13 +1417,15 @@ def detect_dataset_format(self, ds: Dataset): def fit(self, train_ds: Dataset, valid_ds: Optional[Dataset] = None, + valid_ds_for_callback: Optional[Dataset] = None, batch_size: int = 32, output_dir: Optional[str] = None, epochs: int = 1, learning_rate: float = 1e-5, warmup_steps: int = 1000, logging_steps: int = 10, - eval_steps: Optional[int] = None, + eval_steps: int = 1000, + evaluation_strategy: str = 'steps', save_steps: int = 100, save_strategy: str = 'steps', save_total_limit: int = 10, @@ -1439,13 +1446,17 @@ def fit(self, :param train_ds: Dataset. tokenized train dataset. Required. :param valid_ds: Optional[Dataset]. tokenized valid dataset. Default None. + :param valid_ds_for_callback: Optional[Dataset]. tokenized valid dataset for callback use. + The dataset format should be `DatasetFormats.A`. The spearmans' correlation will be computed + after each epoch training and the best model will be saved. Default None. :param batch_size: int. Default 32. :param output_dir: Optional[str]. save dir. Default None. :param epochs: int. Default 1. :param learning_rate: float. Default 1e-5. :param warmup_steps: int. Default 1000. :param logging_steps: int. Default 10. - :param eval_steps: Optional[int]. Default None. + :param eval_steps: int. Default 1000. + :param evaluation_strategy: str. Default 'steps'. :param save_steps: int. Default 100. :param save_strategy: str. Default steps. :param save_total_limit: int. Default 10. @@ -1491,16 +1502,16 @@ def fit(self, trainer_kwargs = {} callbacks = None - if valid_ds is not None: + if valid_ds_for_callback is not None: # check format - for obj in valid_ds: + for obj in valid_ds_for_callback: if obj['extra']['dataset_format'] != DatasetFormats.A: raise ValueError('Currently only support evaluation for DatasetFormats.A.') break best_ckpt_dir = None if output_dir is not None: best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint') - evaluate_callback = EvaluateCallback(self, valid_ds, + evaluate_callback = EvaluateCallback(self, valid_ds_for_callback, partial(self.evaluate, batch_size=batch_size), save_dir=best_ckpt_dir, push_to_hub=push_to_hub, @@ -1519,7 +1530,7 @@ def fit(self, model=self.backbone, dataset_format=self.detect_dataset_format(train_ds), train_dataset=train_ds, - eval_dataset=None, + eval_dataset=valid_ds, loss_kwargs=loss_kwargs, tokenizer=self.tokenizer, args=TrainingArguments( @@ -1530,14 +1541,15 @@ def fit(self, learning_rate=learning_rate, fp16=fp16, logging_steps=logging_steps, + save_steps=save_steps, save_strategy=save_strategy, + evaluation_strategy=evaluation_strategy if valid_ds is not None else 'no', eval_steps=eval_steps, - save_steps=save_steps, output_dir=output_dir, save_total_limit=save_total_limit, load_best_model_at_end=False, ddp_find_unused_parameters=False if self.gpu_count > 1 else None, - label_names=AnglE.special_columns, + remove_unused_columns=False, **argument_kwargs, ), callbacks=callbacks, diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index 653c6f8..0ef536d 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -35,11 +35,20 @@ help='Specify huggingface datasets subset name for valid set, default None') parser.add_argument('--valid_split_name', type=str, default='train', help='Specify huggingface datasets split name for valid set, default `train`') +parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None, + help='Specify huggingface datasets name or local file path for callback valid set. ' + 'The dataset format should be `DatasetFormats.A`. Default None.') +parser.add_argument('--valid_subset_name_for_callback', type=str, default=None, + help='Specify huggingface datasets subset name for valid set for callback use, default None') +parser.add_argument('--valid_split_name_for_callback', type=str, default='train', + help='Specify huggingface datasets split name for valid set for callback use, default `train`') parser.add_argument('--prompt_template', type=str, default=None, help='Specify prompt_template like "xxx: {text}", default None.' 'This prompt will be applied for all text columns.' 'If you want to specify different prompts for different text columns,' 'please handle it in the preprocessing step.') +parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1], + help='Specify filter_duplicate, choices [0, 1], defaut 1') parser.add_argument('--save_dir', type=str, default=None, help='Specify save dir, default None') parser.add_argument('--seed', type=int, default=-1, @@ -84,6 +93,11 @@ parser.add_argument('--max_steps', type=int, default=-1, help='Specify max steps, default -1 (Automatically calculated from epochs)') parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000') +parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch'], + help='Specify save_strategy, default steps') +parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000') +parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch'], + help='Specify evaluation_strategy, default steps') parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32') parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512') parser.add_argument('--streaming', action='store_true', default=False, @@ -227,6 +241,25 @@ def main(): AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), num_proc=args.workers) + valid_ds_for_callback = None + if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None: + logger.info('Validation for callback detected, processing validation...') + if os.path.exists(args.valid_name_or_path_for_callback): + valid_ds_for_callback = load_dataset( + 'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers) + else: + if args.valid_subset_name_for_callback is not None: + valid_ds_for_callback = load_dataset( + args.valid_name_or_path_for_callback, + args.valid_subset_name_for_callback, + num_proc=args.workers) + else: + valid_ds_for_callback = load_dataset( + args.valid_name_or_path_for_callback, num_proc=args.workers) + valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map( + AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + num_proc=args.workers) + argument_kwargs = {} if args.push_to_hub: assert args.hub_model_id is not None, 'Please specify hub_mode_id via --hub_model_id xxx' @@ -254,11 +287,15 @@ def main(): model.fit( train_ds=train_ds, valid_ds=valid_ds, + valid_ds_for_callback=valid_ds_for_callback, output_dir=args.save_dir, batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, save_steps=args.save_steps, + save_strategy=args.save_strategy, + eval_steps=args.eval_steps, + evaluation_strategy=args.evaluation_strategy, warmup_steps=args.warmup_steps, logging_steps=args.logging_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -271,6 +308,7 @@ def main(): 'angle_tau': args.angle_tau, }, fp16=args.fp16, + filter_duplicate=args.filter_duplicate, argument_kwargs=argument_kwargs, apply_ese=args.apply_ese, trainer_kwargs=trainer_kwargs,