Skip to content

Commit

Permalink
upstream resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 committed Sep 7, 2023
1 parent 81fecf3 commit 33da341
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 33 deletions.
10 changes: 8 additions & 2 deletions examples/chat_completion/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def main(


# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
model = load_model(model_name, quantization)
if peft_model:
Expand Down Expand Up @@ -105,7 +108,10 @@ def main(
sys.exit(1) # Exit the program with an error status
tokens= torch.tensor(chat).long()
tokens= tokens.unsqueeze(0)
tokens= tokens.to("cuda:0")
if is_xpu_available():
tokens= tokens.to("xpu:0")
else:
tokens= tokens.to("cuda:0")
outputs = model.generate(
input_ids=tokens,
max_new_tokens=max_new_tokens,
Expand Down
19 changes: 18 additions & 1 deletion examples/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@

import torch
from transformers import LlamaTokenizer
<<<<<<< HEAD:examples/inference.py

from llama_recipes.inference.safety_utils import get_safety_checker
from llama_recipes.inference.model_utils import load_model, load_peft_model

=======
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model, load_llama_from_config
from accelerate.utils import is_xpu_available
>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py

def main(
model_name,
Expand Down Expand Up @@ -50,7 +56,10 @@ def main(
sys.exit(1)

# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

model = load_model(model_name, quantization)
Expand Down Expand Up @@ -102,7 +111,15 @@ def main(

batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")

<<<<<<< HEAD:examples/inference.py
batch = {k: v.to("cuda") for k, v in batch.items()}
=======
batch = tokenizer(user_prompt, return_tensors="pt")
if is_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}
>>>>>>> ed7ba99 (enable xpu finetuning and inference):inference/inference.py
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
Expand Down
6 changes: 5 additions & 1 deletion examples/vllm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import torch
from vllm import LLM
from vllm import LLM, SamplingParams
from accelerate.utils import is_xpu_available

if is_xpu_available():
torch.xpu.manual_seed(42)
else:
torch.cuda.manual_seed(42)

torch.cuda.manual_seed(42)
torch.manual_seed(42)

def load_model(model_name, tp_size=1):
Expand Down
18 changes: 14 additions & 4 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@
print_model_size,
get_policies
)
from accelerate.utils import is_xpu_available


def main(**kwargs):
# Update the configuration for the training and sharding process
update_config((train_config, fsdp_config), **kwargs)

# Set the seeds for reproducibility
torch.cuda.manual_seed(train_config.seed)
if is_xpu_available():
torch.xpu.manual_seed(train_config.seed)
else:
torch.cuda.manual_seed(train_config.seed)
torch.manual_seed(train_config.seed)

if train_config.enable_fsdp:
Expand All @@ -60,7 +64,10 @@ def main(**kwargs):
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
if is_xpu_available():
torch.xpu.set_device(local_rank)
else:
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)

Expand Down Expand Up @@ -146,7 +153,7 @@ def main(**kwargs):
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
sharding_strategy=fsdp_config.sharding_strategy,
device_id=torch.cuda.current_device(),
device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
Expand All @@ -155,7 +162,10 @@ def main(**kwargs):
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
model.to("cuda")
if is_xpu_available():
model.to("xpu:0")
else:
model.to("cuda")

dataset_config = generate_dataset_config(train_config, kwargs)

Expand Down
47 changes: 33 additions & 14 deletions src/llama_recipes/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
import threading

import torch
from accelerate.utils import is_xpu_available

def byte2gb(x):
return int(x / 2**30)
# This context manager is used to track the peak memory usage of the process
class MemoryTrace:
def __enter__(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
if is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.cuda.memory_allocated())
self.process = psutil.Process()
self.cpu_begin = byte2gb(self.cpu_mem_used())
self.peak_monitoring = True
Expand Down Expand Up @@ -44,17 +50,30 @@ def __exit__(self, *exc):
self.peak_monitoring = False

gc.collect()
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
if is_xpu_available():
torch.xpu.empty_cache()
self.end = byte2gb(torch.xpu.memory_allocated())
self.peak = byte2gb(torch.xpu.max_memory_allocated())
xpu_info = torch.xpu.memory_stats()
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
else:
torch.cuda.empty_cache()
self.end = byte2gb(torch.cuda.memory_allocated())
self.peak = byte2gb(torch.cuda.max_memory_allocated())
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())

self.cpu_end = self.cpu_mem_used()
self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
Expand Down
47 changes: 36 additions & 11 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available


def set_tokenizer_params(tokenizer: LlamaTokenizer):
Expand Down Expand Up @@ -101,7 +102,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
train_epoch_loss = total_loss / len(train_dataloader)
if train_config.enable_fsdp:
Expand All @@ -113,17 +116,29 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

if train_config.enable_fsdp:
if rank==0:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
if is_xpu_available():
print(f"Max XPU memory allocated was {memtrace.peak} GB")
print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
else:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")

# Update the learning rate as needed
lr_scheduler.step()
Expand Down Expand Up @@ -246,6 +261,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
)

# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)

Expand Down Expand Up @@ -279,7 +296,11 @@ def check_frozen_layers_peft_model(model):

def setup():
"""Initialize the process group for distributed training"""
dist.init_process_group("nccl")
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
else:
dist.init_process_group("nccl")


def setup_environ_flags(rank):
Expand All @@ -303,7 +324,10 @@ def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
torch.cuda.empty_cache()
if is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()


def get_parameter_dtypes(model):
Expand Down Expand Up @@ -335,13 +359,14 @@ def print_model_size(model, config, rank: int = 0) -> None:
def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""

verify_bfloat_support = (
verify_bfloat_support = ((
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
) or
(is_xpu_available()))


mixed_precision_policy = None
Expand Down

0 comments on commit 33da341

Please sign in to comment.