Skip to content

Commit

Permalink
#0: Remove unused code and duplicate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 5, 2024
1 parent 82d1b9a commit c149b22
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 198 deletions.
19 changes: 1 addition & 18 deletions models/demos/llama3/demo/simple_vision_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ImageMedia, UserMessage
from models.demos.llama3.demo.tiny_demo import load_inputs

from pkg_resources import resource_filename

Expand All @@ -27,22 +26,6 @@
from models.demos.llama3.tt.generator import LlamaGenerator


def get_sampler(temperature, top_p, tokenizer):
def sample(logits):
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = llama_reference_generation.sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)

next_token = next_token.reshape(-1)
token = next_token[0].item()
text = tokenizer.decode(next_token.tolist())
return token, text

return sample


def get_batch_sampler(temperature, top_p, tokenizer):
def sample(logits):
if temperature > 0:
Expand Down Expand Up @@ -204,7 +187,7 @@ def test_llama_multimodal_demo_text(
for gen_idx in range(max_gen_len - 1):
decode_start = time.perf_counter()
position_id = prefill_lens + gen_idx
next_token_tensor = next_tokens.reshape(max_batch_size, 1) # 1, B
next_token_tensor = next_tokens.reshape(max_batch_size, 1)

if enable_trace:
logits = generator.easy_trace(
Expand Down
154 changes: 0 additions & 154 deletions models/demos/llama3/tests/multimodal/test_llama_vision_model.py

This file was deleted.

8 changes: 4 additions & 4 deletions models/demos/llama3/tt/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,7 @@ def capture_trace(
tt_full_text_mask_expand_1NSH,
tt_position_id,
tt_rope_id,
) = self.model.copy_host_to_device(
(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id)
)
) = copy_host_to_device((tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id))

trace_id = ttnn.begin_trace_capture(self.mesh_device, cq_id=0)
tt_h_trace_input = tt_h
Expand Down Expand Up @@ -431,7 +429,7 @@ def decode_forward_trace(
tokens, cross_attention_masks, full_text_row_masked_out_mask, position_id=position_id
)

self.model.copy_host_to_device(
copy_host_to_device(
host_tensors=(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id),
device_tensors=(
trace_h,
Expand Down Expand Up @@ -535,6 +533,8 @@ def generate(
prefill_len=prefill_len,
)

logits = logits.view(1, 1, self.model_args.max_vocab_size)

def sample(logits):
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
Expand Down
1 change: 0 additions & 1 deletion models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def forward_prefill(self, x_11SH, rot_mats, user_id: int = 0, page_table=None, k
else:
return output_11SH

# TODO Miguel: Remove transformation_mats input (send at initialization instead)
def forward(
self,
x,
Expand Down
7 changes: 2 additions & 5 deletions models/demos/llama3/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,8 @@ def copy_host_to_device(host_tensors, device_tensors=None, mesh_device=None):
assert mesh_device is not None, "mesh_device is required when device_tensors is None"
ret = []
for i in range(len(host_tensors)):
if host_tensors[i] is None:
ret.append(None)
else:
on_device = ttnn.to_device(host_tensors[i], device=mesh_device)
ret.append(on_device)
on_device = ttnn.to_device(host_tensors[i], device=mesh_device) if host_tensors[i] else None
ret.append(on_device)
return ret
else:
for i in range(len(host_tensors)):
Expand Down
18 changes: 2 additions & 16 deletions models/demos/llama3/tt/multimodal/llama_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_prefill_rot_mat,
get_rot_transformation_mat,
get_single_rot_mat,
copy_host_to_device,
)
from models.utility_functions import (
nearest_32,
Expand Down Expand Up @@ -386,7 +387,7 @@ def prepare_inputs_decode(
tt_position_id,
tt_rope_id,
tt_page_table,
) = self.copy_host_to_device(
) = copy_host_to_device(
(tt_h, tt_xattn_mask, tt_full_text_mask_expand_1NSH, tt_position_id, tt_rope_id, tt_page_table)
)

Expand Down Expand Up @@ -475,21 +476,6 @@ def prepare_decode_inputs_host(
page_table,
)

def copy_host_to_device(self, host_tensors, device_tensors=None):
"""
Helper function which copies host tensors to device tensors
"""
if device_tensors is None:
ret = []
for i in range(len(host_tensors)):
on_device = ttnn.to_device(host_tensors[i], device=self.mesh_device) if host_tensors[i] else None
ret.append(on_device)
return ret
else:
for i in range(len(host_tensors)):
ttnn.copy_host_to_device_tensor(host_tensors[i], device_tensors[i])
return device_tensors

def transform_decode_inputs_device(self, tt_h, tt_rope_id, tt_xattn_mask, tt_full_text_mask_expand_1NSH, B):
"""
Does any transformations on device tensors which are necessary before ttnn_decode_forward
Expand Down

0 comments on commit c149b22

Please sign in to comment.