Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic components with tracing supports #4

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions deserve_client/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# DeServe Client

## How To Run

For completion:
```bash
python3 -m deserve_client.client complete meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt."
```

For dumping traces of prefill:
```bash
python3 -m deserve_client.client trace meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt."
```

For verifying the correctness of the trace:
```bash
python3 -m deserve_client.client verify meta-llama/Meta-Llama-3-8B-Instruct "Here is a text prompt."
```
Empty file added deserve_client/__init__.py
Empty file.
100 changes: 100 additions & 0 deletions deserve_client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pickle
from typing import Any

import requests
import safetensors.torch
import torch
import typer
from transformers import AutoTokenizer # type: ignore

from deserve_client.model import (
CheckCtx,
Transformer,
VerifyCtx,
llama_3_8b_args,
main_device,
)
from deserve_controller.controller_api import app
from deserve_worker.trace import OpId

cli = typer.Typer()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")


def loads(b: bytes) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
"""
Load tensors and metadata from bytes
"""

metadata_length = int.from_bytes(b[:4], byteorder="big")
metadata = pickle.loads(b[4 : 4 + metadata_length])
tensors = safetensors.torch.load(b[4 + metadata_length :])
return tensors, metadata


@cli.command()
def complete(model: str, prompt: str, entry_point: str = "http://localhost:19000"):
response = requests.post(
f"{entry_point}/complete",
json={"model": model, "prompt": prompt},
stream=True,
)
if response.status_code != 200:
typer.echo("Error")
return

for chunk in response.iter_content():
if chunk:
print(chunk.decode("utf-8"), end="", flush=True)


@cli.command()
def trace(model: str, prompt: str, entry_point: str = "http://localhost:19000"):
response = requests.post(
f"{entry_point}/trace",
json={"model": model, "prompt": prompt},
stream=True,
)
if response.status_code != 200:
typer.echo("Error")
return

tensors = {}
for chunk in response.iter_content(chunk_size=None):
if chunk:
temp_tensors, _ = loads(chunk)
tensors.update(temp_tensors)
print(list(tensors.keys()))


@cli.command()
def verify(model: str, prompt: str, entry_point: str = "http://localhost:19000"):
response = requests.post(
f"{entry_point}/trace",
json={"model": model, "prompt": prompt},
stream=True,
)
if response.status_code != 200:
typer.echo("Error")
return
tensors: dict[str, torch.Tensor] = {}
for chunk in response.iter_content(chunk_size=None):
if chunk:
temp_tensors, _ = loads(chunk)
tensors.update(temp_tensors)

traces = {OpId.from_str(k): v for k, v in tensors.items()}
transformer = Transformer(llama_3_8b_args)
tokens = tokenizer(prompt, return_tensors="pt")["input_ids"].to(main_device)
result = transformer.forward(tokens, CheckCtx(0.03, traces))
if isinstance(result, torch.Tensor):
print("No difference found")
else:
if not transformer.verify(tokens, VerifyCtx(result.op_id, 0.03, traces)):
print("Difference found for", result.op_id)
else:
print("Difference found but verification failed")


if __name__ == "__main__":
cli()
Loading