Skip to content

Commit

Permalink
run w/ target 70b & max_tokens=50
Browse files Browse the repository at this point in the history
  • Loading branch information
keyboardAnt committed Sep 23, 2024
1 parent 192dd5b commit e2c2653
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions poc/actual/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,13 @@ async def get_latency(async_func, *args, **kwargs):
@torch.no_grad()
async def main():
print("Main started")
# manager_cls = Manager
manager_cls = ManagerNonSI
# verifier_cls = VerifierSlow
verifier_cls = Verifier
verifier_cls = VerifierSlow
drafter_cls = Drafter
# verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct"
verifier_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
verifier_name: str = "meta-llama/Meta-Llama-3.1-70B-Instruct"
verifier_load_in_8bit = True
verifier_dtype = torch.float16
num_verifiers = 2
max_new_tokens = 200
# max_new_tokens = 100
# max_new_tokens = 20
max_new_tokens = 50
verifier_device_map_filename = (
"device_map_meta-llama_Meta-Llama-3.1-70B-Instruct_8bit_on_3A40_custom.json"
)
Expand Down Expand Up @@ -122,15 +116,12 @@ async def run_our_implementation():
return manager.tok_ids

latencies = defaultdict(list)
# TODO: Remove this
prompts = prompts[:1]
ts = time.time()
with logfire.span("Run with args:\n{args}\nat {ts}", args=locals(), ts=ts):
for prompt in tqdm(prompts, desc="Prompts"):
with logfire.span("Prompt: {prompt}", prompt=prompt):
tok_ids = encode(prompt, verifier_name)
# for manager_cls in [Manager, ManagerSI, ManagerNonSI]:
for manager_cls in [ManagerSI]:
for manager_cls in [Manager, ManagerSI, ManagerNonSI]:
with logfire.span("{cls_name}", cls_name=manager_cls):
manager = manager_cls(
draft_queue=draft_queue,
Expand Down

0 comments on commit e2c2653

Please sign in to comment.