diff --git a/06_gpu_and_ml/flan_t5/flan_t5_finetune.py b/06_gpu_and_ml/flan_t5/flan_t5_finetune.py index 7d1964f30..b56c2fc7a 100644 --- a/06_gpu_and_ml/flan_t5/flan_t5_finetune.py +++ b/06_gpu_and_ml/flan_t5/flan_t5_finetune.py @@ -52,15 +52,16 @@ stub.restart_tracker_dict = modal.Dict.new() -def track_restarts(restart_tracker: modal.Dict): +def track_restarts(restart_tracker: modal.Dict) -> int: if not restart_tracker.contains("count"): preemption_count = 0 print(f"Starting first time. {preemption_count=}") - restart_tracker["count"] = preemption_count = 0 + restart_tracker["count"] = preemption_count else: preemption_count = restart_tracker.get("count") + 1 print(f"Restarting after pre-emption. {preemption_count=}") restart_tracker["count"] = preemption_count + return preemption_count # ## Finetuning Flan-T5 on XSum dataset @@ -72,6 +73,7 @@ def track_restarts(restart_tracker: modal.Dict): gpu="A10g", timeout=7200, volumes={VOL_MOUNT_PATH: output_vol}, + _allow_background_volume_commits=True, ) def finetune(num_train_epochs: int = 1, size_percentage: int = 10): from datasets import load_dataset @@ -81,10 +83,9 @@ def finetune(num_train_epochs: int = 1, size_percentage: int = 10): DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, - TrainerCallback, ) - track_restarts(stub.restart_tracker_dict) + restarts = track_restarts(stub.restart_tracker_dict) # Use size percentage to retrieve subset of the dataset to iterate faster if size_percentage: @@ -148,18 +149,6 @@ def preprocess(batch): pad_to_multiple_of=batch_size, ) - class CheckpointCallback(TrainerCallback): - def __init__(self, volume): - self.volume = volume - - def on_save(self, args, state, control, **kwargs): - """ - Event called after a checkpoint save. - """ - if state.is_world_process_zero: - print("running commit on modal.Volume after model checkpoint") - self.volume.commit() - training_args = Seq2SeqTrainingArguments( # Save checkpoints to the mounted volume output_dir=str(VOL_MOUNT_PATH / "model"), @@ -172,7 +161,7 @@ def on_save(self, args, state, control, **kwargs): logging_steps=100, evaluation_strategy="steps", save_strategy="steps", - save_steps=750, + save_steps=100, save_total_limit=2, load_best_model_at_end=True, ) @@ -180,14 +169,16 @@ def on_save(self, args, state, control, **kwargs): trainer = Seq2SeqTrainer( model=model, args=training_args, - callbacks=[CheckpointCallback(output_vol)], data_collator=data_collator, train_dataset=tokenized_xsum_train, eval_dataset=tokenized_xsum_test, ) try: - trainer.train(resume_from_checkpoint=True) + resume = restarts > 0 + if resume: + print("resuming from checkpoint") + trainer.train(resume_from_checkpoint=resume) except KeyboardInterrupt: # handle possible preemption print("received interrupt; saving state and model") trainer.save_state()