Skip to content

Commit

Permalink
Fix remaining minor input/output issues with TG-Llama3 vLLM integration
Browse files Browse the repository at this point in the history
Signed-off-by: Salar <[email protected]>
  • Loading branch information
skhorasganiTT committed Jan 3, 2025
1 parent 8e36303 commit 1754885
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def run_llama3_demo(
dims=(3, 1) if model_args.is_galaxy else (1, -1),
mesh_shape=model_args.cluster_shape,
),
)[0, 0, (decoding_pos[batch_id] - 1) % 32, :]
)[0, 0, (decoding_pos[batch_id] - 1) % 32, : model_args.vocab_size]
)
ttnn.deallocate(tt_out)

Expand Down
29 changes: 19 additions & 10 deletions models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,12 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag

tokens = tokens.reshape(1, 1, 1, -1)
S = tokens.shape[-1]
dims = (None, None) # replicate
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)
tokens = ttnn.from_torch(
tokens,
device=self.mesh_device,
dtype=ttnn.uint32,
layout=ttnn.ROW_MAJOR_LAYOUT,
mesh_mapper=mesh_mapper,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
tokens_embd = self.embd(tokens)
tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd)
Expand Down Expand Up @@ -176,14 +174,11 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None):
assert current_pos.shape[0] == B, "Batch size mismatch"
assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size"

dims = (None, -1) if self.args.is_galaxy else (None, None)
mesh_mapper = ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=self.args.cluster_shape)

tokens = ttnn.from_torch(
tokens.view(-1),
device=None,
dtype=ttnn.uint32,
mesh_mapper=mesh_mapper,
mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device),
)
tokens = ttnn.unsqueeze_to_4D(tokens)

Expand Down Expand Up @@ -228,6 +223,10 @@ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_ta
tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs)
tt_tokens = self.embd(tokens)
tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens)
tt_tokens = ttnn.to_memory_config(
tt_tokens,
self.args.model_config["DECODE_RESIDUAL_MEMCFG"],
)
return tt_tokens, current_pos, tt_rot_mats, page_table

def process_output_prefill(self, tt_out, last_token_idx):
Expand All @@ -240,21 +239,31 @@ def process_output_prefill(self, tt_out, last_token_idx):
mesh_composer=ttnn.ConcatMesh2dToTensor(
self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape
),
)[0, 0, last_token_idx, :]
)[0, 0, last_token_idx, : self.vocab_size]
return logits

def process_output_decode(self, tt_out, B, S=1):
"""
Input is ttnn device tensor of logits. Output is torch logits tensor
"""
if self.args.num_devices > 1:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
if self.args.is_galaxy:
tt_out = ttnn.all_gather(
tt_out,
dim=3,
num_links=2,
cluster_axis=0,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
)
else:
tt_out = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear)
tt_out = ttnn.untilize(tt_out, use_multicore=True)
if self.args.num_devices > 1:
tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float()
else:
tt_out = ttnn.to_torch(tt_out).float()
tt_out = tt_out[:, :, :B, :].view(B, S, -1)
tt_out = tt_out[:, :, :B, : self.vocab_size].view(B, S, -1)
return tt_out

def ttnn_prefill_forward(
Expand Down

0 comments on commit 1754885

Please sign in to comment.