Skip to content

Commit

Permalink
load_in_8bit
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 15, 2024
1 parent 92bfc22 commit 29a35b8
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions poc/dsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,6 @@ async def run(self) -> None:
)
self.response_queue.task_done()
continue
# For debugging:
if response.tok_ids.shape[1] > 1 and not response.is_draft:
print(f"Manager: Response {response.id} (verification) has {response.tok_ids.shape[1]} tokens")
print(f"Manager: Check if the GPU memory has been released before garbage collection:")
print_gpu_memory()
print(f"Manager: Run garbage collection...")
garbage_collect()
print(f"Manager: Check if the GPU memory has been released after garbage collection:")
print_gpu_memory()
mask: torch.Tensor = self.id_to_mask.pop(response.id)
print(f"Manager: Popped mask {mask} for response {response.id}")
if response.is_draft:
Expand Down Expand Up @@ -413,6 +404,7 @@ async def load_model(
self,
name: str,
dtype: torch.dtype,
load_in_8bit: bool,
device_map: str,
cache_dir: None | str = None,
) -> None:
Expand All @@ -430,7 +422,7 @@ async def load_model(
cache_dir = os.environ["TRANSFORMERS_CACHE"]
print(f"{self.__class__.__name__}: Loading model {name} with {device_map=}")
self.model = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=dtype, device_map=device_map, cache_dir=cache_dir
name, torch_dtype=dtype, device_map=device_map, cache_dir=cache_dir, load_in_8bit=load_in_8bit
)
self.model.eval()
# if device != cpu:
Expand Down Expand Up @@ -617,12 +609,7 @@ async def perform_task(self, request: Request) -> Response:
# Run in executor (i.e., separate thread) to avoid blocking the event loop
scores: torch.Tensor
tok_ids: torch.Tensor
scores, tok_ids = await loop.run_in_executor(
None,
self.forward,
tok_ids,
request.n,
)
scores, tok_ids = await self.forward(tok_ids, request.n)
# Move scores and tok_ids to the CPU
scores = scores.to("cpu")
tok_ids = tok_ids.to("cpu")
Expand All @@ -636,7 +623,7 @@ async def perform_task(self, request: Request) -> Response:
tok_ids=tok_ids,
)

def forward(
async def forward(
self, tok_ids: torch.Tensor, n: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -780,6 +767,7 @@ async def run(
verifier_dtype: torch.dtype,
drafter_dtype: torch.dtype,
verifier_load_in_8bit: bool,
drafter_load_in_8bit: bool,
lookahead: int,
tok_ids: torch.Tensor,
max_new_tokens: int,
Expand Down Expand Up @@ -824,6 +812,7 @@ async def run(
verifier_name,
dtype=verifier_dtype,
device_map="balanced_low_0",
load_in_8bit=verifier_load_in_8bit,
cache_dir=os.environ["TRANSFORMERS_CACHE"],
)
for verifier in verifiers
Expand All @@ -835,6 +824,7 @@ async def run(
drafter_name,
dtype=drafter_dtype,
device_map=get_device_map_with_only_gpu_0(drafter_name),
load_in_8bit=drafter_load_in_8bit,
cache_dir=os.environ["TRANSFORMERS_CACHE"],
)
print_gpu_memory()
Expand Down Expand Up @@ -869,7 +859,7 @@ async def run(


def generate(
model_name: str, dtype: torch.dtype, tok_ids: torch.Tensor, max_new_tokens: int
model_name: str, dtype: torch.dtype, load_in_8bit: bool, tok_ids: torch.Tensor, max_new_tokens: int
) -> str:
setup_hf_cache()
print(f"Loading tokenizer for {model_name}")
Expand All @@ -879,6 +869,7 @@ def generate(
model_name,
torch_dtype=dtype,
device_map="auto",
load_in_8bit=load_in_8bit,
cache_dir=os.environ["TRANSFORMERS_CACHE"],
)
model.eval()
Expand Down Expand Up @@ -945,8 +936,9 @@ async def main():
verifier_dtype: torch.dtype = torch.float16
drafter_dtype: torch.dtype = torch.float16
verifier_load_in_8bit: bool = True
drafter_load_in_8bit: bool = True
vocab_size: int = 128256
lookahead: int = 1
lookahead: int = 6
max_new_tokens: int = 100
prompt: str = """Below is an instruction that describes a
task, paired with an input that provides
Expand All @@ -962,6 +954,7 @@ async def main():
# tok_ids = generate(
# model_name=verifier_name,
# dtype=verifier_dtype,
# load_in_8bit=verifier_load_in_8bit,
# tok_ids=tok_ids,
# max_new_tokens=max_new_tokens,
# )
Expand All @@ -972,6 +965,7 @@ async def main():
verifier_dtype=verifier_dtype,
drafter_dtype=drafter_dtype,
verifier_load_in_8bit=verifier_load_in_8bit,
drafter_load_in_8bit=drafter_load_in_8bit,
lookahead=lookahead,
tok_ids=tok_ids,
max_new_tokens=max_new_tokens,
Expand Down

0 comments on commit 29a35b8

Please sign in to comment.