Skip to content

Commit

Permalink
#9837: Llama demo segfault repro
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Jul 3, 2024
1 parent 3687a43 commit 04e22ad
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
13 changes: 8 additions & 5 deletions models/demos/t3000/llama2_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def main(args):

# Run decode
with torch.no_grad():
all_text = run_decode(args=args, model=model, tokenizer=tokenizer, prompt_tokens=tokenized, prompts=prompts)
for i in range(10000):
logger.info(f"Running decode stress test iteration {i}")
all_text = run_decode(args=args, model=model, tokenizer=tokenizer, prompt_tokens=tokenized, prompts=prompts)

if args.output_at_end:
with open(
Expand Down Expand Up @@ -80,6 +82,7 @@ def build_generator(args):
max_batch_size=args.max_batch_size,
skip_model_load=args.skip_model_load,
n_layers=1 if args.implementation == "tt" else args.num_layers,
# n_layers=args.num_layers, # Since I don't have weights yet, use fake weights
)

state_dict = load_llama_state_dict(args.ckpt_dir, n_layers=args.num_layers)
Expand Down Expand Up @@ -278,7 +281,7 @@ def __init__(
skip_model_load=False,
max_batch_size=32,
num_layers=None,
max_seq_len=4096,
max_seq_len=2048,
# Generation args
num_tokens=128,
prompts_file="models/demos/t3000/llama2_70b/demo/data/multi_prompt.json",
Expand Down Expand Up @@ -357,8 +360,8 @@ def construct_arg(**kwargs):
@pytest.mark.parametrize(
"num_tokens, output_at_end, top_p, top_k, temperature",
(
(128, True, 1, 1, 1.0),
(128, True, 0.9, 10, 1.0),
(2000, True, 1, 1, 1.0),
(2000, True, 0.9, 10, 1.0),
),
ids=("greedy", "sampling"),
)
Expand Down Expand Up @@ -399,7 +402,7 @@ def test_LlamaModel_demo(

for i in t3k_device_mesh.get_device_ids():
device = t3k_device_mesh.get_device(i)
device.enable_async(True)
device.enable_async(False)

args = construct_arg(
implementation=implementation,
Expand Down
12 changes: 8 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ def decode_forward(self, tokens: torch.Tensor, start_pos: int, *args, **kwargs):
attn_mask,
)

del tt_inp_emb
del rot_mat
del attn_mask
# del tt_inp_emb
# del rot_mat
# del attn_mask
tt_inp_emb.deallocate(True)
rot_mat.deallocate(True)
attn_mask.deallocate(True)

# for device in self.devices:
# ttl.device.Synchronize(device)
Expand All @@ -95,7 +98,8 @@ def decode_forward(self, tokens: torch.Tensor, start_pos: int, *args, **kwargs):
# logits = torch.cat([tt2torch_tensor(tt_o) for tt_o in tt_logits], -1)
logits = logits[..., : self.params.vocab_size].float()
logits = logits.permute(2, 1, 0, 3).squeeze().unsqueeze(1) # [batch, 1, vocab_size]
del tt_logits
# del tt_logits
tt_logits.deallocate(True)

return logits

Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None):
memory_config=self.model_config["DRAM_MEMCFG"],
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
)
rot_mats = ttnn.to_device(rot_mats, self.device_mesh)
# rot_mats = ttnn.to_device(rot_mats, self.device_mesh)

rot_mats = tt_lib.tensor.interleaved_to_sharded(
rot_mats, sharded_mem_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"]
Expand All @@ -275,7 +275,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None):
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
device=self.device_mesh,
)
attn_masks = ttnn.to_device(attn_masks, self.device_mesh)
# attn_masks = ttnn.to_device(attn_masks, self.device_mesh)

repeat_shape = (1, batch, 1, 1)
attn_masks = tt_lib.tensor.repeat(
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/llama2_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def pretty_print_model_config(model_config):


def get_model_config(
llama_version="llama3", batch=32, seq_len=1, num_devices=8, max_batch_size=32, max_context_len=4096
llama_version="llama3", batch=32, seq_len=1, num_devices=8, max_batch_size=32, max_context_len=2048
):
llm_mode = "decode" if seq_len == 1 else "prefill"
assert num_devices == 8
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/llama3_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
@pytest.mark.parametrize(
"num_tokens, output_at_end, top_p, top_k, temperature",
(
(128, True, 1, 1, 1.0),
(128, True, 0.9, 10, 1.0),
(2000, True, 1, 1, 1.0),
(2000, True, 0.9, 10, 1.0),
),
ids=("greedy", "sampling"),
)
Expand Down

1 comment on commit 04e22ad

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[stale:cglagovich/9837]

@cglagovichTT Your branch cglagovich/9837 hasn't been updated in the last 180 days and is marked as stale. It will be removed in 7 days.
If you want to keep this branch around, add new commits to this branch or protect it.

Please sign in to comment.