Skip to content

Commit

Permalink
#8075: [Falcon40b demo cleanup] Add back inference runs, fix perf mea…
Browse files Browse the repository at this point in the history
…surements, compile with single layer, perf mode, top-k/top-p sampling, update input json

Signed-off-by: Salar <[email protected]>
  • Loading branch information
skhorasganiTT committed May 6, 2024
1 parent 674b558 commit c3e78e9
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 216 deletions.
16 changes: 6 additions & 10 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ def run_falcon_demo_kv(

synchronize_devices(devices)

logger.info("Moving weights to device; might take some time...")
profiler.start(f"moving_to_device")

logger.info("Moving weights (single layer) to device...")
base_url = ""

tt_FalconCausalLM_singlelayer = TtFalconCausalLM(
Expand All @@ -222,9 +220,7 @@ def run_falcon_demo_kv(
tt_cache_path,
nearest_32(num_input_tokens),
) # single layer only used for compile

logger.info("Moved weights to device!")
profiler.end(f"moving_to_device")
logger.info("Moved weights (single layer) to device!")

synchronize_devices(devices)

Expand Down Expand Up @@ -328,6 +324,8 @@ def run_falcon_demo_kv(
del tt_FalconCausalLM_singlelayer
del kv_cache_singlelayer

logger.info("Moving weights (all layers) to device; might take some time...")
profiler.start(f"moving_to_device")
tt_FalconCausalLM = TtFalconCausalLM(
devices,
state_dict,
Expand All @@ -339,13 +337,13 @@ def run_falcon_demo_kv(
tt_cache_path,
nearest_32(num_input_tokens),
)
logger.info("Moved weights (all layers) to device!")
profiler.end(f"moving_to_device")

### Second prefill run without compile ###
profiler.enable()
enable_persistent_kernel_cache()

post_processor = partial(post_process)
use_cache = True
output_ids = torch.zeros(num_users, 1, dtype=torch.int64)
logger.info("Running inference prefill stage...")
time_prefill_inference = 0
Expand Down Expand Up @@ -406,8 +404,6 @@ def run_falcon_demo_kv(

generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)

profiler.disable()

### Inference run decode ###
logger.info("Running inference decode stage...")

Expand Down
Loading

0 comments on commit c3e78e9

Please sign in to comment.