diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 28084b7fb4..8e30554475 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -15,6 +15,8 @@ jobs: base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 - name: '2.0.1_cu118' base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + - name: '2.1.0_cu121' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 steps: - name: Maximize Build Space on Worker diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 6af87346c8..efdf8eec58 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,6 +27,10 @@ jobs: container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04 markers: 'not gpu' pytest_command: 'coverage run -m pytest' + - name: 'cpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cpu-python3.10-ubuntu20.04 + markers: 'not gpu' + pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index d228802ddc..769b345e39 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -24,7 +24,11 @@ jobs: markers: 'gpu' pytest_command: 'coverage run -m pytest' - name: 'gpu-2.0.1' - container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 + container: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + markers: 'gpu' + pytest_command: 'coverage run -m pytest' + - name: 'gpu-2.1.0' + container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index b364b73bef..ba861e21bc 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -63,10 +63,7 @@ def __init__(self, om_model_config: Union[DictConfig, nn.Module], tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics - train_metrics = [ - LanguageCrossEntropy(), - LanguagePerplexity(), - ] + train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), @@ -90,6 +87,9 @@ def __init__(self, om_model_config: Union[DictConfig, 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' ) + if not om_model_config.get('use_train_metrics', True): + train_metrics = [] + # load the model config trust_remote_code = om_model_config.get('trust_remote_code', True) use_auth_token = om_model_config.get('use_auth_token', False) @@ -107,6 +107,7 @@ def __init__(self, om_model_config: Union[DictConfig, ) attr = getattr(config, k) + # attempt to disallow typos in nested configs if isinstance(attr, Mapping): extra_keys = [ _k for _k in v.keys() if _k not in attr.keys() @@ -118,6 +119,10 @@ def __init__(self, om_model_config: Union[DictConfig, f'Expected (a subset of) keys: {list(attr.keys())}.' ) getattr(config, k).update(v) + # necessary case to allow for rope_scaling to be overriden in llama config + elif attr is None and isinstance(v, Mapping): + setattr(config, k, {}) + getattr(config, k).update(v) else: setattr(config, k, v) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b1dff15398..cd162195b6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -694,7 +694,9 @@ def __init__( hf_config = MPTConfig.from_dict(resolved_om_model_config) model = MPTForCausalLM(hf_config) - train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] + use_train_metrics = om_model_config.get('use_train_metrics', True) + train_metrics = [LanguageCrossEntropy(), + LanguagePerplexity()] if use_train_metrics else [] eval_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 806dbdbd14..2c2e6e2d35 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -53,7 +53,7 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): by retaining information across optimizer steps. Raises: - NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + NotImplementedError - If any of `quantize`, `compress_state_dict`, or `error_correction` are `True` and either a) there is no CUDA device, or b) step() is executed on a non-CUDA parameter. """ @@ -67,6 +67,7 @@ def __init__(self, compress_state_dict: bool = False, error_correction: bool = False, _fused: bool = True): # XXX this flag is mostly for testing... + if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= betas[0] <= 1.0: @@ -131,11 +132,19 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: mom, try_quantize=self._quantize) need_errs = (p.dtype != torch.float32) and self._error_correction if state.get('errors') is None and need_errs: - state['errors'] = torch.zeros(p.shape, - dtype=torch.uint8, - device=p.device) + numel = p.numel() + numel += numel % 2 # ensure even number of bytes + errors = torch.zeros(numel, dtype=torch.uint8, device=p.device) + # as of torch 2.1, FSDP can't shard ints for no reason + state['errors'] = errors.view(torch.bfloat16) decay_factor = hparams['weight_decay'] decay_factor *= hparams['lr'] / hparams['initial_lr'] + errors: Optional[torch.Tensor] = None + if 'errors' in state: + errors = state['errors'] + assert errors is not None # pyright + errors = errors.view(dtype=torch.uint8) + errors = errors[:p.numel()].view(p.shape) # strip padding + reshape _lion8b_step(momentums=state['exp_avg'], weights=p, grads=p.grad, @@ -144,7 +153,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: lr=hparams['lr'], weight_decay=decay_factor, fused=hparams['fused'], - errors=state.get('errors')) + errors=errors) def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: # we override this function to quantize optimizer states when @@ -166,7 +175,8 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: # we need to cast back to the correct dtype since optimizer # load_state_dict casts to param dtype for fp params; see # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa - errs = param_state['errors'].to(dtype=torch.uint8) + errs = param_state['errors'].to(dtype=torch.uint8).view( + torch.bfloat16) new_state['errors'] = errs opt_state[param_id] = new_state super().__setstate__(state) @@ -192,6 +202,11 @@ def state_dict(self): qtensor.state_dict( name='exp_avg', allow_quantized=self._compress_state_dict)) + if 'errors' in param_state: + # fsdp apparently needs the states to be the same shape + # as the params + param_state['errors'] = param_state['errors'].view( + torch.uint8).to(dtype=torch.bfloat16) opt_state[param_id] = param_state return d diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml index 89c9c0cd9c..ae8f57abb6 100644 --- a/mcli/mcli-llama2-finetune.yaml +++ b/mcli/mcli-llama2-finetune.yaml @@ -2,6 +2,7 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry git_branch: v0.3.0 + # git_commit: # OR use your commit hash pip_install: -e .[gpu] ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-openai-eval.yaml b/mcli/mcli-openai-eval.yaml index 6275d9d578..0b770626b9 100644 --- a/mcli/mcli-openai-eval.yaml +++ b/mcli/mcli-openai-eval.yaml @@ -1,8 +1,8 @@ integrations: - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: # use your branch - # git_commit: 29d65cc26853c09f6de7542978056ddb0b07e98c # OR use your commit hash + git_branch: v0.3.0 + # git_commit: # OR use your commit hash pip_install: -e ".[gpu,openai]" ssh_clone: false # Should be true if using a private repo diff --git a/mcli/mcli-pretokenize-oci-upload.yaml b/mcli/mcli-pretokenize-oci-upload.yaml index 8163d8c3bd..b585b5f5f2 100644 --- a/mcli/mcli-pretokenize-oci-upload.yaml +++ b/mcli/mcli-pretokenize-oci-upload.yaml @@ -14,7 +14,7 @@ integrations: - oci-cli==3.23.2 - integration_type: git_repo git_repo: mosaicml/llm-foundry - git_branch: v0.2.0 + git_branch: v0.3.0 # git_commit: # OR use your commit hash pip_install: '.' ssh_clone: false # Should be true if using a private repo diff --git a/scripts/train/train.py b/scripts/train/train.py index 96d383687b..9511e88618 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -392,9 +392,12 @@ def main(cfg: DictConfig) -> Trainer: and save_folder is not None \ and not save_overwrite \ and not save_weights_only: + autoresume_default = True + + if cfg.get('autoresume') is None and autoresume_default: print('As run_name, save_folder, and save_latest_filename are set, \ changing autoresume default to True...') - autoresume_default = True + autoresume: bool = pop_config(cfg, 'autoresume', must_exist=False, diff --git a/tests/test_hf_config.py b/tests/test_hf_config.py index 99d01f309f..5b3bb3d150 100644 --- a/tests/test_hf_config.py +++ b/tests/test_hf_config.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os import tempfile from copy import deepcopy from pathlib import Path @@ -139,3 +140,30 @@ def test_hf_config_override( assert getattr(hf_model.config, k)[_k] == _v else: assert getattr(hf_model.config, k) == v + + +@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ, + reason='CI does not have access to llama2') +def test_rope_scaling_override(): + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + 'rope_scaling': { + 'type': 'dynamic', + 'factor': 0.5 + } + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'cpu', + } + model_cfg = om.create(model_cfg) + + model = COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer=None) + # This would error if the config isn't parsed into a proper dictionary + model.get_metadata() + assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5} diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index dbd6ff6352..ddb70e882b 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -387,7 +387,7 @@ class _DummyModule(nn.Module): def __init__(self, device: str, dtype: torch.dtype): super().__init__() self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) - self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore return self.linear1(self.linear0(x)) @@ -416,7 +416,7 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized(): - dist.init_process_group() + dist.init_process_group(backend='nccl') assert dist.get_world_size() >= 2, 'Misconfigured test run!' mod = FSDP(_DummyModule(device=device, dtype=dtype)) @@ -460,7 +460,7 @@ def _set_state_dict_type(model: nn.Module): # load state dict into the new optimizer opt_state_dict_slice = FSDP.optim_state_dict_to_load( - opt_state_dict, mod_new, opt_new) + optim_state_dict=opt_state_dict, model=mod_new, optim=opt_new) opt_new.load_state_dict(opt_state_dict_slice) new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) @@ -481,7 +481,7 @@ def _set_state_dict_type(model: nn.Module): assert mom_orig.shape == mom_new.shape assert mom_orig.dtype == mom_new.dtype - if use_errors: + if use_errors and (dtype != torch.float32): errs_orig = d_orig['errors'] errs_new = d_new['errors'] assert errs_orig.shape == errs_new.shape