Skip to content

Commit

Permalink
load drafter on gpu 0 only & free tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 15, 2024
1 parent 9322470 commit 44142bc
Showing 1 changed file with 81 additions and 52 deletions.
133 changes: 81 additions & 52 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, Tuple
from uuid import UUID, uuid4

import accelerate
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -407,23 +408,26 @@ async def load_model(
cache_dir: None | str = None,
) -> None:
"""Loads the model from the given name and moves it to the device."""
device = cpu = "cpu"
if torch.cuda.device_count() > self.gpu_id:
print(f"GPU {self.gpu_id} available. Using GPU.")
device = f"cuda:{self.gpu_id}"
else:
print(f"GPU {self.gpu_id} not available. Using CPU.")
print(f"{self.__class__.__name__}: Loading model {name} on {device} (using device map {device_map})")
# device = cpu = "cpu"
# if torch.cuda.device_count() > self.gpu_id:
# print(f"GPU {self.gpu_id} available. Using GPU.")
# device = f"cuda:{self.gpu_id}"
# else:
# print(f"GPU {self.gpu_id} not available. Using CPU.")
# print(
# f"{self.__class__.__name__}: Loading model {name} on {device} (using device map {device_map})"
# )
if cache_dir is None:
cache_dir = os.environ["TRANSFORMERS_CACHE"]
print(f"{self.__class__.__name__}: Loading model {name} with {device_map=}")
self.model = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=dtype, device_map=device_map, cache_dir=cache_dir
)
self.model.eval()
# if device != cpu:
# print(f"{self.__class__.__name__}: Moving model to {device}")
# self.model.to(device)
print(f"{self.__class__.__name__}: Model loaded on {device}")
print(f"{self.__class__.__name__}: Model loaded")

async def run(self) -> None:
"""
Expand Down Expand Up @@ -729,6 +733,16 @@ async def broadcast(self) -> None:
print(f"PubSub: Broadcast complete. Queue size: {self.queue.qsize()}")


def get_device_map_with_only_gpu_0(model_name):
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
model_name, cache_dir="/workspace/hf_cache"
)
max_memory = {i: 0 for i in range(1, torch.cuda.device_count())}
max_memory[0] = f"{torch.cuda.mem_get_info(0)[0] / 1024 / 1024 / 1024:.2f} GB"
return accelerate.infer_auto_device_map(model, max_memory=max_memory)


def setup_hf_cache():
if torch.cuda.device_count() > 0:
os.environ["TRANSFORMERS_CACHE"] = "/workspace/hf_cache"
Expand All @@ -748,7 +762,7 @@ async def run(
verifier_dtype: torch.dtype,
drafter_dtype: torch.dtype,
lookahead: int,
prompt: str,
tok_ids: torch.Tensor,
max_new_tokens: int,
) -> None:
setup_hf_cache()
Expand All @@ -761,10 +775,8 @@ async def run(
print("Main: Creating server instances")
# Define the missing arguments
print(f"Loading tokenizer for {verifier_name}")
tokenizer = AutoTokenizer.from_pretrained(verifier_name)
tok_ids = tokenizer.encode(prompt, return_tensors="pt")
print_gpu_memory()
print(f"Main: Creating manager with prompt: {prompt}")
print("Main: Creating manager")
manager = Manager(
draft_queue,
verify_queue,
Expand All @@ -775,7 +787,7 @@ async def run(
lookahead,
)
drafter = Drafter(draft_queue, response_queue, manager, 0)
print(f"Main: Creating drafter with prompt: {prompt}")
print("Main: Creating drafter")
print_gpu_memory()
available_gpus = torch.cuda.device_count()
print(f"Main: Available GPUs: {available_gpus}")
Expand Down Expand Up @@ -803,7 +815,7 @@ async def run(
await drafter.load_model(
drafter_name,
dtype=drafter_dtype,
device_map="auto",
device_map=get_device_map_with_only_gpu_0(),
cache_dir=os.environ["TRANSFORMERS_CACHE"],
)
print_gpu_memory()
Expand Down Expand Up @@ -834,24 +846,14 @@ async def run(
f"Main: Manager task completed. Time taken: {time_end - time_start:.2f} seconds"
)
print(f"Main: Final tok_ids: {manager.tok_ids}")
decoded_output = tokenizer.batch_decode(manager.tok_ids, skip_special_tokens=True)
print(f"Main: Final output: {decoded_output}")
# Close all asyncio tasks or resources without waiting for them to complete
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
print("Main: All servers are closed")
return manager.tok_ids


def generate(
model_name: str, dtype: torch.dtype, prompt: str, max_new_tokens: int
model_name: str, dtype: torch.dtype, tok_ids: torch.Tensor, max_new_tokens: int
) -> str:
setup_hf_cache()
print(f"Loading tokenizer for {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tok_ids = tokenizer.encode(prompt, return_tensors="pt")
print_gpu_memory()
print(f"Loading model {model_name}")
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -882,7 +884,7 @@ def generate(
print(
f"Generating with model {model_name} took {time_end - time_start:.2f} seconds"
)
return tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return outputs.sequences


def garbage_collect():
Expand All @@ -894,10 +896,29 @@ def garbage_collect():
def print_gpu_memory():
print(f"The current device is {torch.cuda.current_device()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.mem_get_info(i)[0] / 1024 / 1024 / 1024:.2f} GB free, {torch.cuda.mem_get_info(i)[1] / 1024 / 1024 / 1024:.2f} GB total")
print(
f"GPU {i}: {torch.cuda.mem_get_info(i)[0] / 1024 / 1024 / 1024:.2f} GB free, {torch.cuda.mem_get_info(i)[1] / 1024 / 1024 / 1024:.2f} GB total"
)


if __name__ == "__main__":
def encode(prompt: str, tokernizer_name: str) -> torch.Tensor:
tokenizer = AutoTokenizer.from_pretrained(tokernizer_name)
tok_ids = tokenizer.encode(prompt, return_tensors="pt")
del tokenizer
garbage_collect()
return tok_ids


def decode(tok_ids: torch.Tensor, tokernizer_name: str) -> str:
tokenizer = AutoTokenizer.from_pretrained(tokernizer_name)
decoded_output = tokenizer.batch_decode(tok_ids, skip_special_tokens=True)
del tokenizer
garbage_collect()
return decoded_output


@torch.no_grad()
async def main():
print("Script started")
print_gpu_memory()
verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct"
Expand All @@ -917,27 +938,35 @@ def print_gpu_memory():
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: Produce an annotated bibliography for the provided paper. ### Input: The provided paper is an article about the impact of social media on mental health titled “Exploring the Impact of Social Media on Mental Health”. ### Response: Chou, S., S. C. Y., & Ioerger, M. (2019). Exploring the Impact of Social Media on Mental Health. International Journal of Mental Health, 32(2). This paper discusses the impacts of different kinds of social media on the mental health of users. Firstly, the paper uses a systematic review to explore the existing research on the link between social media use and mental health. The review provides an overview of the current state of the research and discusses both positive and negative impacts that social media can have. Results suggest that there is both a positive and a negative correlation between the use of social media and mental health. The authors then proceed to focus on the effects of four specific types of social media: Instagram, Facebook, Twitter and Youtube. For each type of social media, the authors explore the effects on three specific mental health domains: depression, anxiety, and self-esteem. The paper concludes that more research is needed to understand the relationship between social media and mental health. Bates, M., & LeMesurier, S. (2017). The effect of social media on mental health. Mental Illness, 5(1). This paper examines the potential effects of social media on aspects of mental health such as self-esteem, depression, and anxiety. The authors use an online survey of 644 participants to investigate the relationship between mental health and particular uses of Facebook and Instagram. The survey looks at different users' motives for using social media, the frequency with which they use it, feelings of loneliness or anxiety while using it, and how their real life values and perspectives are impacted by social media. Results indicate that several factors influence the relationship between social media and mental health, including positive and negative attributes of different social media types. The findings of the paper suggest that research into what might influence how users engage with social media and the particular effects of platforms and ways of using them could be beneficial to understanding the relationship between social media and mental health. Olah, Z., Z. Szatmári, & Font, S. (2019). Effects of Social Media on Mental Health. Frontiers in Psychology, 10(2). The authors of this paper explore the potential effects of social media on mental health. The paper highlights both positive and negative outcomes from the use of different types of social media. It also highlights the ways in which our mental health is inextricably linked to our social life and environment. Results from a systematic review suggest that different types of social media have different effects on individuals. For example, it found that social media use has a positive effect on collaboration, connectedness, and communication, while it can also have a negative effect on loneliness, anxiety, depression and self-esteem. The paper concludes that more research is needed to understand how these different types of social media affect our mental wellbeing.
### Response:
"""
with torch.no_grad():
garbage_collect()
# print(
# generate(
# model_name=verifier_name,
# dtype=verifier_dtype,
# prompt=prompt,
# max_new_tokens=max_new_tokens,
# )
# )
# garbage_collect()
asyncio.run(
run(
verifier_name=verifier_name,
drafter_name=drafter_name,
vocab_size=vocab_size,
verifier_dtype=verifier_dtype,
drafter_dtype=drafter_dtype,
lookahead=lookahead,
prompt=prompt,
max_new_tokens=max_new_tokens,
)
)
tok_ids = encode(prompt, verifier_name)
tok_ids = generate(
model_name=verifier_name,
dtype=verifier_dtype,
tok_ids=tok_ids,
max_new_tokens=max_new_tokens,
)
print(f"Main: Final output: {decode(tok_ids, verifier_name)}")
# garbage_collect()
# tok_ids = await run(
# verifier_name=verifier_name,
# drafter_name=drafter_name,
# vocab_size=vocab_size,
# verifier_dtype=verifier_dtype,
# drafter_dtype=drafter_dtype,
# lookahead=lookahead,
# tok_ids=tok_ids,
# max_new_tokens=max_new_tokens,
# )
# print(f"Main: Final output: {decode(tok_ids, verifier_name)}")
# Close all asyncio tasks or resources without waiting for them to complete
for task in asyncio.all_tasks():
if task is not asyncio.current_task():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
print("Main: All servers are closed")
print("Script completed")


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit 44142bc

Please sign in to comment.