Skip to content

Commit

Permalink
Add Ascend NPU support (#1758)
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao authored Nov 21, 2024
1 parent 2e99bb3 commit 838b74d
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 16 deletions.
15 changes: 14 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pynvml
import torch
from pynvml.nvml import NVMLError
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.utils.distributed import get_device_type


def check_cuda_device(default_value):
Expand Down Expand Up @@ -53,6 +56,12 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0


def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0


@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
Expand All @@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0):


def log_gpu_memory_usage(log, msg, device):
cur_device = get_device_type()
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
Expand All @@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2,
)
return usage, cache, misc
8 changes: 7 additions & 1 deletion src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
Expand All @@ -29,7 +30,10 @@ def get_device():
if torch.backends.mps.is_available():
return "mps"

raise SystemError("No CUDA/mps device found")
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"

raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

Expand All @@ -39,6 +43,8 @@ def get_device():
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}

Expand Down
35 changes: 35 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.utils.config.models.internals import GPUCapabilities

Expand Down Expand Up @@ -1433,6 +1434,40 @@ def check_torch_compile_deepspeed(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
raise NotImplementedError(
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
)

# check quant config
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
optimizer = data.get("optimizer")
raise NotImplementedError(
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
)

quant_list = ["load_in_8bit", "load_in_4bit"]
for quant in quant_list:
if data.get(quant):
raise NotImplementedError(
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
)

# check dtype config
if data.get("tf32"):
raise NotImplementedError(
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
)

return data


class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
Expand Down
52 changes: 44 additions & 8 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,44 @@
import torch
import torch.distributed as dist
from accelerate import PartialState
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
)

distributed_state = None # pylint: disable=invalid-name


def get_device_type():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device


def get_device_count():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.device_count()
if "npu" in str(cur_device):
return torch.npu.device_count()
return 1


def get_current_device():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.current_device()
if "npu" in str(cur_device):
return torch.npu.current_device()
return 0


def is_distributed():
"""
Check if distributed training is initialized.
Expand Down Expand Up @@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()

if not is_main_process():
Expand All @@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
if not is_distributed():
return vals

cur_device = get_device_type()
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
data_size = torch.IntTensor([0]).to(cur_device)

dist.broadcast(data_size, 0)
if not is_main_process():
Expand Down Expand Up @@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
Returns:
- The computed value (int or float).
"""
cur_device = f"{get_device_type()}:{get_current_device()}"
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
value_scalar, device=cur_device, dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device(), dtype=torch.float32
0.0, device=cur_device, dtype=torch.float32
) # Placeholder tensor

# Broadcast the tensor to all processes.
Expand All @@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()

# Placeholder tensor for gathering results
Expand Down
20 changes: 14 additions & 6 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
Expand Down Expand Up @@ -570,7 +570,8 @@ def set_device_map_config(self) -> None:
)

max_memory = {}
for i in range(torch.cuda.device_count()):
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything

Expand All @@ -595,8 +596,11 @@ def set_device_map_config(self) -> None:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype

if torch.backends.mps.is_available():
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"

# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
Expand Down Expand Up @@ -1050,7 +1054,11 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
self.ajust_model_config()

# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)

# make sure these are fp32 per Ramesh et al. (2021)
Expand Down Expand Up @@ -1118,9 +1126,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"cuda:{self.cfg.local_rank}")
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")

if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)

Expand Down

0 comments on commit 838b74d

Please sign in to comment.