diff --git a/conftest.py b/conftest.py index a1ab2aed5..5d87fd73c 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,8 @@ +from datetime import timedelta from typing import List import pytest +import torch.distributed as dist from olmo.config import ( DataConfig, @@ -96,3 +98,25 @@ def lorem_ipsum_docs() -> List[str]: @pytest.fixture(scope="function") def model_path() -> str: return "test_fixtures/test-olmo-model" + + +@pytest.fixture(scope="function") +def xtiny_model_path() -> str: + return "test_fixtures/test-olmo-model-xtiny" + + +@pytest.fixture(scope="function") +def make_process_group(): + initialized = False + + def init(): + nonlocal initialized + + dist.init_process_group( + backend="nccl", timeout=timedelta(minutes=30), world_size=1, rank=0, store=dist.HashStore() + ) + initialized = True + + yield init + if initialized: + dist.destroy_process_group() diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 544441450..6f6278435 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -50,6 +50,7 @@ from .aliases import PathOrStr from .config import BaseConfig, ShardedCheckpointerType, TrainConfig from .exceptions import OLMoCheckpointError +from .model import OLMo from .optim import Optimizer, fix_optim_state_dict from .safetensors_util import safetensors_file_to_state_dict from .torch_util import ( @@ -653,6 +654,17 @@ def save_checkpoint( model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite ) + # Then get the optimizer state dict + optim_state_dict = optim.state_dict() + self._write_optim_dict( + optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite + ) + elif isinstance(dist_model, OLMo): + model_state_dict = dist_model.state_dict() + self._write_model_dict( + model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite + ) + # Then get the optimizer state dict optim_state_dict = optim.state_dict() self._write_optim_dict( @@ -775,6 +787,19 @@ def restore_checkpoint( gc.collect() torch.cuda.empty_cache() barrier() + elif isinstance(dist_model, OLMo): + with torch.no_grad(): + state_dict_to_load = load_state_dict( + load_path, "model.pt", local_cache=local_cache, map_location="cpu" + ) + dist_model.load_state_dict(state_dict_to_load) + + # Load optimizer state. + if load_optimizer_state: + optim_state_dict_to_load = load_state_dict( + load_path, "optim.pt", local_cache=local_cache, map_location="cpu" + ) + optim.load_state_dict(optim_state_dict_to_load) else: raise NotImplementedError( "`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!" diff --git a/olmo/train.py b/olmo/train.py index 341055003..74ac49a04 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -51,6 +51,7 @@ get_fs_local_rank, get_global_rank, get_world_size, + is_distributed, move_to_device, peak_gpu_memory, synchronize_flag, @@ -208,7 +209,7 @@ def fused_loss_fn( class Trainer: cfg: TrainConfig model: OLMo - dist_model: Union[DDP, FSDP] + dist_model: Union[DDP, FSDP, OLMo] optim: Optimizer scheduler: Scheduler train_loader: DataLoader @@ -319,7 +320,7 @@ def scheduler_max(self) -> int: raise NotImplementedError(self.cfg.scheduler.units) def trainer_state_dict(self) -> Dict[str, Any]: - return { + state_dict: Dict[str, Any] = { "epoch": self.epoch or 0, "global_step": self.global_step, "global_train_examples_seen_this_epoch": self.global_train_examples_seen_this_epoch, @@ -332,10 +333,14 @@ def trainer_state_dict(self) -> Dict[str, Any]: "python": random.getstate(), "numpy": np.random.get_state(), "torch": torch.random.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), }, } + if torch.cuda.is_available(): + state_dict["rng"]["cuda"] = torch.cuda.get_rng_state() + + return state_dict + def load_trainer_state_dict(self, state_dict: Dict[str, Any]) -> None: # Checkpoint paths. self.checkpoints = [ @@ -429,7 +434,8 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - torch.cuda.set_rng_state(rng_state["cuda"]) + if "cuda" in rng_state: + torch.cuda.set_rng_state(rng_state["cuda"]) def _save_checkpoint( self, checkpointer: Checkpointer, checkpoint_type: CheckpointType @@ -674,9 +680,13 @@ def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.ut ) trace_save_folder.mkdir(parents=True) + module_num = 0 + def trace_outputs_hook( module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor ) -> None: + nonlocal module_num + if len(args) == 0: log.info("No input args for module %s, output %s", module_name, output) @@ -684,16 +694,14 @@ def trace_outputs_hook( trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}" trace_save_folder.mkdir(parents=True, exist_ok=True) - module_occurence_num = 0 - while ( - module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt" - ).exists(): - module_occurence_num += 1 + module_input_filepath = trace_save_folder / f"{module_name}_{module_num}_input.pt" torch.save(module_input, module_input_filepath) - module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt" + module_output_filepath = trace_save_folder / f"{module_name}_{module_num}_output.pt" torch.save(output, module_output_filepath) + module_num += 1 + output_hooks = [] for module_name, module in self.model.named_modules(prefix="model"): output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name))) @@ -838,7 +846,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> ce_batch_loss, z_batch_loss = self.train_batch(batch) # Collect loss, potentially reducing over all ranks. - if reduce_global_loss: + if reduce_global_loss and get_world_size() > 1: dist.reduce(ce_batch_loss, 0) ce_batch_loss.div_(get_world_size()) if z_batch_loss is not None: @@ -852,7 +860,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> collect_param_metrics=should_log_optim_metrics_this_step, # passing this process group here ensures metrics are reduced correctly when we're using # HYBRID sharding. - process_group=self.dist_model.process_group, + process_group=self.dist_model.process_group if is_distributed() else None, ) # Adjust the learning rate. diff --git a/test_fixtures/random_data.npy b/test_fixtures/random_data.npy new file mode 100644 index 000000000..69a39a8c5 Binary files /dev/null and b/test_fixtures/random_data.npy differ diff --git a/test_fixtures/test-olmo-model-xtiny/config.yaml b/test_fixtures/test-olmo-model-xtiny/config.yaml new file mode 100644 index 000000000..c4387722e --- /dev/null +++ b/test_fixtures/test-olmo-model-xtiny/config.yaml @@ -0,0 +1,101 @@ +run_name: test-olmo-model-xtiny +seed: 12345 +dry_run: false + +wandb: null + +model: + d_model: 32 + n_heads: 4 + n_layers: 2 + mlp_ratio: 4 + weight_tying: false + alibi: false + rope: true + rope_theta: 500000 + flash_attention: false + attention_dropout: 0.0 + include_bias: false + block_type: sequential + layer_norm_type: rms + layer_norm_with_affine: true + layer_norm_eps: 1e-6 + bias_for_layer_norm: false + attention_layer_norm: true + attention_layer_norm_with_affine: true + norm_after: true + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 1024 + vocab_size: 100 + embedding_size: 128 + eos_token_id: 1 + pad_token_id: 2 + init_device: meta + init_fn: normal + init_std: 0.02 + init_cutoff_factor: 3 + +softmax_auxiliary_loss: true +auxiliary_loss_multiplier: 1e-5 +fused_loss: false + +compile: null + +optimizer: + name: adamw + learning_rate: 4.0e-4 + weight_decay: 0.1 + eps: 1e-8 + decay_norm_and_bias: true + decay_embeddings: false + betas: + - 0.9 + - 0.95 + metrics_log_interval: 1 + +scheduler: + name: cosine_with_warmup + units: steps + t_warmup: 1 + t_max: 1000 + alpha_f: 0.1 + warmup_min_lr: 0.0 + +tokenizer: + identifier: allenai/dolma2-tokenizer + truncate_direction: right + +save_folder: /weka/oe-training-default/ai2-llm/checkpoints/OLMo-small/${run_name} +save_overwrite: false + +save_interval: null +save_num_checkpoints_to_keep: -1 +sharded_checkpointer: olmo_core + +save_interval_unsharded: null +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 1ep +global_train_batch_size: 4 +device_train_microbatch_size: 2 + +precision: amp_bf16 + +fsdp: + wrapping_strategy: null + sharding_strategy: SHARD_GRAD_OP + precision: mixed + +# activation_checkpointing: whole_layer + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 1 + +gen1_gc_interval: 10 \ No newline at end of file diff --git a/test_fixtures/test-olmo-model-xtiny/model.pt b/test_fixtures/test-olmo-model-xtiny/model.pt new file mode 100644 index 000000000..585dbace2 Binary files /dev/null and b/test_fixtures/test-olmo-model-xtiny/model.pt differ diff --git a/test_fixtures/test-olmo-model-xtiny/optim.pt b/test_fixtures/test-olmo-model-xtiny/optim.pt new file mode 100644 index 000000000..cc501fff1 Binary files /dev/null and b/test_fixtures/test-olmo-model-xtiny/optim.pt differ diff --git a/test_fixtures/test-olmo-model-xtiny/traces_cpu.tar.gz b/test_fixtures/test-olmo-model-xtiny/traces_cpu.tar.gz new file mode 100644 index 000000000..d4ea97e57 Binary files /dev/null and b/test_fixtures/test-olmo-model-xtiny/traces_cpu.tar.gz differ diff --git a/test_fixtures/test-olmo-model-xtiny/traces_cuda.tar.gz b/test_fixtures/test-olmo-model-xtiny/traces_cuda.tar.gz new file mode 100644 index 000000000..78943866f Binary files /dev/null and b/test_fixtures/test-olmo-model-xtiny/traces_cuda.tar.gz differ diff --git a/test_fixtures/test-olmo-model-xtiny/train.pt b/test_fixtures/test-olmo-model-xtiny/train.pt new file mode 100644 index 000000000..9a790592b Binary files /dev/null and b/test_fixtures/test-olmo-model-xtiny/train.pt differ diff --git a/tests/train_test.py b/tests/train_test.py index 2b9b3d115..b55902596 100644 --- a/tests/train_test.py +++ b/tests/train_test.py @@ -1,8 +1,22 @@ +import logging +import shutil +from pathlib import Path +from typing import List, Optional, Union + import pytest import torch +import torch.multiprocessing as mp +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close -from olmo.train import cross_entropy_loss, fused_loss_fn +from olmo.config import DistributedStrategy, TrainConfig +from olmo.data import build_train_dataloader +from olmo.model import OLMo +from olmo.optim import build_optimizer, build_scheduler +from olmo.train import Trainer, cross_entropy_loss, fused_loss_fn + +logger = logging.getLogger(__name__) @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device") @@ -21,3 +35,323 @@ def test_fused_loss(batch_size, seq_len, vocab_size, z_loss_multiplier): # Note: This is allowing for very big differences! assert_close(loss, f_loss, atol=1e-2, rtol=1e-3) assert_close(z_loss, f_z_loss, atol=1e-2, rtol=1e-3) + + +def _get_module_names(checkpoint_traces_folder: Path) -> List[str]: + module_names = [] + for trace_file in checkpoint_traces_folder.iterdir(): + trace_file_name = trace_file.name + if trace_file_name.endswith("_input.pt"): + module_name = trace_file_name.removesuffix("_input.pt") + elif trace_file_name.endswith("_output.pt"): + module_name = trace_file_name.removesuffix("_output.pt") + else: + assert False, f"Cannot get parameter from trace file {trace_file_name}" + + module_names.append(module_name) + + return module_names + + +def _compare_module_output( + original_traces_folder: Path, + new_traces_folder: Path, + module_name: str, + *, + include_non_tensor_outputs: bool = True, +): + original_module_input_path = original_traces_folder / f"{module_name}_input.pt" + original_module_output_path = original_traces_folder / f"{module_name}_output.pt" + new_module_input_path = new_traces_folder / f"{module_name}_input.pt" + new_module_output_path = new_traces_folder / f"{module_name}_output.pt" + + map_location = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + original_input = torch.load(str(original_module_input_path), map_location=map_location) + new_input = torch.load(str(new_module_input_path), map_location=map_location) + + assert ( + original_input.dtype == new_input.dtype + ), f"{module_name} input dtype is different for new model. Original {original_input.dtype}, new {new_input.dtype}" + assert ( + original_input.shape == new_input.shape + ), f"{module_name} input shape is different for new model. Original {original_input.shape}, new {new_input.shape}" + if "wte" in module_name: + mismatching_element_count = torch.sum(torch.logical_not(torch.eq(original_input, new_input))) + assert ( + mismatching_element_count == 0 + ), f"Number of {module_name} mis-matching inputs: {mismatching_element_count}" + + if (norm := torch.linalg.vector_norm((new_input - original_input).float())) > 0.0: + logger.info("Difference of norm of %s input is non-trivial: %f", module_name, norm) + assert_close( + new_input, original_input, msg=lambda msg: f"{module_name} inputs are not sufficiently close.\n{msg}" + ) + + original_output = torch.load(str(original_module_output_path), map_location=map_location) + new_output = torch.load(str(new_module_output_path), map_location=map_location) + + if isinstance(original_output, torch.Tensor): + assert ( + original_output.dtype == new_output.dtype + ), f"{module_name} output dtype is different for new model. Original {original_output.dtype}, new {new_output.dtype}" + assert ( + original_output.shape == new_output.shape + ), f"{module_name} output shape is different for new model. Original {original_output.shape}, new {new_output.shape}" + if (norm := torch.linalg.vector_norm((new_output - original_output).float())) > 0.0: + logger.info("Difference of norm of %s output is non-trivial: %f", module_name, norm) + assert_close( + new_output, + original_output, + msg=lambda msg: f"{module_name} outputs are not sufficiently close.\n{msg}", + ) + elif include_non_tensor_outputs: + pass + # logger.info("%s outputs: %s %s", module_name, original_output, new_output) + + +def _compare_module_outputs( + original_traces_folder: Path, + new_traces_folder: Path, + *, + include_non_tensor_outputs: bool = True, +): + original_modules = set(_get_module_names(original_traces_folder)) + new_modules = set(_get_module_names(new_traces_folder)) + + original_only_modules = original_modules - new_modules + assert len(original_only_modules) == 0, f"Found modules only in base model: {', '.join(original_only_modules)}" + + new_only_modules = new_modules - original_modules + assert len(new_only_modules) == 0, f"Found modules only in new model: {', '.join(new_only_modules)}" + + common_modules = original_modules.intersection(new_modules) + for module_name in sorted(common_modules, key=lambda mod_name: int(mod_name.split("_")[-1])): + _compare_module_output( + original_traces_folder, + new_traces_folder, + module_name, + include_non_tensor_outputs=include_non_tensor_outputs, + ) + + +def _get_train_config(model_path: Path, save_folder: Path) -> TrainConfig: + cfg = TrainConfig.load(model_path / "config.yaml") + cfg.save_folder = str(save_folder) + cfg.data.paths = ["test_fixtures/random_data.npy"] + cfg.precision = "amp_bf16" + cfg.device_train_batch_size = 1 + cfg.global_train_batch_size = 1 + cfg.save_interval = None + cfg.save_interval_unsharded = 1000 + + # Keep model small enough + cfg.model.vocab_size = 100 + cfg.model.embedding_size = 128 + cfg.model.eos_token_id = 2 + cfg.model.pad_token_id = 3 + + # Need to set these to 0 to get deterministic results + cfg.model.attention_dropout = 0.0 + cfg.model.residual_dropout = 0.0 + cfg.model.embedding_dropout = 0.0 + + return cfg + + +def _get_dist_model( + cfg: TrainConfig, olmo_model: OLMo, distributed_strategy: Optional[DistributedStrategy] +) -> Union[FSDP, DDP, OLMo]: + if distributed_strategy is None: + return olmo_model + if distributed_strategy == DistributedStrategy.fsdp: + try: + mp.set_start_method("spawn", force=True) + except RuntimeError as e: + print(f"failed to set multiprocessing start method: {e}") + + # Set CUDA device. + torch.cuda.set_device("cuda:0") + + assert cfg.fsdp is not None + + def dummy_init_fn(module: torch.nn.Module) -> None: + module.to_empty(device=torch.device("cuda:0")) + + param_init_fn = dummy_init_fn + + return FSDP( + olmo_model, + sharding_strategy=cfg.fsdp.sharding_strategy, + mixed_precision=cfg.fsdp_precision, + auto_wrap_policy=olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy), + use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics + limit_all_gathers=True, + device_id=0, + param_init_fn=param_init_fn, + ) + if distributed_strategy == DistributedStrategy.ddp: + return DDP(olmo_model.to(torch.device("cuda:0"))) + + raise NotImplementedError + + +def _train_model( + model_path: str, + cfg: TrainConfig, + *, + distributed_strategy: Optional[DistributedStrategy] = None, + cuda: bool = False, + replace_existing_model: bool = False, + replace_existing_traces: bool = False, +): + device = torch.device("cuda") if cuda else torch.device("cpu") + + olmo_model = OLMo(cfg.model).to_empty(device=device) + olmo_model.reset_parameters() + dist_model = _get_dist_model(cfg, olmo_model, distributed_strategy) + + optim = build_optimizer(cfg, dist_model) + scheduler = build_scheduler(cfg) + train_loader = build_train_dataloader(cfg) + + with Trainer( + cfg=cfg, + epoch=cfg.epoch, + model=olmo_model, + dist_model=dist_model, # type: ignore + optim=optim, + scheduler=scheduler, + train_loader=train_loader, + device=device, + evaluators=[], + indices_file=None, + ) as trainer: + if replace_existing_model: + # Save model and move *.pt files to right place + trainer.save_unsharded_checkpoint() + for path in (Path(cfg.save_folder) / "step0-unsharded/").glob("*.pt"): + shutil.copy(path, Path(model_path) / path.name) + + trainer.restore_unsharded_checkpoint(model_path) + trainer.fit() + + if replace_existing_traces: + # Replace existing trace files + model_traces_path = Path(model_path) / ("traces_cuda.tar.gz" if cuda else "traces_cpu.tar.gz") + model_traces_path.unlink(missing_ok=True) + archive_path = shutil.make_archive( + Path(cfg.save_folder).name, "gztar", root_dir=Path(cfg.save_folder), base_dir="traces" + ) + Path(archive_path).rename(model_traces_path) + + +@pytest.mark.parametrize( + "cuda, distributed_strategy", + [ + pytest.param(False, None), + pytest.param( + True, + DistributedStrategy.fsdp, + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + ], +) +def test_train_forward_unchanged( + xtiny_model_path: str, + tmp_path: Path, + cuda: bool, + distributed_strategy: Optional[DistributedStrategy], + make_process_group, + replace_existing_model: bool = False, + replace_existing_traces: bool = False, +): + """ + This test checks that the output of model forward of the 1st step has not changed (relative to an existing checkpoint). + + Set replace_existing_model and/or replace_existing_traces to True if a non-backwards-compatible change is being + intentionally made and so this test needs to be updated. + """ + cfg = _get_train_config(Path(xtiny_model_path), tmp_path / "test_forward") + cfg.module_outputs_save_steps = [1, 2] + cfg.stop_at = 2 + + if cuda: + make_process_group() + + _train_model( + xtiny_model_path, + cfg, + distributed_strategy=distributed_strategy, + cuda=cuda, + replace_existing_model=replace_existing_model, + replace_existing_traces=replace_existing_traces, + ) + + assert (Path(cfg.save_folder) / "traces/step1").is_dir(), "Output traces not found for newly trained model" + original_traces_archive = Path(xtiny_model_path) / ("traces_cuda.tar.gz" if cuda else "traces_cpu.tar.gz") + shutil.unpack_archive(original_traces_archive, tmp_path / "test_forward_baseline") + _compare_module_outputs( + tmp_path / "test_forward_baseline/traces/step1", Path(cfg.save_folder) / "traces/step1" + ) + + assert not replace_existing_model, "Test successfully updated, please disable replace_existing_model" + assert not replace_existing_traces, "Test successfully updated, please disable replace_existing_traces" + + +@pytest.mark.parametrize( + "cuda, distributed_strategy", + [ + pytest.param(False, None), + pytest.param( + True, + DistributedStrategy.fsdp, + marks=( + pytest.mark.gpu, + pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA device"), + ), + ), + ], +) +def test_train_second_step_unchanged( + xtiny_model_path: str, + tmp_path: Path, + cuda: bool, + distributed_strategy: Optional[DistributedStrategy], + make_process_group, + replace_existing_model: bool = False, + replace_existing_traces: bool = False, +): + """ + This test checks that the output of model forward of the 2nd step has not changed (relative to an existing checkpoint). + + Set replace_existing_model and/or replace_existing_traces to True if a non-backwards-compatible change is being + intentionally made and so this test needs to be updated. + """ + cfg = _get_train_config(Path(xtiny_model_path), tmp_path / "test_forward") + cfg.module_outputs_save_steps = [1, 2] + cfg.stop_at = 2 + + if cuda: + make_process_group() + + _train_model( + xtiny_model_path, + cfg, + distributed_strategy=distributed_strategy, + cuda=cuda, + replace_existing_model=replace_existing_model, + replace_existing_traces=replace_existing_traces, + ) + + assert (Path(cfg.save_folder) / "traces/step2").is_dir(), "Output traces not found for newly trained model" + original_traces_archive = Path(xtiny_model_path) / ("traces_cuda.tar.gz" if cuda else "traces_cpu.tar.gz") + shutil.unpack_archive(original_traces_archive, tmp_path / "test_forward_baseline") + _compare_module_outputs( + tmp_path / "test_forward_baseline/traces/step2", Path(cfg.save_folder) / "traces/step2" + ) + + assert not replace_existing_model, "Test successfully updated, please disable replace_existing_model" + assert not replace_existing_traces, "Test successfully updated, please disable replace_existing_traces"