diff --git a/examples/tiny-llama/lora-mps.yml b/examples/tiny-llama/lora-mps.yml new file mode 100644 index 0000000000..e744638ba4 --- /dev/null +++ b/examples/tiny-llama/lora-mps.yml @@ -0,0 +1,65 @@ +base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true +eval_sample_packing: false + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: false +tf32: true + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: false + +warmup_steps: 10 +evals_per_epoch: 0 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/setup.py b/setup.py index 6f816ce4a6..d4a39b76ea 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ """setup.py for axolotl""" +import platform +import re from importlib.metadata import PackageNotFoundError, version from setuptools import find_packages, setup @@ -26,11 +28,25 @@ def parse_requirements(): _install_requires.append(line) try: - torch_version = version("torch") - _install_requires.append(f"torch=={torch_version}") - if torch_version.startswith("2.1."): + if "Darwin" in platform.system(): _install_requires.pop(_install_requires.index("xformers==0.0.22")) - _install_requires.append("xformers>=0.0.23") + else: + torch_version = version("torch") + _install_requires.append(f"torch=={torch_version}") + + version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) + if version_match: + major, minor, patch = version_match.groups() + major, minor = int(major), int(minor) + patch = ( + int(patch) if patch is not None else 0 + ) # Default patch to 0 if not present + else: + raise ValueError("Invalid version format") + + if (major, minor) >= (2, 1): + _install_requires.pop(_install_requires.index("xformers==0.0.22")) + _install_requires.append("xformers>=0.0.23") except PackageNotFoundError: pass diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 63141635ab..e43c58650a 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -186,8 +186,8 @@ def mask_2d_to_4d( # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one binary_mask = torch.where( mask != 0, - torch.tensor(1).to(dtype), - torch.tensor(0).to(dtype), + torch.tensor(1, device=mask.device).to(dtype), + torch.tensor(0, device=mask.device).to(dtype), ) # Create a block-diagonal mask. diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 8f33665c69..c039e790a1 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -47,6 +47,12 @@ def gpu_memory_usage_all(device=0): return usage, reserved - usage, max(0, smi - reserved) +def mps_memory_usage_all(): + usage = torch.mps.current_allocated_memory() / 1024.0**3 + reserved = torch.mps.driver_allocated_memory() / 1024.0**3 + return usage, reserved - usage, 0 + + @check_cuda_device(0.0) def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): @@ -63,7 +69,10 @@ def gpu_memory_usage_smi(device=0): def log_gpu_memory_usage(log, msg, device): - usage, cache, misc = gpu_memory_usage_all(device) + if torch.backends.mps.is_available(): + usage, cache, misc = mps_memory_usage_all() + else: + usage, cache, misc = gpu_memory_usage_all(device) extras = [] if cache > 0: extras.append(f"+{cache:.03f}GB cache") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 52a81ea2c0..1df6228ab5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -429,6 +429,10 @@ def load_model( model_kwargs["device_map"] = device_map model_kwargs["torch_dtype"] = cfg.torch_dtype + + if torch.backends.mps.is_available(): + model_kwargs["device_map"] = "mps:0" + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss # if cfg.rl: # if torch.cuda.device_count() > 1: @@ -668,7 +672,7 @@ def load_model( ): model.config.eos_token_id = tokenizer.eos_token_id - if hasattr(model, "device") and model.device.type == "cuda": + if hasattr(model, "device") and model.device.type in ("cuda", "mps"): log_gpu_memory_usage(LOG, "after model load", model.device) # make sure these are fp32 per Ramesh et al. (2021)