Skip to content

Commit

Permalink
[feature] add pytorch profiling (#2182)
Browse files Browse the repository at this point in the history
* add pytorch profiling

* kick off the profiler asap since things may get allcoated before train start

* document feature

* add url for visualizer [skip ci]
  • Loading branch information
winglian authored Dec 16, 2024
1 parent 1720cc0 commit b4eebc3
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]

profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
# snapshots can be visualized @ https://pytorch.org/memory_viz

loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
Expand Down Expand Up @@ -1363,6 +1364,13 @@ def get_callbacks(self) -> List[TrainerCallback]:
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)

if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)

if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
Expand Down
43 changes: 43 additions & 0 deletions src/axolotl/utils/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
HF Trainer callback for creating pytorch profiling snapshots
"""
from pathlib import Path
from pickle import dump # nosec B403

import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)


class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
"""

def __init__(self, steps_to_profile: int = 5):
self.steps_to_profile = steps_to_profile
if self.steps_to_profile:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)

def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.steps_to_profile:
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled=None
)
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ class Config:
load_best_model_at_end: Optional[bool] = False
save_only_model: Optional[bool] = False
use_tensorboard: Optional[bool] = None
profiler_steps: Optional[int] = None

neftune_noise_alpha: Optional[float] = None

Expand Down

0 comments on commit b4eebc3

Please sign in to comment.