From a08973b1277f5807cc081e66f8bd808e8dc61abc Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Wed, 9 Oct 2024 19:18:14 +0000 Subject: [PATCH] #13650: Add example of hybrid TP/DP llama-70b model on TG - Tile T3000 model configuration 4x on TG --- models/MODEL_HYBRID_TP_DP.md | 44 ++++ .../tests/test_llama_perf_decode.py | 216 ++++++++++++++++++ tests/scripts/tg/run_tg_model_perf_tests.sh | 1 + 3 files changed, 261 insertions(+) create mode 100644 models/MODEL_HYBRID_TP_DP.md diff --git a/models/MODEL_HYBRID_TP_DP.md b/models/MODEL_HYBRID_TP_DP.md new file mode 100644 index 000000000000..299cfdc369c7 --- /dev/null +++ b/models/MODEL_HYBRID_TP_DP.md @@ -0,0 +1,44 @@ +# Hybrid Tensor and Data Parallelism Implementation + +This short guide explains how to add hybrid tensor and data parallelism to your model using submesh tiling across a larger mesh. + +## Overview of Changes + +The main changes involve: + +1. Creating multiple submeshes from the main mesh +2. Running the model on each submesh +3. Capturing and replaying a trace across all submeshes in parallel + +## Key Implementation Details + +### 1. Submesh Creation + +```python + # Work with submesh device as you would with a regular ttnn.MeshDevice + submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) +``` + +### 2. Compile & Run the Model on Each Submesh + +```python + # Run the model on each submesh + for submesh_device in submesh_devices: + model(..., device=submesh_device) +``` + +### 3. Capture & Replay the Trace + +```python + + # Capture Model Trace spanning all submeshes + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for submesh_device in submesh_devices: + model(..., device=submesh) # Run the model on each submesh + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + + # Execute Model Trace across all submeshes in parallel + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + +``` diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index fbccd4176c32..cc9c1dac0317 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -254,3 +254,219 @@ def test_Llama_perf_host( tokenizer_path, cache_path, ) + + +def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel( + mesh_device, + llama_version, + batch, + seq_len, + max_context_len, + model_config, + n_layers, + n_devices, + generation_length, + expected_compile_time, + expected_inference_time, + ckpt_dir, + tokenizer_path, + cache_path, +): + # Prepare paths and devices + skip_model_load = should_skip_model_load() + + logger.info(f"Running num_layer: {n_layers}") + + generator = Llama.build( + ckpt_dir, + tokenizer_path, + max_seq_len=max_context_len, + max_batch_size=batch, + n_layers=1, + skip_model_load=skip_model_load, + ) + hugging_face_reference_model, tokenizer = generator.model, generator.tokenizer + hugging_face_reference_model.eval() + # state_dict = hugging_face_reference_model.state_dict() + state_dict = load_llama_state_dict(ckpt_dir, n_layers=n_layers) + configuration = hugging_face_reference_model.params + + # Prepare input ----------------------------------------------------------------------- + torch.manual_seed(0) + total_len = min(max_context_len, generation_length + 1) + n_iters = 100 # Number of iterations to run in order to get a perf estimate + tokens = torch.randint(0, 10000, (batch, 1), dtype=torch.long) + # Clear global profiler state before starting measurements + profiler.clear() + + submesh_to_metadata = defaultdict(dict) + submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) + for submesh in submeshes: + # Set up model ----------------------------------------------------------------------- + logger.info("Moving weights to devices; might take some time...") + profiler.start("TT_llama_model_setup") + tt_model = TtLlamaModel_optimized( + submesh, + state_dict, + BASE_URL, + n_layers, + model_config, + configuration, + cache_path=cache_path, + read_cache=True, + ) + + for i in submesh.get_device_ids(): + device = submesh.get_device(i) + ttnn.synchronize_device(device) + + profiler.end("TT_llama_model_setup") + + ##### Prepare Inputs ##### + prev_pos = total_len - 1 + tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens, prev_pos) + tt_inp_emb = ttnn.to_device(tt_inp_emb, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb = tt_model.tt_embd(tt_inp_emb) + tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) + + rot_mat = ttnn.to_device(rot_mat, submesh, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) + cache_idxs = ttnn.to_device(cache_idxs, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + ##### Compile Model ##### + logger.info("Compiling model") + profiler.start(f"compile_time") + tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode") + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_logits_tensors = ttnn.get_device_tensors(tt_logits) + logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT) + logits = ttnn.to_torch(logits_rm) + profiler.end(f"compile_time") + profiler.print() + compile_iter_time = profiler.get("compile_time") + logger.info(f"decode with compile time, single iter latency: {compile_iter_time}") + + submesh_to_metadata[submesh.get_mesh_id()] = { + "submesh": submesh, + "logits_rm": logits_rm, + "tt_model": tt_model, + "prev_pos": prev_pos, + "tt_inp_emb": tt_inp_emb, + "rot_mat": rot_mat, + "cache_idxs": cache_idxs, + } + + ##### Capture Trace ##### + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + + for submesh in submeshes: + mesh_id = submesh.get_mesh_id() + tt_model = submesh_to_metadata[mesh_id]["tt_model"] + tt_inp_emb = submesh_to_metadata[mesh_id]["tt_inp_emb"] + rot_mat = submesh_to_metadata[mesh_id]["rot_mat"] + cache_idxs = submesh_to_metadata[mesh_id]["cache_idxs"] + prev_pos = submesh_to_metadata[mesh_id]["prev_pos"] + + tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode") + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_logits_tensors = ttnn.get_device_tensors(tt_logits) + logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + + ##### Execute Trace ##### + logger.info("Executing trace") + profiler.start(f"end_to_end_inference") + for i in range(n_iters): + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + logits = ttnn.to_torch(logits_rm) + profiler.end(f"end_to_end_inference") + ttnn.release_trace(mesh_device, trace_id) + + profiler.print() + loop_time = profiler.get("end_to_end_inference") + iter_time = loop_time / n_iters + logger.info(f"decode cached, single iter latency: {iter_time}") + + comment = f"num_layers={n_layers}L_n_devices={n_devices}" + + prep_perf_report( + model_name=f"{llama_version}_70b_{comment}", + batch_size=batch, + inference_and_compile_time=compile_iter_time, + inference_time=iter_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments=comment, + ) + + tokens_per_s_per_user = 1 / iter_time + tokens_per_s_overall = tokens_per_s_per_user * batch * len(submeshes) + + logger.info(f"Time per iteration: {iter_time}") + logger.info(f"Tokens per s per user: {tokens_per_s_per_user}") + logger.info(f"Tokens per s overall: {tokens_per_s_overall}") + + # assert compile_time <= expected_compile_time + assert iter_time <= expected_inference_time + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(4500) +@pytest.mark.model_perf_tg +@pytest.mark.parametrize( + "llama_version", + (("llama3"),), +) +@pytest.mark.parametrize( + "generation_length, expected_compile_time, expected_inference_time, batch, seq_len, max_context_len", + ( + (32, 10000, 0.0653 + 0.01, 32, 1, 4096), + (128, 10000, 0.0655 + 0.01, 32, 1, 4096), + (2048, 10000, 0.0771 + 0.01, 32, 1, 4096), + (8192, 10000, 0.0825 + 0.01, 16, 1, 8192), + (128 * 1024, 10000, 0.0918 + 0.01, 1, 1, 128 * 1024), + ), + ids=["gen32", "gen128", "gen2k", "gen8k", "gen128k"], +) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 20000000}], indirect=True) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_Llama_perf_hybrid_data_tensor_parallel( + mesh_device, + generation_length, + expected_compile_time, + expected_inference_time, + batch, + seq_len, + max_context_len, + llama_version, + use_program_cache, + n_layers=80, + n_devices=8, +): + model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( + llama_version=llama_version, + max_batch_size=batch, + max_context_len=max_context_len, + ) + + check_mesh_device(mesh_device, model_config) + mesh_device.enable_async(True) + + disable_compilation_reports() + + run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel( + mesh_device, + llama_version, + batch, + seq_len, + max_context_len, + model_config, + n_layers, + n_devices, + generation_length, + expected_compile_time, + expected_inference_time, + ckpt_dir, + tokenizer_path, + cache_path, + ) diff --git a/tests/scripts/tg/run_tg_model_perf_tests.sh b/tests/scripts/tg/run_tg_model_perf_tests.sh index 9501e79e4236..d86a7a966888 100755 --- a/tests/scripts/tg/run_tg_model_perf_tests.sh +++ b/tests/scripts/tg/run_tg_model_perf_tests.sh @@ -4,6 +4,7 @@ run_tg_llm_tests() { echo "LOG_METAL: Running run_t3000_llama2_70b_tests" pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$? + pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_tg" --timeout=600 ; fail+=$? # Merge all the generated reports env python models/perf/merge_perf_results.py; fail+=$?