Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regression tests for training #730

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import timedelta
from typing import List

import pytest
import torch.distributed as dist

from olmo.config import (
DataConfig,
Expand Down Expand Up @@ -96,3 +98,25 @@ def lorem_ipsum_docs() -> List[str]:
@pytest.fixture(scope="function")
def model_path() -> str:
return "test_fixtures/test-olmo-model"


@pytest.fixture(scope="function")
def xtiny_model_path() -> str:
return "test_fixtures/test-olmo-model-xtiny"


@pytest.fixture(scope="function")
def make_process_group():
initialized = False

def init():
nonlocal initialized

dist.init_process_group(
backend="nccl", timeout=timedelta(minutes=30), world_size=1, rank=0, store=dist.HashStore()
)
initialized = True

yield init
if initialized:
dist.destroy_process_group()
25 changes: 25 additions & 0 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .aliases import PathOrStr
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
from .exceptions import OLMoCheckpointError
from .model import OLMo
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import (
Expand Down Expand Up @@ -653,6 +654,17 @@ def save_checkpoint(
model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
)

# Then get the optimizer state dict
optim_state_dict = optim.state_dict()
self._write_optim_dict(
optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
)
elif isinstance(dist_model, OLMo):
model_state_dict = dist_model.state_dict()
self._write_model_dict(
model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
)

# Then get the optimizer state dict
optim_state_dict = optim.state_dict()
self._write_optim_dict(
Expand Down Expand Up @@ -775,6 +787,19 @@ def restore_checkpoint(
gc.collect()
torch.cuda.empty_cache()
barrier()
elif isinstance(dist_model, OLMo):
with torch.no_grad():
state_dict_to_load = load_state_dict(
load_path, "model.pt", local_cache=local_cache, map_location="cpu"
)
dist_model.load_state_dict(state_dict_to_load)

# Load optimizer state.
if load_optimizer_state:
optim_state_dict_to_load = load_state_dict(
load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
)
optim.load_state_dict(optim_state_dict_to_load)
else:
raise NotImplementedError(
"`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!"
Expand Down
32 changes: 20 additions & 12 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
get_fs_local_rank,
get_global_rank,
get_world_size,
is_distributed,
move_to_device,
peak_gpu_memory,
synchronize_flag,
Expand Down Expand Up @@ -208,7 +209,7 @@ def fused_loss_fn(
class Trainer:
cfg: TrainConfig
model: OLMo
dist_model: Union[DDP, FSDP]
dist_model: Union[DDP, FSDP, OLMo]
optim: Optimizer
scheduler: Scheduler
train_loader: DataLoader
Expand Down Expand Up @@ -319,7 +320,7 @@ def scheduler_max(self) -> int:
raise NotImplementedError(self.cfg.scheduler.units)

def trainer_state_dict(self) -> Dict[str, Any]:
return {
state_dict: Dict[str, Any] = {
"epoch": self.epoch or 0,
"global_step": self.global_step,
"global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch,
Expand All @@ -332,10 +333,14 @@ def trainer_state_dict(self) -> Dict[str, Any]:
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
},
}

if torch.cuda.is_available():
state_dict["rng"]["cuda"] = torch.cuda.get_rng_state()

return state_dict

def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Checkpoint paths.
self.checkpoints = [
Expand Down Expand Up @@ -429,7 +434,8 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None:
random.setstate(rng_state["python"])
np.random.set_state(rng_state["numpy"])
torch.set_rng_state(rng_state["torch"])
torch.cuda.set_rng_state(rng_state["cuda"])
if "cuda" in rng_state:
torch.cuda.set_rng_state(rng_state["cuda"])

def _save_checkpoint(
self, checkpointer: Checkpointer, checkpoint_type: CheckpointType
Expand Down Expand Up @@ -674,26 +680,28 @@ def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.ut
)
trace_save_folder.mkdir(parents=True)

module_num = 0

def trace_outputs_hook(
module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor
) -> None:
nonlocal module_num

if len(args) == 0:
log.info("No input args for module %s, output %s", module_name, output)

module_input = args[0] if len(args) > 0 else torch.tensor(())
trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
trace_save_folder.mkdir(parents=True, exist_ok=True)

module_occurence_num = 0
while (
module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt"
).exists():
module_occurence_num += 1
module_input_filepath = trace_save_folder / f"{module_name}_{module_num}_input.pt"
torch.save(module_input, module_input_filepath)

module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt"
module_output_filepath = trace_save_folder / f"{module_name}_{module_num}_output.pt"
torch.save(output, module_output_filepath)

module_num += 1

output_hooks = []
for module_name, module in self.model.named_modules(prefix="model"):
output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name)))
Expand Down Expand Up @@ -838,7 +846,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
ce_batch_loss, z_batch_loss = self.train_batch(batch)

# Collect loss, potentially reducing over all ranks.
if reduce_global_loss:
if reduce_global_loss and get_world_size() > 1:
dist.reduce(ce_batch_loss, 0)
ce_batch_loss.div_(get_world_size())
if z_batch_loss is not None:
Expand All @@ -852,7 +860,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
collect_param_metrics=should_log_optim_metrics_this_step,
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.dist_model.process_group,
process_group=self.dist_model.process_group if is_distributed() else None,
)

# Adjust the learning rate.
Expand Down
Binary file added test_fixtures/random_data.npy
Binary file not shown.
101 changes: 101 additions & 0 deletions test_fixtures/test-olmo-model-xtiny/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
run_name: test-olmo-model-xtiny
seed: 12345
dry_run: false

wandb: null

model:
d_model: 32
n_heads: 4
n_layers: 2
mlp_ratio: 4
weight_tying: false
alibi: false
rope: true
rope_theta: 500000
flash_attention: false
attention_dropout: 0.0
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
layer_norm_eps: 1e-6
bias_for_layer_norm: false
attention_layer_norm: true
attention_layer_norm_with_affine: true
norm_after: true
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 1024
vocab_size: 100
embedding_size: 128
eos_token_id: 1
pad_token_id: 2
init_device: meta
init_fn: normal
init_std: 0.02
init_cutoff_factor: 3

softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: false

compile: null

optimizer:
name: adamw
learning_rate: 4.0e-4
weight_decay: 0.1
eps: 1e-8
decay_norm_and_bias: true
decay_embeddings: false
betas:
- 0.9
- 0.95
metrics_log_interval: 1

scheduler:
name: cosine_with_warmup
units: steps
t_warmup: 1
t_max: 1000
alpha_f: 0.1
warmup_min_lr: 0.0

tokenizer:
identifier: allenai/dolma2-tokenizer
truncate_direction: right

save_folder: /weka/oe-training-default/ai2-llm/checkpoints/OLMo-small/${run_name}
save_overwrite: false

save_interval: null
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core

save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
global_train_batch_size: 4
device_train_microbatch_size: 2

precision: amp_bf16

fsdp:
wrapping_strategy: null
sharding_strategy: SHARD_GRAD_OP
precision: mixed

# activation_checkpointing: whole_layer

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 1

gen1_gc_interval: 10
Binary file added test_fixtures/test-olmo-model-xtiny/model.pt
Binary file not shown.
Binary file added test_fixtures/test-olmo-model-xtiny/optim.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test_fixtures/test-olmo-model-xtiny/train.pt
Binary file not shown.
Loading
Loading