Skip to content

Commit

Permalink
Add noc read/write burst command support to CCL command kernel. Also …
Browse files Browse the repository at this point in the history
…add automated command lowering to these noc commands (#16461)

Add two new commands to CCL command infrastructure:
- noc read burst
- noc write burst

The program factory can specify bursts of noc reads and writes by
specifying a base address and then sequences of src/dests and sizes
(source and dest depend on if it is read or write). The new commands are
only implemented in the existing reference kernel. However, future work
will be to enable a dedicated kernel that only supports noc reads/writes
that can be more easily optimized to reach peach utilization.

Additionally, added an initial command lowering pass and enabled it in
all-gather (conditionally), which lowers tensor slice commands to noc
read/write burst command streams. This eliminates all tensor iteration
and page address lookup overheads at runtime from within the kernel.
Additionally, the lowering process is done by performing a single call
in the program factory. When adopting this approach, runtime arg
overrides change and the details about which runtime args require
overriding become hidden from the user. To account for this,
infrastructure was added to automatically track which runtime args
require updates as the program is invoked over time (with different
tensors).

This also further starts to generalize command stream use and simplifies
their adoption since the user doesn't need to manually carry state
(implicitly or explicitly) through the program factory to manage runtime
argument overrides on op reruns.

Note that for the time being, the noc burst count is limited by runtime
arg counts. This limitation will be lifted in the future.

### Ticket
[Link to Github
Issue](#16395)
  • Loading branch information
SeanNijjar authored Jan 8, 2025
1 parent 924f017 commit 7764bf5
Show file tree
Hide file tree
Showing 23 changed files with 1,315 additions and 288 deletions.
90 changes: 57 additions & 33 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,27 +184,34 @@ def run_all_gather_impl(
output_mem_config = mem_config
###

if rand_tensor:
output_tensor = torch.rand(output_shape).bfloat16()
else:
output_tensor = torch.zeros(output_shape)
tile_id = 1
for w in range(output_shape[0]):
for z in range(output_shape[1]):
for y in range(0, output_shape[2], 32):
for x in range(0, output_shape[3], 32):
output_tensor[w, z, y : y + 32, x : x + 32] = tile_id
tile_id += 1

input_tensors = torch.chunk(output_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(
ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config)
)
logger.info(f"using device {mesh_device.get_devices()[i].id()}")
input_tensor_mesh_list = []
output_tensor_goldens_list = []

for i in range(num_iters):
if rand_tensor:
output_tensor = torch.rand(output_shape).bfloat16()
else:
output_tensor = torch.zeros(output_shape)
tile_id = 1
for w in range(output_shape[0]):
for z in range(output_shape[1]):
for y in range(0, output_shape[2], 32):
for x in range(0, output_shape[3], 32):
output_tensor[w, z, y : y + 32, x : x + 32] = tile_id
tile_id += 1

output_tensor_goldens_list.append(output_tensor)
input_tensors = torch.chunk(output_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(
ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], input_mem_config)
)
logger.info(f"using device {mesh_device.get_devices()[i].id()}")

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)

input_tensor_mesh_list.append(input_tensor_mesh)

compute_grid_size = mesh_device.compute_with_storage_grid_size()
worker_sub_device = ttnn.SubDevice(
Expand All @@ -220,22 +227,24 @@ def run_all_gather_impl(
mesh_device, [worker_sub_device], 0, 0, enable_persistent_fabric
)

tt_out_tensor_list = []
if trace_mode:
tt_out_tensor = run_with_trace(
mesh_device,
all_gather_topology,
input_tensor_mesh,
input_tensor_mesh_list[0],
dim,
num_links,
output_mem_config,
num_iter=num_iters,
subdevice_id=worker_sub_device_id,
)
tt_out_tensor_list.append(tt_out_tensor)
else:
for i in range(num_iters):
if use_cluster_axis_api:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor_mesh,
input_tensor_mesh_list[i],
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
Expand All @@ -249,14 +258,15 @@ def run_all_gather_impl(

else:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor_mesh,
input_tensor_mesh_list[i],
dim,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
tt_out_tensor_list.append(tt_out_tensor)

logger.info(f"Waiting for op {i}")
for d in mesh_device.get_devices():
Expand All @@ -266,17 +276,20 @@ def run_all_gather_impl(
if enable_persistent_fabric and teardown_persistent_fabric:
teardown_fabric_interface(mesh_device)

for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
logger.info(f"Checking for device {t.device().id()}")
for tensor_index in range(len(tt_out_tensor_list)):
tt_out_tensor = tt_out_tensor_list[tensor_index]
output_tensor = output_tensor_goldens_list[tensor_index]
for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
logger.info(f"Checking for device {t.device().id()}")

if input_dtype == ttnn.bfloat16:
eq, output = comp_equal(tt_output_tensor, output_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, output_tensor)
if not eq:
logger.error(f"output mismatch for tensor {i}")
assert eq, f"{i} FAILED: {output}"
if input_dtype == ttnn.bfloat16:
eq, output = comp_equal(tt_output_tensor, output_tensor)
else:
eq, output = comp_pcc(tt_output_tensor, output_tensor)
if not eq:
logger.error(f"output mismatch for tensor {i}")
assert eq, f"{i} FAILED: {output}"


# Enumerate the post-commit cases explicitly
Expand Down Expand Up @@ -386,6 +399,15 @@ def test_all_gather(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}),
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
(
4,
[1, 4, 32, 1280],
3,
ttnn.TILE_LAYOUT,
(32, 128),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 4))}),
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
),
],
)
@pytest.mark.parametrize("num_links", [1])
Expand Down Expand Up @@ -431,6 +453,8 @@ def test_all_gather_sharded(
input_shard_shape=input_shard_shape,
shard_grid=shard_grid,
tensor_mem_layout=tensor_mem_layout,
create_persistent_fabric=True,
teardown_persistent_fabric=True,
)


Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ set(CCL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/ccl_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl_host_datastructures.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/types/ccl_types_args_emitters.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/host/command_backend_runtime_args_overrider.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_command.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/uops/command_lowering.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/uops/ccl_host_commands.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_worker_builder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common/host/ccl_command_stream_builders.cpp
Expand Down
123 changes: 100 additions & 23 deletions ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <optional>
#include <variant>
#include <vector>

namespace ttnn::ccl::worker_detail {

Expand Down Expand Up @@ -571,6 +572,32 @@ size_t generate_ccl_core_descriptor_info_command_args(
return num_ccl_command_args_added;
}

static size_t generate_ccl_noc_transfer_burst_command_args(
const ttnn::ccl::cmd::HostCclCommandNocTransferBurst& noc_burst_descriptor,
size_t tensor_index,
ttnn::ccl::tensor_address_runtime_args_overrider &rt_args_overrider_out,
std::vector<uint32_t>& args_out) {
ttnn::ccl::cmd::CclCommandArgHeader hdr;
hdr.code = ttnn::ccl::cmd::CclCommandArgCode::SET_NOC_TRANSFER_BURST_START_INFO;
TT_FATAL(noc_burst_descriptor.num_transfers_total > 0, "Internal Error. num_transfers_total uninitialized when generating runtime args for noc read/write commands");
hdr.inline_value0 = noc_burst_descriptor.num_transfers_total;
// Bank base address must be set in the next arg since we may need the full 32-bit value
args_out.push_back(hdr.to_uint32());
rt_args_overrider_out.add_runtime_arg_index(tensor_index, args_out.size());
args_out.push_back(noc_burst_descriptor.bank_base_address);

for (auto const& transfer_group : noc_burst_descriptor.transfer_burst_groupings) {
args_out.push_back(transfer_group.num_transfers_per_packet);
for (auto const& transfer : transfer_group.transfer_infos) {
args_out.push_back(transfer.noc_addr & 0xFFFFFFFF);
args_out.push_back(transfer.noc_addr >> 32);
args_out.push_back(transfer.noc_transfer_size_bytes);
}
}

return 1;
}

void validate_ccl_command_dest_args(ttnn::ccl::cmd::CclCommandDestArgs const& dest_args) {
bool valid = std::holds_alternative<ttnn::ccl::cmd::UnicastCommandDestArgs>(dest_args) ||
std::holds_alternative<ttnn::ccl::cmd::MulticastCommandDestArgs>(dest_args) ||
Expand All @@ -597,8 +624,14 @@ void validate_command(ttnn::ccl::cmd::CclHostLowLevelWorkerCommand const& comman

void generate_ccl_command_stream_to_kernel_args(
std::vector<ttnn::ccl::cmd::CclHostLowLevelWorkerCommand> const& ccl_command_stream,
std::optional<size_t> tensor_index,
std::optional<std::vector<size_t>> const& tensor_indices,
ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider_out,
std::vector<uint32_t>& rt_args_out) {
std::optional<v2::TensorSlice> last_tensor_slice = std::nullopt;

bool fill_args_overrider = rt_args_overrider_out != nullptr;
TT_FATAL(!fill_args_overrider || tensor_index != std::nullopt, "Internal Error: When generating CCL command stream to kernel args, a runtime args overrider was provided but no tensor command index map was provided.");
std::optional<std::pair<ttnn::ccl::cmd::CclCommandAddrType, ttnn::ccl::cmd::CclCommandAddrArgs>>
last_src_addr_type = std::nullopt;
std::optional<std::pair<ttnn::ccl::cmd::CclCommandAddrType, ttnn::ccl::cmd::CclCommandAddrArgs>>
Expand All @@ -618,9 +651,30 @@ void generate_ccl_command_stream_to_kernel_args(
static_assert(sizeof(ttnn::ccl::cmd::CclCommandHeader) == sizeof(uint32_t));
const size_t old_rt_args_start_index = rt_args_out.size();
rt_args_out.push_back(0);

// populate the body (ccl command args)of the command
size_t num_ccl_command_args_added = 0;

// populate the src_addr_type
num_ccl_command_args_added += generate_ccl_address_info_command_args(
last_src_addr_type,
{command.source_addr_type, command.source_addr_args},
ttnn::ccl::cmd::SRC_DEST_TYPE::SRC,
rt_args_out);
last_src_addr_type = {command.source_addr_type, command.source_addr_args};

// populate the dest_addr_type
num_ccl_command_args_added += generate_ccl_address_info_command_args(
last_dest_addr_type,
{command.dest_addr_type, command.dest_addr_args},
ttnn::ccl::cmd::SRC_DEST_TYPE::DEST,
rt_args_out);
last_dest_addr_type = {command.dest_addr_type, command.dest_addr_args};

// populate the core_desc_type
num_ccl_command_args_added += generate_ccl_core_descriptor_info_command_args(
last_core_descriptor, {command.core_desc_type, command.core_desc_args}, rt_args_out);
last_core_descriptor = {command.core_desc_type, command.core_desc_args};

switch (command.command_code) {
case ttnn::ccl::cmd::CclCommandCode::STREAM_CB_TO_TENSOR:
case ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_CB: {
Expand All @@ -645,6 +699,26 @@ void generate_ccl_command_stream_to_kernel_args(
std::get<ttnn::ccl::cmd::CclCommandWaitValue>(command.command_args), rt_args_out);
break;

case ttnn::ccl::cmd::CclCommandCode::NOC_READ_BURST:
TT_FATAL(fill_args_overrider, "Internal Error: When generating noc read burst command args, an rt args override must be provided so that runtime args can be overridden on re-invocations of the owning operation");
num_ccl_command_args_added += generate_ccl_noc_transfer_burst_command_args(
std::get<ttnn::ccl::cmd::HostCclCommandNocTransferBurst>(command.command_args), tensor_indices->at(tensor_index.value()), *rt_args_overrider_out, rt_args_out);
break;

case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_BURST:
TT_FATAL(fill_args_overrider, "Internal Error: When generating noc write burst command args, an rt args override must be provided so that runtime args can be overridden on re-invocations of the owning operation");
num_ccl_command_args_added += generate_ccl_noc_transfer_burst_command_args(
std::get<ttnn::ccl::cmd::HostCclCommandNocTransferBurst>(command.command_args), tensor_indices->at(tensor_index.value()), *rt_args_overrider_out, rt_args_out);
break;

case ttnn::ccl::cmd::CclCommandCode::FLOW_CONTROLLED_NOC_READ_BURST:
TT_THROW("Command encoding support for CclCommandCode::FLOW_CONTROLLED_NOC_READ_BURST is unimplemented");
break;

case ttnn::ccl::cmd::CclCommandCode::NOC_WRITE_AND_ATOMIC_INC:
TT_THROW("Command encoding support for CclCommandCode::NOC_WRITE_AND_ATOMIC_INC is unimplemented");
break;

case ttnn::ccl::cmd::CclCommandCode::STREAM_EDM_TO_TENSOR:
TT_THROW(
"CCL command STREAM_EDM_TO_TENSOR is not useable, supported, or intended to be supported in CCL "
Expand All @@ -660,26 +734,6 @@ void generate_ccl_command_stream_to_kernel_args(
break;
}

// populate the src_addr_type
num_ccl_command_args_added += generate_ccl_address_info_command_args(
last_src_addr_type,
{command.source_addr_type, command.source_addr_args},
ttnn::ccl::cmd::SRC_DEST_TYPE::SRC,
rt_args_out);
last_src_addr_type = {command.source_addr_type, command.source_addr_args};

// populate the dest_addr_type
num_ccl_command_args_added += generate_ccl_address_info_command_args(
last_dest_addr_type,
{command.dest_addr_type, command.dest_addr_args},
ttnn::ccl::cmd::SRC_DEST_TYPE::DEST,
rt_args_out);
last_dest_addr_type = {command.dest_addr_type, command.dest_addr_args};
// populate the core_desc_type
num_ccl_command_args_added += generate_ccl_core_descriptor_info_command_args(
last_core_descriptor, {command.core_desc_type, command.core_desc_args}, rt_args_out);
last_core_descriptor = {command.core_desc_type, command.core_desc_args};

// populate the fabric_transfer_type
// Handled by header
log_trace(
Expand Down Expand Up @@ -916,6 +970,9 @@ static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence co
[&ss](CclCommandWaitValue const& a) { ss << fmt::format("(wait_value: {})", a.target_value); },
[&ss](CclCommandInlineReadWrite const& a) { ss << fmt::format("(value: {})", a.value); },
[&ss](CclCommandReadWrite const& a) { ss << fmt::format("(size_bytes: {})", a.size_bytes); },
[&ss](HostCclCommandNocTransferBurst const& a) {
ss << fmt::format("(base_addr: {}, n_transfers: {})", a.bank_base_address, a.num_transfers_total);
},
[&ss](auto const&&) { ss << "ERROR"; }},
args);
};
Expand All @@ -934,6 +991,7 @@ static void log_command_stream(ttnn::ccl::cmd::CclHostLowLevelCommandSequence co
a.noc0_end_x,
a.noc0_end_y);
},
[&ss](CclCommandCoreDescriptorTypeNone const& a) { ss << fmt::format("(None)"); },
},
args);
};
Expand Down Expand Up @@ -999,7 +1057,22 @@ void generate_multi_input_command_stream_kernel_rt_args(
std::optional<ttnn::ccl::cmd::CclHostLowLevelCommandSequence> const& ccl_command_stream1,
std::optional<ttnn::ccl::SenderWorkerAdapterSpec> const& forward_fabric_connections,
std::optional<ttnn::ccl::SenderWorkerAdapterSpec> const& backward_fabric_connections,
std::optional<std::unordered_map<const Tensor*, Device*>> const& tensor_device_override) {
std::optional<std::unordered_map<const Tensor*, Device*>> const& tensor_device_override,
std::optional<std::vector<size_t>> const& tensor_indices,
ttnn::ccl::tensor_address_runtime_args_overrider *rt_args_overrider) {

bool fill_args_overrider = rt_args_overrider != nullptr;

if (fill_args_overrider) {
TT_FATAL(tensor_indices.has_value(), "Internal Error. Tensor indices must be provided when using rt_args_overrider");
TT_FATAL(tensor_indices.value().size() == tensors.size(), "Internal Error. Tensor indices must match the number of tensors");
for (auto tensor_index : tensor_indices.value()) {
while (rt_args_overrider->size() <= tensor_index) {
rt_args_overrider->add_tensor();
}
}
}

// TODO: see if we can pull the kernel defines to understand if we built the kernel in single command stream mode
log_trace(
tt::LogOp,
Expand Down Expand Up @@ -1031,6 +1104,9 @@ void generate_multi_input_command_stream_kernel_rt_args(
rt_args.reserve(100);
for (size_t i = 0; i < tensors.size(); i++) {
if (tensors[i]) {
if (fill_args_overrider) {
rt_args_overrider->add_runtime_arg_index(tensor_indices.value()[i], rt_args.size());
}
rt_args.push_back(tensors[i]->buffer()->address());
} else {
// take up the rt arg with filler value in case user built a kernel across a core range
Expand Down Expand Up @@ -1095,7 +1171,7 @@ void generate_multi_input_command_stream_kernel_rt_args(
// Update the command stream start arg index argument to point to here (i.e. where
// this command stream's commands will start)
rt_args[command_stream_start_arg_indices[i]] = rt_args.size();
generate_ccl_command_stream_to_kernel_args((*command_streams[i]), rt_args);
generate_ccl_command_stream_to_kernel_args((*command_streams[i]), i, tensor_indices, rt_args_overrider, rt_args);
}

log_trace(tt::LogOp, "\tMulti-input command processor RT Args");
Expand All @@ -1104,6 +1180,7 @@ void generate_multi_input_command_stream_kernel_rt_args(
log_trace(tt::LogOp, "\t\t{}: {}", i, arg);
}
tt::tt_metal::SetRuntimeArgs(program, kernel_id, worker_core_range, rt_args);

}

void generate_multi_command_stream_kernel_rt_args(
Expand Down
Loading

0 comments on commit 7764bf5

Please sign in to comment.