Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timer impl #44

Merged
merged 16 commits into from
Dec 14, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Breaking changes

### Bug Fixes
- `exp_manager.max_time_per_run` is now respected, the trainers will save and run val if we've reached the time limit before exiting.
trias702 marked this conversation as resolved.
Show resolved Hide resolved

## [0.1.0] - 2023-12-04
### Added
Expand Down
5 changes: 4 additions & 1 deletion examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from functools import partial

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from omegaconf.omegaconf import OmegaConf

from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
Expand Down Expand Up @@ -131,6 +132,7 @@ def main(cfg) -> None:

logger.log_hyperparams(OmegaConf.to_container(cfg))

timer = Timer(cfg.exp_manager.get("max_time_per_run"))
dpo_trainer = DPOTrainer(
cfg=cfg.trainer.dpo,
model=ptl_model,
Expand All @@ -141,6 +143,7 @@ def main(cfg) -> None:
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel
from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMCriticClient
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
Expand Down Expand Up @@ -153,6 +154,7 @@ def main(cfg) -> None:
logger.log_hyperparams(OmegaConf.to_container(cfg))

rm_critic = RemoteGPTRMCriticClient(cfg.remote_critic_rm)
timer = Timer(cfg.exp_manager.get("max_time_per_run"))

ppo_trainer = PPOTrainer(
cfg=cfg.trainer.ppo,
Expand All @@ -164,6 +166,7 @@ def main(cfg) -> None:
rm_critic=rm_critic,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
Expand Down Expand Up @@ -214,6 +215,7 @@ def main(cfg) -> None:
ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model)

logger.log_hyperparams(OmegaConf.to_container(cfg))
timer = Timer(cfg.exp_manager.get("max_time_per_run"))

sft_trainer = SupervisedTrainer(
cfg=cfg.trainer.sft,
Expand All @@ -225,6 +227,7 @@ def main(cfg) -> None:
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
Expand Down
5 changes: 4 additions & 1 deletion examples/nlp/gpt/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
build_train_valid_test_regression_rm_datasets,
build_train_valid_test_rm_datasets,
)

from nemo_aligner.models.nlp.gpt.reward_model_classes import REWARD_MODEL_CLASS_DICT, RewardModelType
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
Expand Down Expand Up @@ -131,6 +131,8 @@ def main(cfg) -> None:

logger.log_hyperparams(OmegaConf.to_container(cfg))

timer = Timer(cfg.exp_manager.get("max_time_per_run"))

rm_trainer = SupervisedTrainer(
cfg=cfg.trainer.rm,
model=ptl_model,
Expand All @@ -141,6 +143,7 @@ def main(cfg) -> None:
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
Expand Down
13 changes: 13 additions & 0 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MegatronPretrainingRandomBatchSampler,
)
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.utils import logging
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
test_dataloader,
logger,
ckpt_callback,
run_timer,
):
self.model = model
self.train_dataloader = train_dataloader
Expand All @@ -89,6 +91,9 @@ def __init__(
self.optimizer = optimizer
self.scheduler = scheduler

# this timer checks if we should stop training
self.run_timer = run_timer

self.step = 0
self.epoch = 0
self.consumed_samples = 0
Expand Down Expand Up @@ -188,6 +193,8 @@ def fit(self):
# epoch done
return

self.run_timer.start_time()

for _ in epoch_iter:
loop_iter = range(self.step, self.max_steps)

Expand Down Expand Up @@ -223,12 +230,14 @@ def fit(self):

self.step += 1

run_time_exceeded = self.run_timer.is_finished()
run_val, save_model, is_train_end = check_progress(
self.step,
self.max_steps,
self.cfg.val_check_interval,
self.cfg.save_interval,
self.limit_val_batches,
run_time_exceeded=run_time_exceeded,
)

if run_val:
Expand All @@ -246,6 +255,10 @@ def fit(self):
metrics = {k: torch.as_tensor(v) for k, v in metrics.items()}
self.save(metrics, is_train_end=is_train_end)

if run_time_exceeded:
logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run")
return

metrics.clear()

self.epoch += 1
Expand Down
31 changes: 23 additions & 8 deletions nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tqdm import tqdm

from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.utils import logging
from nemo_aligner.utils.distributed import (
SyncTimer,
masked_global_mean_var,
Expand All @@ -36,6 +37,7 @@
)
from nemo_aligner.utils.server_utils import FutureResult
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress
from nemo_aligner.utils.utils import clear_memory, cpu_dict, masked_mean


Expand All @@ -61,6 +63,7 @@ def __init__(
rm_critic,
logger,
ckpt_callback,
run_timer,
):
self.cfg = cfg
self.model = model
Expand All @@ -72,6 +75,9 @@ def __init__(
self.logger = logger
self.ckpt_callback = ckpt_callback

# this timer checks if we should stop training
self.run_timer = run_timer

self.consumed_samples = 0
self.epoch = 0
# the step here is PPO step
Expand All @@ -95,10 +101,6 @@ def __init__(
reduction="mean", sync_cuda=True, buffer_size=1, reduce_op=torch.distributed.ReduceOp.MAX
)

assert (
self.cfg.save_interval % self.cfg.val_check_interval == 0
), f"{self.cfg.save_interval=} must be divisible by {self.cfg.val_check_interval=}"

def generate_ppo_data(self, rollout_batches):
"""generate ppo specific data for training
"""
Expand Down Expand Up @@ -372,6 +374,7 @@ def fit(self):

num_to_load_on_each_dp = divide(self.cfg.model_gbs, dp_size)

self.run_timer.start_time()
for _ in global_pbar:
step_metrics = {}
timing_metrics = {}
Expand Down Expand Up @@ -412,8 +415,16 @@ def fit(self):

self.step += 1

is_train_end = self.step == self.max_steps
run_val = (self.step % self.cfg.val_check_interval == 0) or is_train_end
run_time_exceeded = self.run_timer.is_finished()
run_val, save_model, is_train_end = check_progress(
self.step,
self.max_steps,
self.cfg.val_check_interval,
self.cfg.save_interval,
1.0, # TODO:(geshen): allow for limit val batches
run_time_exceeded=run_time_exceeded,
)

if run_val:
self.timer.start("validation_time")
val_metrics = self.run_validation()
Expand All @@ -439,10 +450,14 @@ def fit(self):
step_metrics.update({f"train_{k}": v for k, v in metrics.items()})
global_pbar.set_postfix(step_metrics)

step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()}
if run_val and (self.step % self.cfg.save_interval == 0 or is_train_end):
if save_model:
step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()}
self.save(step_metrics, is_train_end=is_train_end)

if run_time_exceeded:
logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run")
return

self.epoch += 1

self.logger.finalize()
Expand Down
14 changes: 14 additions & 0 deletions nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
from nemo.utils import logging


from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
Expand All @@ -40,6 +42,7 @@ def __init__(
test_dataloader,
logger,
ckpt_callback,
run_timer,
):
self.model = model
self.train_dataloader = train_dataloader
Expand All @@ -50,6 +53,9 @@ def __init__(
self.optimizer = optimizer
self.scheduler = scheduler

# this timer checks if we should stop training
self.run_timer = run_timer

self.step = 0
self.epoch = 0
self.consumed_samples = 0
Expand Down Expand Up @@ -138,6 +144,8 @@ def fit(self):
# epoch done
return

self.run_timer.start_time()

for _ in epoch_iter:
loop_iter = range(self.step, self.max_steps)

Expand Down Expand Up @@ -166,12 +174,14 @@ def fit(self):

self.step += 1

run_time_exceeded = self.run_timer.is_finished()
run_val, save_model, is_train_end = check_progress(
self.step,
self.max_steps,
self.cfg.val_check_interval,
self.cfg.save_interval,
self.limit_val_batches,
run_time_exceeded=run_time_exceeded,
)

if run_val:
Expand All @@ -189,6 +199,10 @@ def fit(self):
metrics = {k: torch.as_tensor(v) for k, v in metrics.items()}
self.save(metrics, is_train_end=is_train_end)

if run_time_exceeded:
logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run")
return
odelalleau marked this conversation as resolved.
Show resolved Hide resolved

metrics.clear()

self.epoch += 1
Expand Down
38 changes: 38 additions & 0 deletions nemo_aligner/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@

"""distributed utils for communicating between different ranks"""

import time
import warnings
from collections import defaultdict
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, Optional, Union

import torch
from megatron.core import parallel_state, tensor_parallel
Expand Down Expand Up @@ -320,3 +324,37 @@ def sync_and_consume_over_stored_time(self, name=""):
yield from output_list

del self.stored_results[name]


@dataclass
class Timer:
"""Timer to tell us when the time limit is reached
"""

duration: Optional[str]

def __post_init__(self):
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
self._duration = float("inf")

if self.duration is not None:
days, hours, mins, seconds = map(int, self.duration.strip().split(":"))
self._duration = timedelta(days=days, hours=hours, minutes=mins, seconds=seconds).total_seconds()

def start_time(self):
self._start_time = time.monotonic()
gshennvm marked this conversation as resolved.
Show resolved Hide resolved

def get_time_elapsed(self):
return time.monotonic() - self._start_time

def get_time_remaining(self):
return self._duration - self.get_time_elapsed()

def is_finished(self):
time_left = self.get_time_remaining()

is_finished = time_left <= 0
is_finished_tensor = torch.tensor([is_finished], dtype=torch.bool, device="cuda")

# only respect rank 0 timing
torch.distributed.broadcast(is_finished_tensor, 0)
return is_finished_tensor.item()
Loading
Loading