Skip to content

Commit

Permalink
#0: Cleanup and add support for seq_len = 1024
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-asaigal committed Apr 18, 2024
1 parent 962fb05 commit 00a3504
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 130 deletions.
106 changes: 51 additions & 55 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def run_falcon_demo_kv(
tt_prefill_attention_mask[i].deallocate()

tt_logits[i].deallocate()
# exit(0)

synchronize_devices(devices)
logger.info("Finished 1st run prefill stage with compile!")

Expand Down Expand Up @@ -409,60 +409,56 @@ def run_falcon_demo_kv(
else:
N = 15
N_warmup = 5
for i in range(1):
kv_cache_len = num_input_tokens # This will increment by one after each decode
for output_token_index in range(N):
time_decode_inference_start = time.time()
(
tt_decode_embeddings,
tt_decode_attention_mask,
) = tt_FalconCausalLM.model_preprocessing(
"decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1
)
assert tt_decode_attention_mask is not None

tt_logits, kv_cache = tt_FalconCausalLM(
input_embeddings=tt_decode_embeddings,
llm_mode="decode",
attention_mask=tt_decode_attention_mask,
layer_past=kv_cache,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
synchronize_devices(devices)
time_decode_inference_end = time.time()
if output_token_index >= N_warmup:
time_decode_inference += time_decode_inference_end - time_decode_inference_start

logits = torch.concat([tt2torch_tensor(tt_logits[i]).squeeze(1) for i in range(num_devices)], dim=-2)

for i in range(num_devices):
tt_decode_embeddings[i].deallocate()
if tt_decode_attention_mask is not None:
tt_decode_attention_mask[i].deallocate()
tt_logits[i].deallocate()

if not perf_mode:
if greedy_sampling:
decode_ids = post_processor(logits=logits, index=...).reshape(global_batch, 1)
else:
decode_ids = top_pk_logits_efficient(logits.reshape(global_batch, -1)).reshape(global_batch, 1)

for user_id, user_decode_id in enumerate(decode_ids[:num_users]):
if user_decode_id == END_OF_TEXT:
prompt_is_done[user_id] = True
if prompt_is_done[user_id]:
decode_ids[user_id] = SPACE

if all(prompt_is_done):
break

generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

# TODO: Remove if we don't want to print per generated token
os.system("clear")
print_output_prompts(generated_ids, tokenizer, batch_size)
for output_token_index in range(N):
time_decode_inference_start = time.time()
(
tt_decode_embeddings,
tt_decode_attention_mask,
) = tt_FalconCausalLM.model_preprocessing("decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1)
assert tt_decode_attention_mask is not None

tt_logits, kv_cache = tt_FalconCausalLM(
input_embeddings=tt_decode_embeddings,
llm_mode="decode",
attention_mask=tt_decode_attention_mask,
layer_past=kv_cache,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
synchronize_devices(devices)
time_decode_inference_end = time.time()
if output_token_index >= N_warmup:
time_decode_inference += time_decode_inference_end - time_decode_inference_start

logits = torch.concat([tt2torch_tensor(tt_logits[i]).squeeze(1) for i in range(num_devices)], dim=-2)

for i in range(num_devices):
tt_decode_embeddings[i].deallocate()
if tt_decode_attention_mask is not None:
tt_decode_attention_mask[i].deallocate()
tt_logits[i].deallocate()

if not perf_mode:
if greedy_sampling:
decode_ids = post_processor(logits=logits, index=...).reshape(global_batch, 1)
else:
decode_ids = top_pk_logits_efficient(logits.reshape(global_batch, -1)).reshape(global_batch, 1)

for user_id, user_decode_id in enumerate(decode_ids[:num_users]):
if user_decode_id == END_OF_TEXT:
prompt_is_done[user_id] = True
if prompt_is_done[user_id]:
decode_ids[user_id] = SPACE

if all(prompt_is_done):
break

generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

# TODO: Remove if we don't want to print per generated token
os.system("clear")
print_output_prompts(generated_ids, tokenizer, batch_size)

logger.info("Finished inference decode stage!")
num_tokens_generated_decode = global_batch * (output_token_index - N_warmup + 1)
Expand Down
5 changes: 0 additions & 5 deletions models/demos/falcon7b/tt/falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def forward(
assert not output_attentions

layernorm_output = []
# print("=== Running Decoder Layernorm ===")
for i in range(self.num_devices):
layernorm_output.append(
tt_lib.tensor.layernorm(
Expand All @@ -112,7 +111,6 @@ def forward(
output_mem_config=self.model_config["INPUT_LAYERNORM_OUTPUT_MEMCFG"],
)
)
# print("=== Running Decoder BcastMul ===")
for i in range(self.num_devices):
layernorm_output[i] = tt_lib.tensor.bcast(
layernorm_output[i],
Expand All @@ -121,7 +119,6 @@ def forward(
tt_lib.tensor.BcastOpDim.H,
output_mem_config=self.model_config["INPUT_LAYERNORM_OUTPUT_MEMCFG"],
)
# print("=== Running Decoder BcastAdd ===")
for i in range(self.num_devices):
layernorm_output[i] = tt_lib.tensor.bcast(
layernorm_output[i],
Expand All @@ -134,7 +131,6 @@ def forward(
residual = hidden_states

# Self Attention
# print("=== Running Decoder SelfAttn ===")
attn_outputs = self.self_attn(
hidden_states=layernorm_output,
alibi=alibi,
Expand All @@ -150,7 +146,6 @@ def forward(

# MLP
# mlp will deallocate layernorm_output
# print("=== Running Decoder MLP ===")
mlp_output = self.mlp(layernorm_output)

output = []
Expand Down
3 changes: 0 additions & 3 deletions models/demos/falcon7b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,12 @@ def forward(
layer_output = layer_output[0]

# apply final norm layer
# print(" === Running Layernorm ===")
for i in range(self.num_devices):
layer_output[i] = tt_lib.tensor.layernorm(
layer_output[i],
self.layernorm_eps,
output_mem_config=self.model_config["LN_F_OUTPUT_MEMCFG"],
)
# print(" === Running BcastMul ===")
for i in range(self.num_devices):
layer_output[i] = tt_lib.tensor.bcast(
layer_output[i],
Expand All @@ -251,7 +249,6 @@ def forward(
tt_lib.tensor.BcastOpDim.H,
output_mem_config=self.model_config["LN_F_OUTPUT_MEMCFG"],
)
# print(" === Running BcastAdd ===")
for i in range(self.num_devices):
layer_output[i] = tt_lib.tensor.bcast(
layer_output[i],
Expand Down
6 changes: 5 additions & 1 deletion models/demos/t3000/falcon7b/demo_t3000.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


@pytest.mark.parametrize("perf_mode", (True,)) # Option to measure perf using max seq length (with invalid outputs)
@pytest.mark.parametrize("async_mode", (False,)) # Option to run Falcon in Async mode
@pytest.mark.parametrize("num_devices", (1, 2, 3, 4, 5, 6, 7, 8))
def test_demo_multichip(
perf_mode,
Expand All @@ -17,14 +18,17 @@ def test_demo_multichip(
get_tt_cache_path,
all_devices,
use_program_cache,
async_mode,
):
assert is_wormhole_b0(), "Multi-chip is only supported for Wormhole B0"
devices = get_devices_for_t3000(all_devices, num_devices)

for device in devices:
device.enable_async(async_mode)
return run_falcon_demo_kv(
user_input=user_input,
batch_size=32,
max_seq_len=128,
max_seq_len=1024,
model_config_strs_prefill_decode=["BFLOAT16-DRAM", "BFLOAT16-L1_SHARDED"],
model_location_generator=model_location_generator,
get_tt_cache_path=get_tt_cache_path,
Expand Down
27 changes: 16 additions & 11 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Tensor::~Tensor() {
}

void Tensor::deallocate(bool force) {
ZoneScopedN("Deallocate");
ZoneScopedN("TensorDeallocate");
if (this->tensor_attributes.use_count()) {
// Check if the attributes didn't get moved to another tensor.
// If not, we can deallocate this tensor.
Expand All @@ -92,7 +92,6 @@ void Tensor::deallocate(bool force) {
// This is a special case, where storage type cannot change for multi
// device tensors (see assert in launch_op). Hence, this only applies
// to the single device case, where metadata populated == tensor populated.
ZoneScopedN("WaitForMDPopulated");
this->wait_for_tensor_metadata_populated();
}
std::visit(
Expand All @@ -105,10 +104,11 @@ void Tensor::deallocate(bool force) {
} else if constexpr (std::is_same_v<T, DeviceStorage>) {
if (this->workers.at(0)->in_main_thread()) {
// If owned by the main thread, deallocate this tensor only from the main thread
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count.load();
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count;
if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) {
this->tensor_attributes->deallocated = true;
uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(this->workers.at(0));
// Record ref count before sending to worker
uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count();
this->workers.at(0)->push_work([force, *this] () mutable {
std::visit([force, this] (auto&& s) {
using type = std::decay_t<decltype(s)>;
Expand All @@ -123,6 +123,7 @@ void Tensor::deallocate(bool force) {
}
}, this->tensor_attributes->storage);
});
// Update ref count after sending to worker
this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count);
}
} else {
Expand All @@ -134,9 +135,11 @@ void Tensor::deallocate(bool force) {
}
} else if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
if (this->workers.at(0)->in_main_thread()) {
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count.load();
uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count;
if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) {
this->tensor_attributes->deallocated = true;
// Record ref count before sending to workers
uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count();
for (auto worker : this->workers) {
worker->push_work([force, *this, worker] () mutable {
std::visit([force, worker] (auto&& s) {
Expand All @@ -152,6 +155,8 @@ void Tensor::deallocate(bool force) {
}, this->tensor_attributes->storage);
});
}
// Update ref count after sending to workers
this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count);
}
}
} else if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
Expand Down Expand Up @@ -318,8 +323,8 @@ Tensor Tensor::to(CommandQueue & queue, const MemoryConfig & mem_config) const {
// functions running in main can get storage type without blocking
Tensor device_tensor({target_device});
// Record main thread ref count for tensors before pushing to queue.
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0));
uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0));
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();
uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count();
queue.device()->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable {
if (async_safe_tensor.storage_type() == StorageType::DEVICE) {
TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices");
Expand Down Expand Up @@ -347,8 +352,8 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const {
// functions running in main can get storage type without blocking
Tensor device_tensor({target_device});
// Record main thread ref count for tensors before pushing to queue.
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0));
uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0));
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();
uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count();
target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable {
if (async_safe_tensor.storage_type() == StorageType::DEVICE) {
TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices");
Expand All @@ -374,7 +379,7 @@ Tensor Tensor::to(DeviceMesh *device_mesh, const MemoryConfig &mem_config) const
auto workers = std::vector<Device*>(all_workers.begin(), all_workers.end());
TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)");
Tensor multi_device_tensor = Tensor(workers);
uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count(multi_device_tensor.workers.at(0));
uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count();

for (auto& target_device : workers) {
target_device->push_work([*this, multi_device_tensor, mem_config, target_device] () mutable {
Expand Down Expand Up @@ -407,7 +412,7 @@ Tensor Tensor::cpu(bool blocking) const {
}
TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)");
Tensor host_tensor({}, workers.size());
uint32_t original_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(workers.at(0));
uint32_t original_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count();
for (auto target_device : workers) {
target_device->push_work([host_tensor, blocking, target_device, *this, workers] () mutable {
TT_ASSERT(this->storage_type() == StorageType::DEVICE or this->storage_type() == StorageType::MULTI_DEVICE, "Can only use worker queue for cpu call if tensor is on device.");
Expand Down
Loading

0 comments on commit 00a3504

Please sign in to comment.