diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 895e689..3235e7b 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -1048,39 +1048,6 @@ def __call__(self, return loss -class EvaluateCallback(TrainerCallback): - """ - Custom TrainerCallback for Angle. - This callback will compute corrcoef for each epoch. - - :param model: PreTrainedModel. - :param valid_ds: Dataset. - :param evaluate_fn: Callable. It will receive valid_ds as input like `evaluate_fn(valid_ds)`. - :param save_dir: Optional[str]. specify dir to save model with best results. - """ - def __init__(self, - model: PreTrainedModel, - valid_ds: Dataset, - evaluate_fn: Callable, - save_dir: Optional[str] = None): - super().__init__() - self.model = model - self.valid_ds = valid_ds - self.evaluate_fn = evaluate_fn - self.save_dir = save_dir - self.best_corrcoef = 0 - - def on_epoch_end(self, args, state, control, **kwargs): - corrcoef, accuracy = self.evaluate_fn(self.valid_ds) - if corrcoef > self.best_corrcoef: - self.best_corrcoef = corrcoef - print('new best corrcoef!') - if self.save_dir is not None: - self.model.save_pretrained(self.save_dir) - print(f'save to {self.save_dir}') - print(f'corrcoef: {corrcoef}, accuracy: {accuracy}, best corrcoef: {self.best_corrcoef}') - - class AnglE: """ AnglE. Everything is heređź‘‹ @@ -1493,10 +1460,14 @@ def fit(self, best_ckpt_dir = None if output_dir is not None: best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint') - self.tokenizer.save_pretrained(best_ckpt_dir) - evaluate_callback = EvaluateCallback(self.backbone, valid_ds, + evaluate_callback = EvaluateCallback(self, valid_ds, partial(self.evaluate, batch_size=batch_size, device=self.device), - save_dir=best_ckpt_dir) + save_dir=best_ckpt_dir, + push_to_hub=push_to_hub, + hub_model_id=hub_model_id, + hub_private_repo=hub_private_repo) + # set False to ensure only best checkpoint is pushed + argument_kwargs['push_to_hub'] = False callbacks = [evaluate_callback] CustomTrainer = AngleESETrainer if apply_ese else AngleTrainer @@ -1662,3 +1633,48 @@ def save_pretrained(self, output_dir: str, exist_ok: bool = True): os.makedirs(output_dir) self.tokenizer.save_pretrained(output_dir) self.backbone.save_pretrained(output_dir) + + +class EvaluateCallback(TrainerCallback): + """ + Custom TrainerCallback for Angle. + This callback will compute corrcoef for each epoch. + + :param model: PreTrainedModel. + :param valid_ds: Dataset. + :param evaluate_fn: Callable. It will receive valid_ds as input like `evaluate_fn(valid_ds)`. + :param save_dir: Optional[str]. specify dir to save model with best results. + """ + def __init__(self, + model: AnglE, + valid_ds: Dataset, + evaluate_fn: Callable, + save_dir: Optional[str] = None, + push_to_hub: bool = False, + hub_model_id: Optional[str] = None, + hub_private_repo: bool = True): + super().__init__() + self.model = model + self.valid_ds = valid_ds + self.evaluate_fn = evaluate_fn + self.save_dir = save_dir + self.best_corrcoef = 0 + self.push_to_hub = push_to_hub + self.hub_model_id = hub_model_id + self.hub_private_repo = hub_private_repo + + def on_epoch_end(self, args, state, control, **kwargs): + corrcoef, accuracy = self.evaluate_fn(self.valid_ds) + if corrcoef > self.best_corrcoef: + self.best_corrcoef = corrcoef + print('new best corrcoef!') + if self.save_dir is not None: + self.model.save_pretrained(self.save_dir) + print(f'save to {self.save_dir}') + if self.push_to_hub: + self.model.push_to_hub( + self.hub_model_id, + private=self.hub_private_repo, + exist_ok=True, + commit_message='new best checkpoint') + print(f'corrcoef: {corrcoef}, accuracy: {accuracy}, best corrcoef: {self.best_corrcoef}')