Skip to content

Commit

Permalink
#13650: Add example of hybrid TP/DP llama-70b model on TG
Browse files Browse the repository at this point in the history
- Tile T3000 model configuration 4x on TG
  • Loading branch information
cfjchu committed Oct 10, 2024
1 parent 75390d9 commit a08973b
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 0 deletions.
44 changes: 44 additions & 0 deletions models/MODEL_HYBRID_TP_DP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Hybrid Tensor and Data Parallelism Implementation

This short guide explains how to add hybrid tensor and data parallelism to your model using submesh tiling across a larger mesh.

## Overview of Changes

The main changes involve:

1. Creating multiple submeshes from the main mesh
2. Running the model on each submesh
3. Capturing and replaying a trace across all submeshes in parallel

## Key Implementation Details

### 1. Submesh Creation

```python
# Work with submesh device as you would with a regular ttnn.MeshDevice
submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
```

### 2. Compile & Run the Model on Each Submesh

```python
# Run the model on each submesh
for submesh_device in submesh_devices:
model(..., device=submesh_device)
```

### 3. Capture & Replay the Trace

```python

# Capture Model Trace spanning all submeshes
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for submesh_device in submesh_devices:
model(..., device=submesh) # Run the model on each submesh
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)

# Execute Model Trace across all submeshes in parallel
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)

```
216 changes: 216 additions & 0 deletions models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,219 @@ def test_Llama_perf_host(
tokenizer_path,
cache_path,
)


def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel(
mesh_device,
llama_version,
batch,
seq_len,
max_context_len,
model_config,
n_layers,
n_devices,
generation_length,
expected_compile_time,
expected_inference_time,
ckpt_dir,
tokenizer_path,
cache_path,
):
# Prepare paths and devices
skip_model_load = should_skip_model_load()

logger.info(f"Running num_layer: {n_layers}")

generator = Llama.build(
ckpt_dir,
tokenizer_path,
max_seq_len=max_context_len,
max_batch_size=batch,
n_layers=1,
skip_model_load=skip_model_load,
)
hugging_face_reference_model, tokenizer = generator.model, generator.tokenizer
hugging_face_reference_model.eval()
# state_dict = hugging_face_reference_model.state_dict()
state_dict = load_llama_state_dict(ckpt_dir, n_layers=n_layers)
configuration = hugging_face_reference_model.params

# Prepare input -----------------------------------------------------------------------
torch.manual_seed(0)
total_len = min(max_context_len, generation_length + 1)
n_iters = 100 # Number of iterations to run in order to get a perf estimate
tokens = torch.randint(0, 10000, (batch, 1), dtype=torch.long)
# Clear global profiler state before starting measurements
profiler.clear()

submesh_to_metadata = defaultdict(dict)
submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring)
for submesh in submeshes:
# Set up model -----------------------------------------------------------------------
logger.info("Moving weights to devices; might take some time...")
profiler.start("TT_llama_model_setup")
tt_model = TtLlamaModel_optimized(
submesh,
state_dict,
BASE_URL,
n_layers,
model_config,
configuration,
cache_path=cache_path,
read_cache=True,
)

for i in submesh.get_device_ids():
device = submesh.get_device(i)
ttnn.synchronize_device(device)

profiler.end("TT_llama_model_setup")

##### Prepare Inputs #####
prev_pos = total_len - 1
tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens, prev_pos)
tt_inp_emb = ttnn.to_device(tt_inp_emb, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_inp_emb = tt_model.tt_embd(tt_inp_emb)
tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])

rot_mat = ttnn.to_device(rot_mat, submesh, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"])
cache_idxs = ttnn.to_device(cache_idxs, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG)

##### Compile Model #####
logger.info("Compiling model")
profiler.start(f"compile_time")
tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode")
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_logits_tensors = ttnn.get_device_tensors(tt_logits)
logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT)
logits = ttnn.to_torch(logits_rm)
profiler.end(f"compile_time")
profiler.print()
compile_iter_time = profiler.get("compile_time")
logger.info(f"decode with compile time, single iter latency: {compile_iter_time}")

submesh_to_metadata[submesh.get_mesh_id()] = {
"submesh": submesh,
"logits_rm": logits_rm,
"tt_model": tt_model,
"prev_pos": prev_pos,
"tt_inp_emb": tt_inp_emb,
"rot_mat": rot_mat,
"cache_idxs": cache_idxs,
}

##### Capture Trace #####
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)

for submesh in submeshes:
mesh_id = submesh.get_mesh_id()
tt_model = submesh_to_metadata[mesh_id]["tt_model"]
tt_inp_emb = submesh_to_metadata[mesh_id]["tt_inp_emb"]
rot_mat = submesh_to_metadata[mesh_id]["rot_mat"]
cache_idxs = submesh_to_metadata[mesh_id]["cache_idxs"]
prev_pos = submesh_to_metadata[mesh_id]["prev_pos"]

tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode")
tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_logits_tensors = ttnn.get_device_tensors(tt_logits)
logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)

##### Execute Trace #####
logger.info("Executing trace")
profiler.start(f"end_to_end_inference")
for i in range(n_iters):
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
logits = ttnn.to_torch(logits_rm)
profiler.end(f"end_to_end_inference")
ttnn.release_trace(mesh_device, trace_id)

profiler.print()
loop_time = profiler.get("end_to_end_inference")
iter_time = loop_time / n_iters
logger.info(f"decode cached, single iter latency: {iter_time}")

comment = f"num_layers={n_layers}L_n_devices={n_devices}"

prep_perf_report(
model_name=f"{llama_version}_70b_{comment}",
batch_size=batch,
inference_and_compile_time=compile_iter_time,
inference_time=iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comment,
)

tokens_per_s_per_user = 1 / iter_time
tokens_per_s_overall = tokens_per_s_per_user * batch * len(submeshes)

logger.info(f"Time per iteration: {iter_time}")
logger.info(f"Tokens per s per user: {tokens_per_s_per_user}")
logger.info(f"Tokens per s overall: {tokens_per_s_overall}")

# assert compile_time <= expected_compile_time
assert iter_time <= expected_inference_time


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(4500)
@pytest.mark.model_perf_tg
@pytest.mark.parametrize(
"llama_version",
(("llama3"),),
)
@pytest.mark.parametrize(
"generation_length, expected_compile_time, expected_inference_time, batch, seq_len, max_context_len",
(
(32, 10000, 0.0653 + 0.01, 32, 1, 4096),
(128, 10000, 0.0655 + 0.01, 32, 1, 4096),
(2048, 10000, 0.0771 + 0.01, 32, 1, 4096),
(8192, 10000, 0.0825 + 0.01, 16, 1, 8192),
(128 * 1024, 10000, 0.0918 + 0.01, 1, 1, 128 * 1024),
),
ids=["gen32", "gen128", "gen2k", "gen8k", "gen128k"],
)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 20000000}], indirect=True)
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
def test_Llama_perf_hybrid_data_tensor_parallel(
mesh_device,
generation_length,
expected_compile_time,
expected_inference_time,
batch,
seq_len,
max_context_len,
llama_version,
use_program_cache,
n_layers=80,
n_devices=8,
):
model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env(
llama_version=llama_version,
max_batch_size=batch,
max_context_len=max_context_len,
)

check_mesh_device(mesh_device, model_config)
mesh_device.enable_async(True)

disable_compilation_reports()

run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel(
mesh_device,
llama_version,
batch,
seq_len,
max_context_len,
model_config,
n_layers,
n_devices,
generation_length,
expected_compile_time,
expected_inference_time,
ckpt_dir,
tokenizer_path,
cache_path,
)
1 change: 1 addition & 0 deletions tests/scripts/tg/run_tg_model_perf_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ run_tg_llm_tests() {

echo "LOG_METAL: Running run_t3000_llama2_70b_tests"
pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$?
pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_tg" --timeout=600 ; fail+=$?

# Merge all the generated reports
env python models/perf/merge_perf_results.py; fail+=$?
Expand Down

0 comments on commit a08973b

Please sign in to comment.