Skip to content

Commit

Permalink
#0: Update vision model to work with sharded text decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 9ad6c85 commit a8d0494
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def run_llama3_demo(
# Compile
logger.info(f"Compiling model trace...")
decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
if tt_model.args.num_devices > 1:
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
Expand All @@ -400,7 +400,7 @@ def run_llama3_demo(
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
if tt_model.args.num_devices > 1:
tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
# Prepare input for TT model
decode_input = model_args.prepare_inputs_ttnn_decode(
pt_decode_input,
model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"],
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
# Run TT model
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tests/test_llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, ensure_gc):
dtype = ttnn.bfloat8_b

mesh_device.enable_async(False) # NOCOMMIT
mesh_device.enable_async(True)

model_args = TtModelArgs(mesh_device)
model_args.n_layers = 1
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds, en
decode_input = model_args.prepare_inputs_ttnn_decode(
tt_decode_input,
# ttnn.DRAM_MEMORY_CONFIG,
model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"],
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)

# Run TT model
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,

decode_input = model_args.prepare_inputs_ttnn_decode(
tt_decode_input,
model_args.model_config["DEC_SKIP_OUTPUT_MEMCFG"],
model_args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
current_pos_tensor = ttnn.from_torch(
torch.tensor([current_pos] * batch),
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos
profiler.start(f"model_run_for_inference_{i}")

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
decode_input = ttnn.to_memory_config(decode_input, tt_model.args.model_config["DECODE_RESIDUAL_MEMCFG"])
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_rm = ttnn.untilize(tt_out, use_multicore=True)
ttnn.deallocate(tt_out)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_llama_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, e
dtype=dtype,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1),
memory_config=model_args.get_model_config()["DEC_SKIP_OUTPUT_MEMCFG"]
memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"]
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
)
Expand Down
6 changes: 3 additions & 3 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def forward_decode(
program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"],
compute_kernel_config=self.compute_kernel_config_hifi2,
memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"],
memory_config_mm=self.model_config["DEC_SKIP_OUTPUT_MEMCFG"],
memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"],
)
else:
# program config matched to output of nlp_concat_heads_decode
Expand All @@ -360,13 +360,13 @@ def forward_decode(
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=self.model_config[
"DEC_SKIP_OUTPUT_MEMCFG"
"DECODE_RESIDUAL_MEMCFG"
], # Unlike matmuls, CCL ops can reshard to any valid output sharding for free
)
ttnn.deallocate(dense_out_sharded)
return dense_out_reduced
else:
dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"])
return dense_out_sharded

def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = 0, page_table=None):
Expand Down
9 changes: 1 addition & 8 deletions models/demos/llama3/tt/llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,7 @@ def forward(
page_table=None,
) -> ttnn.Tensor:
# x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode)
# FIXME: move to sharded residuals once support for this is added
# FIXME: Currently, for decode mode, we are using DRAM intereleaved as L1 interleaved results in h being corrupted in MLP
skip_mem_cfg = (
# ttnn.DRAM_MEMORY_CONFIG
self.model_config["DEC_SKIP_OUTPUT_MEMCFG"]
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG
)
skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
assert (
x.memory_config() == skip_mem_cfg
), f"decoder input memcfg mismatch: {x.memory_config()} != {skip_mem_cfg}"
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,5 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:

# reshard to residual, no-op if already correct
if mode == "decode":
result = ttnn.to_memory_config(result, self.model_config["DEC_SKIP_OUTPUT_MEMCFG"])
result = ttnn.to_memory_config(result, self.model_config["DECODE_RESIDUAL_MEMCFG"])
return result
3 changes: 3 additions & 0 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def forward(
page_table=None,
get_last_token=-1,
):
# No-op if callers already provide the right memory config
x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"])

for layer in self.layers:
x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)

Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TtModelArgs:
"ATTN_OUTPUT",
"ATTN_W_LAYOUT",
# Decoder
"DEC_SKIP_OUTPUT",
"DECODE_RESIDUAL",
"OUTPUT_MM",
)

Expand Down Expand Up @@ -229,7 +229,7 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s
self.model_config["COMPUTE_KERNEL_CONFIG_HIFI2"] = self.compute_kernel_config_hifi2

residual_grid = self.dram_shard_core_grid_for_k(self.dim // self.num_devices)
self.model_config["DEC_SKIP_OUTPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
self.model_config["DECODE_RESIDUAL_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
self.dim // residual_grid.num_cores // self.num_devices,
Expand Down
9 changes: 4 additions & 5 deletions models/demos/llama3/tt/multimodal/llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,15 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH,

# All reduce
if self.is_multichip:
dense_out_reduced = ttnn.reduce_scatter(
output = ttnn.reduce_scatter(
output,
scatter_dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.L1_MEMORY_CONFIG,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
return dense_out_reduced
else:
return output

return ttnn.to_memory_config(output, self.model_config["DECODE_RESIDUAL_MEMCFG"])

def forward_prefill(
self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, xattn_cache, user_id, vision_tokens
Expand Down
5 changes: 5 additions & 0 deletions models/demos/llama3/tt/multimodal/llama_cross_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,15 @@ def forward(
user_id=user_id,
vision_tokens=vision_tokens,
)
# FIXME: DRAM workaround for No circular buffer with id error
attn_out = ttnn.to_memory_config(attn_out, memory_config=ttnn.DRAM_MEMORY_CONFIG)
attn_out = ttnn.mul(attn_out, ttnn.tanh(self.gate_attn))

res = ttnn.add(x_11SH, attn_out)
mlp_out = self.feed_forward(self.ffn_norm(res, mode=mode), mode=mode)
# FIXME: DRAM workaround for No circular buffer with id error
mlp_out = ttnn.to_memory_config(mlp_out, memory_config=ttnn.DRAM_MEMORY_CONFIG)

if mode == "prefill":
mlp_out = ttnn.mul(mlp_out, full_text_row_masked_out_mask_11SD)
mlp_out = ttnn.mul(mlp_out, ttnn.tanh(self.gate_ffwd))
Expand Down
4 changes: 3 additions & 1 deletion models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def prepare_inputs_decode(self, tokens, cross_attention_masks, full_text_row_mas
B=tokens.shape[0],
)

tt_h = ttnn.to_memory_config(tt_h, self.configuration.model_config["DECODE_RESIDUAL_MEMCFG"])

return (
tt_h,
tt_xattn_mask,
Expand All @@ -413,7 +415,7 @@ def prepare_decode_inputs_host(self, tokens, cross_attention_masks, full_text_ro
h = self.prepare_inputs_common(position_ids, tokens)
tt_h = self.configuration.prepare_inputs_ttnn_decode(
h,
ttnn.DRAM_MEMORY_CONFIG,
ttnn.DRAM_MEMORY_CONFIG, # L1 memory_configs are not respected for on_host tensors
on_host=True,
)

Expand Down

0 comments on commit a8d0494

Please sign in to comment.