Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized Llama 3.x perf with sharded residual #15142

Merged
merged 18 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def run_llama3_demo(
# Compile
logger.info(f"Compiling model trace...")
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
if tt_model.args.num_devices > 1:
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
Expand All @@ -399,6 +400,7 @@ def run_llama3_demo(
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
if tt_model.args.num_devices > 1:
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
Expand Down Expand Up @@ -716,7 +718,7 @@ def run_llama3_demo(
"single_layer",
],
)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 14951424, "num_command_queues": 2}], indirect=True)
@pytest.mark.parametrize("device_params", [{"trace_region_size": 23887872, "num_command_queues": 2}], indirect=True)
@pytest.mark.parametrize(
"mesh_device",
[
Expand Down
6 changes: 6 additions & 0 deletions models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class OutputEntryList:
"Cancelled"
if entry_data["status"]
in [
"Waiting",
"Running",
"Resetting",
"Initializing device",
Expand Down Expand Up @@ -844,6 +845,11 @@ def parse_output_line(line, previous_line, current_status):
speed_match = re.search(r"@ (\d+\.\d+) tok/s/user", line)
if speed_match:
speed = float(speed_match.group(1))
else:
# Check for end_to_end_inference time from perf test
latency_match = re.search(r"end_to_end_inference: (\d+\.\d+)s", line)
if latency_match:
speed = 1000 * float(latency_match.group(1)) # convert to ms

# Check for PCC information
pcc = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_llama_cross_attention_transformer_text_inference(
else:
tt_h = model_args.prepare_inputs_ttnn_decode(
h,
ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
position_ids = position_ids.reshape(1).expand(batch)
tt_position_id = ttnn.from_torch(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
# Prepare input for TT model
decode_input = model_args.prepare_inputs_ttnn_decode(
pt_decode_input,
ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
# Run TT model
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
Expand Down
3 changes: 2 additions & 1 deletion models/demos/llama3/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en

decode_input = model_args.prepare_inputs_ttnn_decode(
tt_decode_input,
ttnn.DRAM_MEMORY_CONFIG,
# ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)

# Run TT model
Expand Down
16 changes: 8 additions & 8 deletions models/demos/llama3/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,24 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
final_model_pcc = {
"llama32_1b": 0.9991,
"llama32_3b": 0.9989,
"llama31_8b": 0.99899,
"llama32_11b": 0.9976,
"llama31_70b": 0.98454,
"llama31_8b": 0.9987,
"llama32_11b": 0.9987,
"llama31_70b": 0.9843,
}[model_name]

final_k_cache_pcc = {
"llama32_1b": 0.9998,
"llama32_3b": 0.9998,
"llama31_8b": 0.99986,
"llama31_8b": 0.9998,
"llama32_11b": 0.9995,
"llama31_70b": 0.99983,
"llama31_70b": 0.9998,
}[model_name]
final_v_cache_pcc = {
"llama32_1b": 0.9996,
"llama32_3b": 0.9998,
"llama31_8b": 0.99986,
"llama31_8b": 0.9998,
"llama32_11b": 0.9996,
"llama31_70b": 0.99985,
"llama31_70b": 0.9998,
}[model_name]
quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6, "llama31_70b": 6}[
model_name
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,

decode_input = model_args.prepare_inputs_ttnn_decode(
tt_decode_input,
ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
current_pos_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch),
Expand Down
14 changes: 7 additions & 7 deletions models/demos/llama3/tests/test_llama_model_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se
# Use instruct weights instead of general weights
instruct = True

model_args = TtModelArgs(mesh_device, instruct=instruct)
model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1)
tokenizer = Tokenizer(model_args.tokenizer_path)

logger.info("Loading weights...")
Expand All @@ -84,12 +84,12 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se
prompt_file = os.path.join(current_file_dir, "tale-of-two-cities.txt.bz2")

with bz2.open(prompt_file, "rt", encoding="utf-8") as f:
prompts = f.read()
prompt = f.read()

if instruct:
encoded_prompts = [encode_prompt_llama_instruct(tokenizer, prompt) for prompt in prompts]
encoded_prompt = encode_prompt_llama_instruct(tokenizer, prompt)[:seq_len]
else:
encoded_prompts = tokenizer.encode(prompts, bos=True, eos=False)[:seq_len]
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False)[:seq_len]

if run_ref_pt:
reference_model = Transformer(model_args)
Expand Down Expand Up @@ -126,9 +126,9 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se

batch = 1

# Select the first token from the prompts for initial decoding
encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0]
pt_decode_input = embd(encoded_prompts_tensor).view(batch, seq_len, -1)
# Select the first token from the prompt for initial decoding
encoded_prompt_tensor = torch.tensor(encoded_prompt) # [:,0]
pt_decode_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1)

tt_decode_input = pt_decode_input

Expand Down
24 changes: 13 additions & 11 deletions models/demos/llama3/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,20 @@
@pytest.mark.parametrize(
"kv_cache_len, expected_compile_time",
(
(32, 20),
(128, 20),
(1024, 20),
(32, 30),
(128, 30),
(1024, 30),
),
)
@pytest.mark.parametrize(
"mesh_device",
[
{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(
os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids())
)
],
indirect=True,
)
def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_program_cache, reset_seeds, ensure_gc):
dtype = ttnn.bfloat8_b

Expand Down Expand Up @@ -162,15 +171,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos
dtype=ttnn.uint32,
)

# Generate first input on host
pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1)
# Send first input to device
tt_decode_input = pt_decode_input
decode_input = tt_model.args.prepare_inputs_ttnn_decode(
tt_decode_input,
ttnn.DRAM_MEMORY_CONFIG,
)

current_pos = ttnn.from_torch(
torch.tensor([generation_start_pos] * batch),
device=mesh_device,
Expand All @@ -183,6 +184,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos
profiler.start(f"model_run_for_inference_{i}")

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
Expand Down
4 changes: 3 additions & 1 deletion models/demos/llama3/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, e
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
memory_config=ttnn.L1_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG,
memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"]
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
)

tt_output = tt_model(tt_input, mode=mode)
Expand Down
39 changes: 4 additions & 35 deletions models/demos/llama3/tt/distributed_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,17 @@ class DistributedNorm(LightweightModule):
def __init__(self, norm, args):
self.norm = norm
self.args = args
norm_input_grid = args.dram_shard_core_grid_for_k(args.dim // args.num_devices)
self.gather_in_mem_cfg = ttnn.create_sharded_memory_config(
(
args.tile_padded_batch_rows,
args.dim // args.num_devices // norm_input_grid.num_cores,
),
norm_input_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.gather_out_mem_cfg = ttnn.create_sharded_memory_config(
(
args.tile_padded_batch_rows,
args.dim // norm_input_grid.num_cores,
),
norm_input_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)

def forward(self, x, mode):
"""Apply a norm, possibly gathering inputs if required."""
input_mem_cfg = self.norm.sharded_output_config if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG

# Distributed norm already performs a gather
if self.args.is_multichip and not self.args.is_distributed_norm(mode):
if mode == "decode":
x = ttnn.interleaved_to_sharded(x, self.gather_in_mem_cfg)
x = ttnn.all_gather(
x, dim=3, num_links=1, topology=self.args.ccl_topology(), memory_config=input_mem_cfg
)
else:
x = ttnn.all_gather(
x, dim=3, num_links=1, topology=self.args.ccl_topology(), memory_config=input_mem_cfg
)
elif mode == "decode":
# Gathered norms will be sharded for decode mode, so single-chip should be too
x = ttnn.interleaved_to_sharded(x, input_mem_cfg)

# x sharded in decode mode here
x = ttnn.all_gather(x, dim=3, num_links=1, topology=self.args.ccl_topology(), memory_config=input_mem_cfg)
else:
x = ttnn.to_memory_config(x, input_mem_cfg)

x = self.norm(x, mode=mode, in_sharded=(mode == "decode"), out_sharded=(mode == "decode"))

# Distributed norm requires a gather
Expand Down
32 changes: 17 additions & 15 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def forward_decode(
xqkv_fused,
num_heads=self.n_local_heads,
num_kv_heads=self.n_local_kv_heads,
memory_config=self.model_config["HEIGHT_SHARDED_MEMCFG"],
memory_config=ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG,
)

ttnn.deallocate(xqkv_fused)
Expand Down Expand Up @@ -310,7 +310,7 @@ def forward_decode(
scale=self.scale,
program_config=self.model_config["SDPA_DECODE_PROGCFG"],
compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"],
memory_config=ttnn.DRAM_MEMORY_CONFIG,
memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG?
)

ttnn.deallocate(q_heads_1BQD)
Expand All @@ -326,46 +326,48 @@ def forward_decode(
ttnn.deallocate(attn_output_1G4D)

if self.is_multichip and self.use_fused_all_gather_matmul:
attn_output_cat = ttnn.to_memory_config(
attn_output_cat, self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"]
)
_, dense_out_sharded, _ = ttnn.experimental.all_gather_matmul(
attn_output_cat,
self.wo,
dim=3,
all_gather_core_grid_offset=(0, 4),
num_links=1,
memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"],
memory_config_mm=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_PROGCFG"],
program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"],
compute_kernel_config=self.compute_kernel_config_hifi2,
memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"],
memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"],
)
else:
# program config matched to output of nlp_concat_heads_decode
dense_out_sharded = ttnn.linear(
attn_output_cat,
self.wo,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
program_config=self.model_config["ATTN_OUTPUT_PROGCFG"],
compute_kernel_config=self.compute_kernel_config_hifi2,
memory_config=attn_output_cat.memory_config(),
) # seqlen, 1, batch, hidden_size

ttnn.deallocate(attn_output_cat)
dense_out = ttnn.sharded_to_interleaved(
dense_out_sharded, ttnn.L1_MEMORY_CONFIG
) # TODO: remove as soon as we have sharded support in for all CCL

ttnn.deallocate(dense_out_sharded)

# All reduce
if self.is_multichip and not self.use_fused_all_gather_matmul:
dense_out_reduced = ttnn.reduce_scatter(
dense_out,
dense_out_sharded,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.L1_MEMORY_CONFIG,
memory_config=self.model_config[
"DECODE_RESIDUAL_MEMCFG"
], # Unlike matmuls, CCL ops can reshard to any valid output sharding for free
)
ttnn.deallocate(dense_out)
ttnn.deallocate(dense_out_sharded)
return dense_out_reduced
else:
return dense_out
dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"])
return dense_out_sharded

def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None):
seq_len = x_11SH.shape[-2]
Expand Down
19 changes: 6 additions & 13 deletions models/demos/llama3/tt/llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,11 @@ def forward(
mode="decode",
page_table=None,
) -> ttnn.Tensor:
# x is fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode)
# FIXME: move to sharded residuals once support for this is added
# FIXME: Currently, for decode mode, we are using DRAM intereleaved as L1 interleaved results in h being corrupted in MLP
skip_mem_cfg = (
ttnn.DRAM_MEMORY_CONFIG
# self.model_config["DEC_SKIP_OUTPUT_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
)

# x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode)
skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
assert (
x.memory_config() == skip_mem_cfg
), f"decoder input memcfg mismatch: {x.memory_config()} != {skip_mem_cfg}"
# Norms take fractured inputs and output replicated across devices
attn_in = self.attention_norm(x, mode)
# Attention takes replicated inputs and produces fractured outputs
Expand All @@ -108,18 +105,14 @@ def forward(
mode,
page_table,
)

# Here x and attn_out are both fractured across devices
h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg)

# TODO: This deallocate may cause ND output. The reason seems to be related to either the input being on DRAM/L1 and the sharded spec in MLP using 32 cores instead of 16.
# ttnn.deallocate(attn_out)
ttnn.deallocate(attn_out)

# Norms take fractured inputs and output replicated across devices
ff_in = self.ff_norm(h, mode)
# MLP takes replicated inputs and produces fractured outputs
ff_out = self.feed_forward.forward(ff_in, mode)
# ff_out and h are both fractured across devices
out = ttnn.add(h, ff_out, memory_config=skip_mem_cfg)

return out # fractured across devices
Loading
Loading