Skip to content

Commit

Permalink
add mps support
Browse files Browse the repository at this point in the history
  • Loading branch information
maximegmd committed Feb 4, 2024
1 parent 2d65f47 commit eb300b6
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 9 deletions.
65 changes: 65 additions & 0 deletions examples/tiny-llama/lora-mps.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false

adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false

warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
15 changes: 10 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""setup.py for axolotl"""

from importlib.metadata import PackageNotFoundError, version

from packaging.version import Version, parse
import platform
from setuptools import find_packages, setup


Expand All @@ -26,11 +27,15 @@ def parse_requirements():
_install_requires.append(line)

try:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
else:
torch_version = parse(version("torch"))
_install_requires.append(f"torch=={torch_version}")

if torch_version >= Version("2.1"):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def mask_2d_to_4d(
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)

# Create a block-diagonal mask.
Expand Down
10 changes: 9 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def gpu_memory_usage_all(device=0):
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)

def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0


@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
Expand All @@ -63,7 +68,10 @@ def gpu_memory_usage_smi(device=0):


def log_gpu_memory_usage(log, msg, device):
usage, cache, misc = gpu_memory_usage_all(device)
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def load_model(

model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype

if torch.backends.mps.is_available():
model_kwargs["device_map"] = "mps:0"

# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
Expand Down Expand Up @@ -668,7 +672,7 @@ def load_model(
):
model.config.eos_token_id = tokenizer.eos_token_id

if hasattr(model, "device") and model.device.type == "cuda":
if hasattr(model, "device") and (model.device.type == "cuda" or model.device.type == "mps"):
log_gpu_memory_usage(LOG, "after model load", model.device)

# make sure these are fp32 per Ramesh et al. (2021)
Expand Down

0 comments on commit eb300b6

Please sign in to comment.