Skip to content

Commit

Permalink
Use background commits in flan t5 finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Dec 1, 2023
1 parent 6a335d7 commit 2e563e9
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions 06_gpu_and_ml/flan_t5/flan_t5_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@
stub.restart_tracker_dict = modal.Dict.new()


def track_restarts(restart_tracker: modal.Dict):
def track_restarts(restart_tracker: modal.Dict) -> int:
if not restart_tracker.contains("count"):
preemption_count = 0
print(f"Starting first time. {preemption_count=}")
restart_tracker["count"] = preemption_count = 0
restart_tracker["count"] = preemption_count
else:
preemption_count = restart_tracker.get("count") + 1
print(f"Restarting after pre-emption. {preemption_count=}")
restart_tracker["count"] = preemption_count
return preemption_count


# ## Finetuning Flan-T5 on XSum dataset
Expand All @@ -72,19 +73,22 @@ def track_restarts(restart_tracker: modal.Dict):
gpu="A10g",
timeout=7200,
volumes={VOL_MOUNT_PATH: output_vol},
_allow_background_volume_commits=True,
)
def finetune(num_train_epochs: int = 1, size_percentage: int = 10):
from modal.exception import simulate_preemption

simulate_preemption(300)
from datasets import load_dataset
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
TrainerCallback,
)

track_restarts(stub.restart_tracker_dict)
restarts = track_restarts(stub.restart_tracker_dict)

# Use size percentage to retrieve subset of the dataset to iterate faster
if size_percentage:
Expand Down Expand Up @@ -148,18 +152,6 @@ def preprocess(batch):
pad_to_multiple_of=batch_size,
)

class CheckpointCallback(TrainerCallback):
def __init__(self, volume):
self.volume = volume

def on_save(self, args, state, control, **kwargs):
"""
Event called after a checkpoint save.
"""
if state.is_world_process_zero:
print("running commit on modal.Volume after model checkpoint")
self.volume.commit()

training_args = Seq2SeqTrainingArguments(
# Save checkpoints to the mounted volume
output_dir=str(VOL_MOUNT_PATH / "model"),
Expand All @@ -172,22 +164,24 @@ def on_save(self, args, state, control, **kwargs):
logging_steps=100,
evaluation_strategy="steps",
save_strategy="steps",
save_steps=750,
save_steps=100,
save_total_limit=2,
load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
model=model,
args=training_args,
callbacks=[CheckpointCallback(output_vol)],
data_collator=data_collator,
train_dataset=tokenized_xsum_train,
eval_dataset=tokenized_xsum_test,
)

try:
trainer.train(resume_from_checkpoint=True)
resume = restarts > 0
if resume:
print("resuming from checkpoint")
trainer.train(resume_from_checkpoint=resume)
except KeyboardInterrupt: # handle possible preemption
print("received interrupt; saving state and model")
trainer.save_state()
Expand Down

0 comments on commit 2e563e9

Please sign in to comment.