Skip to content

Commit

Permalink
push best checkpoint to hub
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Jun 29, 2024
1 parent db997f2 commit 8374b95
Showing 1 changed file with 52 additions and 36 deletions.
88 changes: 52 additions & 36 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,39 +1048,6 @@ def __call__(self,
return loss


class EvaluateCallback(TrainerCallback):
"""
Custom TrainerCallback for Angle.
This callback will compute corrcoef for each epoch.
:param model: PreTrainedModel.
:param valid_ds: Dataset.
:param evaluate_fn: Callable. It will receive valid_ds as input like `evaluate_fn(valid_ds)`.
:param save_dir: Optional[str]. specify dir to save model with best results.
"""
def __init__(self,
model: PreTrainedModel,
valid_ds: Dataset,
evaluate_fn: Callable,
save_dir: Optional[str] = None):
super().__init__()
self.model = model
self.valid_ds = valid_ds
self.evaluate_fn = evaluate_fn
self.save_dir = save_dir
self.best_corrcoef = 0

def on_epoch_end(self, args, state, control, **kwargs):
corrcoef, accuracy = self.evaluate_fn(self.valid_ds)
if corrcoef > self.best_corrcoef:
self.best_corrcoef = corrcoef
print('new best corrcoef!')
if self.save_dir is not None:
self.model.save_pretrained(self.save_dir)
print(f'save to {self.save_dir}')
print(f'corrcoef: {corrcoef}, accuracy: {accuracy}, best corrcoef: {self.best_corrcoef}')


class AnglE:
"""
AnglE. Everything is here👋
Expand Down Expand Up @@ -1493,10 +1460,14 @@ def fit(self,
best_ckpt_dir = None
if output_dir is not None:
best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint')
self.tokenizer.save_pretrained(best_ckpt_dir)
evaluate_callback = EvaluateCallback(self.backbone, valid_ds,
evaluate_callback = EvaluateCallback(self, valid_ds,
partial(self.evaluate, batch_size=batch_size, device=self.device),
save_dir=best_ckpt_dir)
save_dir=best_ckpt_dir,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
hub_private_repo=hub_private_repo)
# set False to ensure only best checkpoint is pushed
argument_kwargs['push_to_hub'] = False
callbacks = [evaluate_callback]

CustomTrainer = AngleESETrainer if apply_ese else AngleTrainer
Expand Down Expand Up @@ -1662,3 +1633,48 @@ def save_pretrained(self, output_dir: str, exist_ok: bool = True):
os.makedirs(output_dir)
self.tokenizer.save_pretrained(output_dir)
self.backbone.save_pretrained(output_dir)


class EvaluateCallback(TrainerCallback):
"""
Custom TrainerCallback for Angle.
This callback will compute corrcoef for each epoch.
:param model: PreTrainedModel.
:param valid_ds: Dataset.
:param evaluate_fn: Callable. It will receive valid_ds as input like `evaluate_fn(valid_ds)`.
:param save_dir: Optional[str]. specify dir to save model with best results.
"""
def __init__(self,
model: AnglE,
valid_ds: Dataset,
evaluate_fn: Callable,
save_dir: Optional[str] = None,
push_to_hub: bool = False,
hub_model_id: Optional[str] = None,
hub_private_repo: bool = True):
super().__init__()
self.model = model
self.valid_ds = valid_ds
self.evaluate_fn = evaluate_fn
self.save_dir = save_dir
self.best_corrcoef = 0
self.push_to_hub = push_to_hub
self.hub_model_id = hub_model_id
self.hub_private_repo = hub_private_repo

def on_epoch_end(self, args, state, control, **kwargs):
corrcoef, accuracy = self.evaluate_fn(self.valid_ds)
if corrcoef > self.best_corrcoef:
self.best_corrcoef = corrcoef
print('new best corrcoef!')
if self.save_dir is not None:
self.model.save_pretrained(self.save_dir)
print(f'save to {self.save_dir}')
if self.push_to_hub:
self.model.push_to_hub(
self.hub_model_id,
private=self.hub_private_repo,
exist_ok=True,
commit_message='new best checkpoint')
print(f'corrcoef: {corrcoef}, accuracy: {accuracy}, best corrcoef: {self.best_corrcoef}')

0 comments on commit 8374b95

Please sign in to comment.