diff --git a/models/demos/t3000/llama2_70b/demo/demo.py b/models/demos/t3000/llama2_70b/demo/demo.py index 1c5ceb14b55..5e4aa62a7ef 100644 --- a/models/demos/t3000/llama2_70b/demo/demo.py +++ b/models/demos/t3000/llama2_70b/demo/demo.py @@ -50,7 +50,9 @@ def main(args): # Run decode with torch.no_grad(): - all_text = run_decode(args=args, model=model, tokenizer=tokenizer, prompt_tokens=tokenized, prompts=prompts) + for i in range(10000): + logger.info(f"Running decode stress test iteration {i}") + all_text = run_decode(args=args, model=model, tokenizer=tokenizer, prompt_tokens=tokenized, prompts=prompts) if args.output_at_end: with open( @@ -80,6 +82,7 @@ def build_generator(args): max_batch_size=args.max_batch_size, skip_model_load=args.skip_model_load, n_layers=1 if args.implementation == "tt" else args.num_layers, + # n_layers=args.num_layers, # Since I don't have weights yet, use fake weights ) state_dict = load_llama_state_dict(args.ckpt_dir, n_layers=args.num_layers) @@ -278,7 +281,7 @@ def __init__( skip_model_load=False, max_batch_size=32, num_layers=None, - max_seq_len=4096, + max_seq_len=2048, # Generation args num_tokens=128, prompts_file="models/demos/t3000/llama2_70b/demo/data/multi_prompt.json", @@ -357,8 +360,8 @@ def construct_arg(**kwargs): @pytest.mark.parametrize( "num_tokens, output_at_end, top_p, top_k, temperature", ( - (128, True, 1, 1, 1.0), - (128, True, 0.9, 10, 1.0), + (2000, True, 1, 1, 1.0), + (2000, True, 0.9, 10, 1.0), ), ids=("greedy", "sampling"), ) @@ -399,7 +402,7 @@ def test_LlamaModel_demo( for i in t3k_device_mesh.get_device_ids(): device = t3k_device_mesh.get_device(i) - device.enable_async(True) + device.enable_async(False) args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 84bb6a9de94..245deeb20de 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -81,9 +81,12 @@ def decode_forward(self, tokens: torch.Tensor, start_pos: int, *args, **kwargs): attn_mask, ) - del tt_inp_emb - del rot_mat - del attn_mask + # del tt_inp_emb + # del rot_mat + # del attn_mask + tt_inp_emb.deallocate(True) + rot_mat.deallocate(True) + attn_mask.deallocate(True) # for device in self.devices: # ttl.device.Synchronize(device) @@ -95,7 +98,8 @@ def decode_forward(self, tokens: torch.Tensor, start_pos: int, *args, **kwargs): # logits = torch.cat([tt2torch_tensor(tt_o) for tt_o in tt_logits], -1) logits = logits[..., : self.params.vocab_size].float() logits = logits.permute(2, 1, 0, 3).squeeze().unsqueeze(1) # [batch, 1, vocab_size] - del tt_logits + # del tt_logits + tt_logits.deallocate(True) return logits diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index cf10318ecf3..d0c603ac7fc 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -253,7 +253,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None): memory_config=self.model_config["DRAM_MEMCFG"], mesh_mapper=ReplicateTensorToMesh(self.device_mesh), ) - rot_mats = ttnn.to_device(rot_mats, self.device_mesh) + # rot_mats = ttnn.to_device(rot_mats, self.device_mesh) rot_mats = tt_lib.tensor.interleaved_to_sharded( rot_mats, sharded_mem_config=self.model_config["ROT_MAT_MM_IN1_MEMCFG"] @@ -275,7 +275,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None): mesh_mapper=ReplicateTensorToMesh(self.device_mesh), device=self.device_mesh, ) - attn_masks = ttnn.to_device(attn_masks, self.device_mesh) + # attn_masks = ttnn.to_device(attn_masks, self.device_mesh) repeat_shape = (1, batch, 1, 1) attn_masks = tt_lib.tensor.repeat( diff --git a/models/demos/t3000/llama2_70b/tt/model_config.py b/models/demos/t3000/llama2_70b/tt/model_config.py index dae1027a9c2..e41dc7b8702 100644 --- a/models/demos/t3000/llama2_70b/tt/model_config.py +++ b/models/demos/t3000/llama2_70b/tt/model_config.py @@ -24,7 +24,7 @@ def pretty_print_model_config(model_config): def get_model_config( - llama_version="llama3", batch=32, seq_len=1, num_devices=8, max_batch_size=32, max_context_len=4096 + llama_version="llama3", batch=32, seq_len=1, num_devices=8, max_batch_size=32, max_context_len=2048 ): llm_mode = "decode" if seq_len == 1 else "prefill" assert num_devices == 8 diff --git a/models/demos/t3000/llama3_70b/demo/demo.py b/models/demos/t3000/llama3_70b/demo/demo.py index a50ed5ebc29..87cbbcf5a56 100644 --- a/models/demos/t3000/llama3_70b/demo/demo.py +++ b/models/demos/t3000/llama3_70b/demo/demo.py @@ -48,8 +48,8 @@ @pytest.mark.parametrize( "num_tokens, output_at_end, top_p, top_k, temperature", ( - (128, True, 1, 1, 1.0), - (128, True, 0.9, 10, 1.0), + (2000, True, 1, 1, 1.0), + (2000, True, 0.9, 10, 1.0), ), ids=("greedy", "sampling"), )