Skip to content

Commit

Permalink
minimal run w/ logfire & random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 22, 2024
1 parent f226d0a commit 757bb70
Showing 1 changed file with 55 additions and 47 deletions.
102 changes: 55 additions & 47 deletions poc/actual/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import statistics
import time
from dsi.utils import set_random_seed
import logfire

import accelerate

Check failure on line 8 in poc/actual/main.py

View workflow job for this annotation

GitHub Actions / ci (3.11.9, ubuntu-latest)

Ruff (F401)

poc/actual/main.py:8:8: F401 `accelerate` imported but unused
Expand Down Expand Up @@ -119,57 +120,64 @@ async def run_our_implementation():
return manager.tok_ids

latencies = defaultdict(list)
prompts = prompts[1:]
for prompt in tqdm(prompts, desc="Prompts"):
tok_ids = encode(prompt, verifier_name)
for manager_cls in [Manager, ManagerSI, ManagerNonSI]:
with logfire.span("{cls_name} - Prompt: {prompt}", cls_name=manager_cls.__name__, prompt=prompt):
manager = manager_cls(
draft_queue=draft_queue,
verify_queue=verify_queue,
response_queue=response_queue,
pubsub=pubsub,
tok_ids=tok_ids,
max_new_tokens=max_new_tokens,
vocab_size=128256,
lookahead=5,
)
print(f"Main: Created {manager.__class__.__name__}")
# TODO: Remove this
prompts = prompts[:1]
with logfire.span("Run with args: {args}", args=locals()):
for prompt in tqdm(prompts, desc="Prompts"):
with logfire.span("Prompt: {prompt}", prompt=prompt):
tok_ids = encode(prompt, verifier_name)
# for manager_cls in [Manager, ManagerSI, ManagerNonSI]:
for manager_cls in [ManagerSI]:
with logfire.span("{cls_name}", cls_name=manager_cls):
manager = manager_cls(
draft_queue=draft_queue,
verify_queue=verify_queue,
response_queue=response_queue,
pubsub=pubsub,
tok_ids=tok_ids,
max_new_tokens=max_new_tokens,
vocab_size=128256,
lookahead=5,
)
print(f"Main: Created {manager.__class__.__name__}")
cleanup()
latency, out_tok_ids = await get_latency(run_our_implementation)
latencies[manager.__class__.__name__].append(latency)
print(f"Main: Output tok_ids:\n{out_tok_ids}")
logfire.info("Main: Output tok_ids:\n{out_tok_ids}", out_tok_ids=out_tok_ids)
out_str = decode(out_tok_ids, verifier_name)
print(f"Main: Final output:\n{out_str}")
logfire.info("Main: Final output:\n{out_str}", out_str=out_str)
for worker in workers:
worker.reset()
cleanup()
latency, out_tok_ids = await get_latency(run_our_implementation)
latencies[manager.__class__.__name__].append(latency)
print(f"Main: Output tok_ids:\n{out_tok_ids}")
logfire.info("Main: Output tok_ids:\n{out_tok_ids}", out_tok_ids=out_tok_ids)
out_str = decode(out_tok_ids, verifier_name)
print(f"Main: Final output:\n{out_str}")
logfire.info("Main: Final output:\n{out_str}", out_str=out_str)
for worker in workers:
worker.reset()
cleanup()
print("Running non-SI HF...")
with logfire.span("NonSI HF"):
latency, out_tok_ids = await get_latency(run_nonsi_hf)
latencies["NonSI HF"].append(latency)
print(f"Main: Output tok_ids:\n{out_tok_ids}")
logfire.info("Main: Output tok_ids:\n{out_tok_ids}", out_tok_ids=out_tok_ids)
out_str = decode(out_tok_ids, verifier_name)
print(f"Main: Final output:\n{out_str}")
logfire.info("Main: Final output:\n{out_str}", out_str=out_str)
for worker in workers:
worker.reset()
print(f"Latencies: {latencies}")
logfire.info("Latencies: {latencies}", latencies=latencies)
for manager_cls, latencies in latencies.items():
print(f"Latencies for {manager_cls.__name__}: {latencies}")
mean_latency = sum(latencies) / len(latencies)
print(f"Mean latency: {mean_latency:.2f} seconds")
stddev = statistics.stdev(latencies)
print(f"Standard deviation: {stddev:.2f} seconds")
logfire.info("Mean latency for {manager_cls}: {mean_latency:.2f} seconds", manager_cls=manager_cls, mean_latency=mean_latency)
logfire.info("Standard deviation for {manager_cls}: {stddev:.2f} seconds", manager_cls=manager_cls, stddev=stddev)
print("Running non-SI HF...")
with logfire.span("NonSI HF"):
latency, out_tok_ids = await get_latency(run_nonsi_hf)
latencies["NonSI HF"].append(latency)
print(f"Main: Output tok_ids:\n{out_tok_ids}")
logfire.info("Main: Output tok_ids:\n{out_tok_ids}", out_tok_ids=out_tok_ids)
out_str = decode(out_tok_ids, verifier_name)
print(f"Main: Final output:\n{out_str}")
logfire.info("Main: Final output:\n{out_str}", out_str=out_str)
for worker in workers:
worker.reset()
print(f"Latencies: {latencies}")
logfire.info("Latencies: {latencies}", latencies=latencies)
for manager_cls, latencies in latencies.items():

Check failure on line 167 in poc/actual/main.py

View workflow job for this annotation

GitHub Actions / ci (3.11.9, ubuntu-latest)

Ruff (B020)

poc/actual/main.py:167:26: B020 Loop control variable `latencies` overrides iterable it iterates
print(f"Latencies for {manager_cls}: {latencies}")
mean_latency = sum(latencies) / len(latencies)
print(f"Mean latency: {mean_latency:.2f} seconds")
logfire.info("Mean latency for {manager_cls}: {mean_latency:.2f} seconds", manager_cls=manager_cls, mean_latency=mean_latency)
if len(latencies) > 1:
stddev = statistics.stdev(latencies)
print(f"Standard deviation: {stddev:.2f} seconds")
logfire.info("Standard deviation for {manager_cls}: {stddev:.2f} seconds", manager_cls=manager_cls, stddev=stddev)


if __name__ == "__main__":
print("Starting script.")
set_random_seed(42)
with cuda_memory_recording(max_entries=1_000):
asyncio.run(main())
shutdown_asyncio()
Expand Down

0 comments on commit 757bb70

Please sign in to comment.