diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 98491a0a33f..a89ae87a776 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -417,7 +417,7 @@ def run_llama3_demo( dims=(3, 1) if model_args.is_galaxy else (1, -1), mesh_shape=model_args.cluster_shape, ), - )[0, 0, (decoding_pos[batch_id] - 1) % 32, :] + )[0, 0, (decoding_pos[batch_id] - 1) % 32, : model_args.vocab_size] ) ttnn.deallocate(tt_out) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 14c9bdb0cf9..429d1d25c7b 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -110,14 +110,12 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tokens = tokens.reshape(1, 1, 1, -1) S = tokens.shape[-1] - dims = (None, None) # replicate - mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) tokens = ttnn.from_torch( tokens, device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=mesh_mapper, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tokens_embd = self.embd(tokens) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) @@ -176,14 +174,11 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): assert current_pos.shape[0] == B, "Batch size mismatch" assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" - dims = (None, -1) if self.args.is_galaxy else (None, None) - mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape) - tokens = ttnn.from_torch( tokens.view(-1), device=None, dtype=ttnn.uint32, - mesh_mapper=mesh_mapper, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tokens = ttnn.unsqueeze_to_4D(tokens) @@ -228,6 +223,10 @@ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_ta tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) tt_tokens = self.embd(tokens) tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) + tt_tokens = ttnn.to_memory_config( + tt_tokens, + self.args.model_config["DECODE_RESIDUAL_MEMCFG"], + ) return tt_tokens, current_pos, tt_rot_mats, page_table def process_output_prefill(self, tt_out, last_token_idx): @@ -240,7 +239,7 @@ def process_output_prefill(self, tt_out, last_token_idx): mesh_composer=ttnn.ConcatMesh2dToTensor( self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape ), - )[0, 0, last_token_idx, :] + )[0, 0, last_token_idx, : self.vocab_size] return logits def process_output_decode(self, tt_out, B, S=1): @@ -248,13 +247,23 @@ def process_output_decode(self, tt_out, B, S=1): Input is ttnn device tensor of logits. Output is torch logits tensor """ if self.args.num_devices > 1: - tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) + if self.args.is_galaxy: + tt_out = ttnn.all_gather( + tt_out, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=self.mesh_device, + topology=ttnn.Topology.Linear, + ) + else: + tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) tt_out = ttnn.untilize(tt_out, use_multicore=True) if self.args.num_devices > 1: tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() else: tt_out = ttnn.to_torch(tt_out).float() - tt_out = tt_out[:, :, :B, :].view(B, S, -1) + tt_out = tt_out[:, :, :B, : self.vocab_size].view(B, S, -1) return tt_out def ttnn_prefill_forward(