Skip to content

Commit

Permalink
#0: Add Async-Mode support for Falcon 7B:
Browse files Browse the repository at this point in the history
  - Uplift all falcon ops to use launch_op API
  - Resolve issues in async mode
  - Add async mode tests to multichip CI
  • Loading branch information
tt-asaigal committed Apr 15, 2024
1 parent 06828e3 commit 0e98d33
Show file tree
Hide file tree
Showing 22 changed files with 321 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def torch_model():
],
indirect=True,
)
@pytest.mark.parametrize(
"enable_async",
[True, False],
)
def test_falcon_attention(
device_mesh,
model_name,
Expand All @@ -75,7 +79,11 @@ def test_falcon_attention(
expected_pcc,
model_config_str,
torch_model,
enable_async,
):
for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(enable_async)

torch.manual_seed(0)
batch = device_batch_size * device_mesh.get_num_devices()
if llm_mode == "decode":
Expand Down Expand Up @@ -178,3 +186,6 @@ def test_falcon_attention(
assert_with_pcc(
pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc
)

for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(False)
72 changes: 45 additions & 27 deletions models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
],
indirect=True,
)
@pytest.mark.parametrize(
"enable_async, num_loops",
((True, 20), (False, 1)),
)
def test_falcon_causal_lm(
device_mesh,
use_program_cache,
Expand All @@ -73,7 +77,12 @@ def test_falcon_causal_lm(
num_layers,
expected_pcc,
model_config_str,
enable_async,
num_loops,
):
for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(enable_async)

torch.manual_seed(0)
batch = device_batch_size * device_mesh.get_num_devices()
if llm_mode == "decode":
Expand Down Expand Up @@ -159,35 +168,41 @@ def convert_to_ttnn(model, name):
)
# TODO: Generate embeddings and attention_mask on device
if llm_mode == "prefill":
tt_outs = []
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=seq_len
)
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
user_id=0,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=True,
)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1)
for loop in range(num_loops):
tt_outs = []
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=seq_len
)
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
user_id=0,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=True,
)
# Explicitly move tensor to host ... in async mode this is faster than calling from torch directly,
# due to parallelization of tensor shards
tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1)

elif llm_mode == "decode":
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len
)
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=True,
)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1)
tt_out = tt_out.transpose(0, 1)
for loop in range(num_loops):
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len
)
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=True,
)
tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1)
tt_out = tt_out.transpose(0, 1)

passed, pcc = assert_with_pcc(pytorch_out, tt_out.to(pytorch_out.dtype), expected_pcc)
logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}")
Expand Down Expand Up @@ -223,3 +238,6 @@ def convert_to_ttnn(model, name):
logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}")

logger.info("Falcon CausalLM Passed!")

for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(False)
11 changes: 11 additions & 0 deletions models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def torch_model():
],
indirect=True,
)
@pytest.mark.parametrize(
"enable_async",
[True, False],
)
def test_falcon_decoder(
device_mesh,
model_name,
Expand All @@ -75,7 +79,11 @@ def test_falcon_decoder(
expected_pcc,
model_config_str,
torch_model,
enable_async,
):
for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(enable_async)

torch.manual_seed(0)
batch = device_batch_size * device_mesh.get_num_devices()
if llm_mode == "decode":
Expand Down Expand Up @@ -176,3 +184,6 @@ def test_falcon_decoder(
assert_with_pcc(
pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc
)

for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(False)
11 changes: 11 additions & 0 deletions models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def torch_model():
],
indirect=True,
)
@pytest.mark.parametrize(
"enable_async",
[True, False],
)
def test_falcon_mlp(
device_mesh,
model_name,
Expand All @@ -64,7 +68,11 @@ def test_falcon_mlp(
expected_pcc,
model_config_str,
torch_model,
enable_async,
):
for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(enable_async)

torch.manual_seed(0)

configuration = transformers.FalconConfig.from_pretrained(PRETRAINED_MODEL_NAME)
Expand Down Expand Up @@ -100,3 +108,6 @@ def test_falcon_mlp(
expected_pcc,
)
logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}")

for device in device_mesh.get_device_ids():
device_mesh.get_device(device).enable_async(False)
1 change: 1 addition & 0 deletions tests/scripts/run_frequent_regressions_multi_device.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pytest tests/ttnn/unit_tests/test_multi_device.py
pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_mlp
pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_attention
pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_decoder
pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_causallm

# Llama2_70b related cached files and tests (the test should parse env variables similar to these)
export LLAMA_CKPT_DIR=/mnt/MLPerf/tt_dnn-models/llama-2/llama-2-70b-repacked/
Expand Down
15 changes: 11 additions & 4 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ void Tensor::deallocate(bool force) {
std::visit([force, worker] (auto&& s) {
using type = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<type, MultiDeviceStorage>) {
if (force or s.buffers.at(worker->id()).use_count() == 1) {
DeallocateBuffer(*(s.buffers.at(worker->id())));
if (s.num_buffers()) {
if (force or s.buffers.at(worker->id()).use_count() == 1) {
DeallocateBuffer(*(s.buffers.at(worker->id())));
}
s.buffers.at(worker->id()).reset();
}
s.buffers.at(worker->id()).reset();
}
}, this->tensor_attributes->storage);
});
Expand Down Expand Up @@ -392,7 +394,12 @@ Tensor Tensor::cpu(bool blocking) const {
else {
host_tensor.set_populated(target_device);
}
}, blocking);
});
}
if (blocking) {
for (auto target_device : workers) {
target_device->synchronize();
}
}
// Update main_thread_ref_count for tensor after pushing to queue.
this->tensor_attributes->update_main_thread_ref_count(workers.at(0), original_tensor_ref_count);
Expand Down
78 changes: 42 additions & 36 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,42 +74,48 @@ inline Tensor bcast(
BcastOpMath bcast_op,
BcastOpDim bcast_dim,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;

if (bcast_dim == BcastOpDim::W) {
TT_FATAL(input_tensor_a.get_legacy_shape()[2] == input_tensor_b.get_legacy_shape()[2]);
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else {
TT_FATAL(false, "Unsupported layout");
}
} else if (bcast_dim == BcastOpDim::H) {
TT_FATAL(input_tensor_a.get_legacy_shape()[3] == input_tensor_b.get_legacy_shape()[3]);
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(input_tensor_b.get_legacy_shape()[2] == 1 || input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT);
} else {
TT_FATAL(false, "Unsupported layout");
}
} else if (bcast_dim == BcastOpDim::HW) {
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(
input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT &&
input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(
(input_tensor_b.get_legacy_shape()[2] == 1 && input_tensor_b.get_legacy_shape()[3] == 1) ||
(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT &&
input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH));
}
}
return operation::run_with_autoformat(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config, false}, {input_tensor_a, input_tensor_b})
.at(0);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_with_autoformat(
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;
auto& input_tensor_a = input_tensors.at(0);
auto& input_tensor_b = input_tensors.at(1);
if (bcast_dim == BcastOpDim::W) {
TT_FATAL(input_tensor_a.get_legacy_shape()[2] == input_tensor_b.get_legacy_shape()[2]);
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else {
TT_FATAL(false, "Unsupported layout");
}
} else if (bcast_dim == BcastOpDim::H) {
TT_FATAL(input_tensor_a.get_legacy_shape()[3] == input_tensor_b.get_legacy_shape()[3]);
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(input_tensor_b.get_legacy_shape()[2] == 1 || input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT);
} else {
TT_FATAL(false, "Unsupported layout");
}
} else if (bcast_dim == BcastOpDim::HW) {
if (input_tensor_b.get_layout() == Layout::TILE) {
TT_FATAL(
input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT &&
input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH);
} else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) {
TT_FATAL(
(input_tensor_b.get_legacy_shape()[2] == 1 && input_tensor_b.get_legacy_shape()[3] == 1) ||
(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT &&
input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH));
}
}
return operation::run_with_autoformat(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b});
}, {input_tensor_a, input_tensor_b}, output_tensors);
return output_tensors.at(0);
}

} // namespace tt_metal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
all_cores,
tt_metal::DataMovementConfig{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines});


KernelHandle mm_kernel_in0_receiver_id = 0;
if (!in0_is_sharded) {
mm_kernel_in0_receiver_id = tt_metal::CreateKernel(
Expand Down Expand Up @@ -872,7 +871,6 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
mcast_sender,
tt_metal::DataMovementConfig{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines});


auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp",
Expand Down
6 changes: 0 additions & 6 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,6 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vector<Tenso
//unary op version tie
template<BcastOpMath OP>
Tensor tie_binop_to_unary(const Tensor& input_tensor, float value, const MemoryConfig& output_mem_config) {
if (is_multi_device_tensor(input_tensor)) {
return transform(input_tensor, [&](const Tensor& tensor) {
return tie_binop_to_unary<OP>(tensor, value, output_mem_config);
});
}

Tensor t_value = mk_tiled_scalar(value);
return bcast(input_tensor, t_value, OP, BcastOpDim::HW);
}
Expand Down
Loading

0 comments on commit 0e98d33

Please sign in to comment.