From ed7ba999a9d5509178cf9b69ae201c64f2324cad Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Thu, 10 Aug 2023 23:39:04 -0700 Subject: [PATCH] enable xpu finetuning and inference --- inference/chat_completion.py | 10 ++++++-- inference/inference.py | 11 +++++++-- inference/vLLM_inference.py | 6 ++++- llama_finetuning.py | 18 +++++++++++--- utils/memory_utils.py | 47 +++++++++++++++++++++++++----------- utils/train_utils.py | 47 +++++++++++++++++++++++++++--------- 6 files changed, 105 insertions(+), 34 deletions(-) diff --git a/inference/chat_completion.py b/inference/chat_completion.py index d5c8378bd..f4920db21 100644 --- a/inference/chat_completion.py +++ b/inference/chat_completion.py @@ -55,7 +55,10 @@ def main( # Set the seeds for reproducibility - torch.cuda.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed(seed) + else: + torch.cuda.manual_seed(seed) torch.manual_seed(seed) model = load_model(model_name, quantization) if peft_model: @@ -105,7 +108,10 @@ def main( sys.exit(1) # Exit the program with an error status tokens= torch.tensor(chat).long() tokens= tokens.unsqueeze(0) - tokens= tokens.to("cuda:0") + if is_xpu_available(): + tokens= tokens.to("xpu:0") + else: + tokens= tokens.to("cuda:0") outputs = model.generate( tokens, max_new_tokens=max_new_tokens, diff --git a/inference/inference.py b/inference/inference.py index c010c07ca..95d323881 100644 --- a/inference/inference.py +++ b/inference/inference.py @@ -13,6 +13,7 @@ from transformers import LlamaTokenizer from safety_utils import get_safety_checker from model_utils import load_model, load_peft_model, load_llama_from_config +from accelerate.utils import is_xpu_available def main( model_name, @@ -48,7 +49,10 @@ def main( sys.exit(1) # Set the seeds for reproducibility - torch.cuda.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed(seed) + else: + torch.cuda.manual_seed(seed) torch.manual_seed(seed) model = load_model(model_name, quantization) @@ -98,7 +102,10 @@ def main( sys.exit(1) # Exit the program with an error status batch = tokenizer(user_prompt, return_tensors="pt") - batch = {k: v.to("cuda") for k, v in batch.items()} + if is_xpu_available(): + batch = {k: v.to("xpu") for k, v in batch.items()} + else: + batch = {k: v.to("cuda") for k, v in batch.items()} start = time.perf_counter() with torch.no_grad(): outputs = model.generate( diff --git a/inference/vLLM_inference.py b/inference/vLLM_inference.py index 63c644148..ff09b74cf 100644 --- a/inference/vLLM_inference.py +++ b/inference/vLLM_inference.py @@ -14,8 +14,12 @@ ) from vllm import LLM from vllm import LLM, SamplingParams +from accelerate.utils import is_xpu_available -torch.cuda.manual_seed(42) +if is_xpu_available(): + torch.xpu.manual_seed(42) +else: + torch.cuda.manual_seed(42) torch.manual_seed(42) def load_model(model_name, tp_size=1): diff --git a/llama_finetuning.py b/llama_finetuning.py index b7c2dc603..77f42a982 100644 --- a/llama_finetuning.py +++ b/llama_finetuning.py @@ -64,6 +64,7 @@ import torch.cuda.nccl as nccl import torch.distributed as dist from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from accelerate.utils import is_xpu_available def main(**kwargs): @@ -71,7 +72,10 @@ def main(**kwargs): update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility - torch.cuda.manual_seed(train_config.seed) + if is_xpu_available(): + torch.xpu.manual_seed(train_config.seed) + else: + torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) if train_config.enable_fsdp: @@ -82,7 +86,10 @@ def main(**kwargs): world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(rank) + if is_xpu_available(): + torch.xpu.set_device(rank) + else: + torch.cuda.set_device(rank) setup_environ_flags(rank) # Calculate gradient accumulation steps @@ -142,13 +149,16 @@ def main(**kwargs): auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, - device_id=torch.cuda.current_device(), + device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(), limit_all_gathers=True, ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: - model.to("cuda") + if is_xpu_available(): + model.to("xpu:0") + else: + model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) diff --git a/utils/memory_utils.py b/utils/memory_utils.py index ee134d286..db2edce38 100644 --- a/utils/memory_utils.py +++ b/utils/memory_utils.py @@ -8,6 +8,7 @@ import numpy as np import psutil import torch +from accelerate.utils import is_xpu_available def byte2gb(x): return int(x / 2**30) @@ -15,9 +16,14 @@ def byte2gb(x): class MemoryTrace: def __enter__(self): gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = byte2gb(torch.cuda.memory_allocated()) + if is_xpu_available(): + torch.xpu.empty_cache() + torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.xpu.memory_allocated()) + elif torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.cuda.memory_allocated()) self.process = psutil.Process() self.cpu_begin = byte2gb(self.cpu_mem_used()) self.peak_monitoring = True @@ -46,17 +52,30 @@ def __exit__(self, *exc): self.peak_monitoring = False gc.collect() - torch.cuda.empty_cache() - self.end = byte2gb(torch.cuda.memory_allocated()) - self.peak = byte2gb(torch.cuda.max_memory_allocated()) - cuda_info = torch.cuda.memory_stats() - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.m_cuda_ooms = cuda_info.get("num_ooms", 0) - self.used = byte2gb(self.end - self.begin) - self.peaked = byte2gb(self.peak - self.begin) - self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) + if is_xpu_available(): + torch.xpu.empty_cache() + self.end = byte2gb(torch.xpu.memory_allocated()) + self.peak = byte2gb(torch.xpu.max_memory_allocated()) + xpu_info = torch.xpu.memory_stats() + self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) + self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) + self.m_xpu_ooms = xpu_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.xpu.max_memory_reserved()) + else: + torch.cuda.empty_cache() + self.end = byte2gb(torch.cuda.memory_allocated()) + self.peak = byte2gb(torch.cuda.max_memory_allocated()) + cuda_info = torch.cuda.memory_stats() + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.m_cuda_ooms = cuda_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) self.cpu_end = self.cpu_mem_used() self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) diff --git a/utils/train_utils.py b/utils/train_utils.py index 8113ef173..05bb7dafe 100644 --- a/utils/train_utils.py +++ b/utils/train_utils.py @@ -36,6 +36,7 @@ from pathlib import Path sys.path.append(str(Path(__file__).resolve().parent.parent)) from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper +from accelerate.utils import is_xpu_available, is_ccl_available def set_tokenizer_params(tokenizer: LlamaTokenizer): tokenizer.pad_token_id = 0 @@ -113,7 +114,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche epoch_end_time = time.perf_counter()-epoch_start_time epoch_times.append(epoch_end_time) # Reducing total_loss across all devices if there's more than one CUDA device - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) train_epoch_loss = total_loss / len(train_dataloader) if train_config.enable_fsdp: @@ -125,17 +128,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche if train_config.enable_fsdp: if rank==0: + if is_xpu_available(): + print(f"Max XPU memory allocated was {memtrace.peak} GB") + print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") + print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") + else: + print(f"Max CUDA memory allocated was {memtrace.peak} GB") + print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + else: + if is_xpu_available(): + print(f"Max XPU memory allocated was {memtrace.peak} GB") + print(f"Max XPU memory reserved was {memtrace.max_reserved} GB") + print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB") + print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}") + else: print(f"Max CUDA memory allocated was {memtrace.peak} GB") print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") - else: - print(f"Max CUDA memory allocated was {memtrace.peak} GB") - print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") - print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") - print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") - print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") # Update the learning rate as needed lr_scheduler.step() @@ -259,6 +274,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): ) # If there's more than one CUDA device, reduce evaluation loss across all devices + if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp): + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) if torch.cuda.device_count() > 1 and train_config.enable_fsdp: dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) @@ -292,7 +309,11 @@ def check_frozen_layers_peft_model(model): def setup(): """Initialize the process group for distributed training""" - dist.init_process_group("nccl") + if is_ccl_available(): + # distributed training on xpus + dist.init_process_group("ccl") + else: + dist.init_process_group("nccl") def setup_environ_flags(rank): @@ -316,7 +337,10 @@ def clear_gpu_cache(rank=None): """Clear the GPU cache for all ranks""" if rank == 0: print(f"Clearing GPU cache for all ranks") - torch.cuda.empty_cache() + if is_xpu_available(): + torch.xpu_empty_cache() + else: + torch.cuda.empty_cache() def get_parameter_dtypes(model): @@ -348,13 +372,14 @@ def print_model_size(model, config, rank: int = 0) -> None: def get_policies(cfg, rank): """Get the policies for mixed precision and fsdp wrapping""" - verify_bfloat_support = ( + verify_bfloat_support = (( torch.version.cuda and torch.cuda.is_bf16_supported() and packaging.version.parse(torch.version.cuda).release >= (11, 0) and dist.is_nccl_available() and nccl.version() >= (2, 10) - ) + ) or + (is_xpu_available())) mixed_precision_policy = None