Skip to content

Commit

Permalink
add multi-iteration support to reduce scatter async (#16294)
Browse files Browse the repository at this point in the history
Add multi-iteration support to reduce scatter async via the program
override args callback
  • Loading branch information
SeanNijjar authored Dec 31, 2024
1 parent 3949130 commit 7a19dbe
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def run_reduce_scatter_test(
mem_config,
use_program_cache,
function_level_defaults,
num_iters,
enable_async=True,
num_iters=1,
topology=ttnn.Topology.Ring,
trace_mode=False,
):
Expand Down Expand Up @@ -196,6 +196,7 @@ def run_reduce_scatter_test(
subdevice_id=ttnn.SubDeviceId(0),
)
else:
logger.info(f"Running {num_iters} iterations of reduce scatter")
for i in range(num_iters):
output_tensor_mesh = ttnn.reduce_scatter_async(
input_tensor_mesh,
Expand All @@ -207,10 +208,10 @@ def run_reduce_scatter_test(
subdevice_id=worker_sub_device_id,
)

logger.info(f"Waiting for op {i}")
for device_id in mesh_device.get_device_ids():
ttnn.synchronize_device(mesh_device.get_device(device_id), sub_device_ids=[worker_sub_device_id])
logger.info(f"Done iteration {i}")
logger.info(f"Waiting for op to finish all iterations")
for device_id in mesh_device.get_device_ids():
ttnn.synchronize_device(mesh_device.get_device(device_id), sub_device_ids=[worker_sub_device_id])
logger.info(f"Done iterations")

teardown_fabric_interface(mesh_device)
# Compute golden
Expand Down Expand Up @@ -321,7 +322,7 @@ def test_line_reduce_scatter_async_post_commit(
function_level_defaults,
enable_async,
trace_mode,
num_iters=1,
num_iters=16,
):
run_reduce_scatter_test(
t3k_mesh_device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2089,8 +2089,75 @@ operation::ProgramWithCallbacks reduce_scatter_async_on_instantiated_edm_fabric(
const std::vector<Tensor>& output_tensors) {
const auto& input = input_tensors.at(0);
const auto& output = output_tensors.at(0);

auto& input_tensor = input_tensors.at(0);

auto& final_output_tensor = output_tensors.at(0);
auto& input_tensor_from_remote_forward_direction = output_tensors.at(1);
auto& input_tensor_from_remote_backward_direction = output_tensors.at(2);
std::array<const Tensor*, 2> input_tensor_from_remote = {
&input_tensor_from_remote_forward_direction, &input_tensor_from_remote_backward_direction};
auto& partial_output_tensor_forward_direction = output_tensors.at(3);
auto& partial_output_tensor_backward_direction = output_tensors.at(4);
std::array<const Tensor*, 2> partial_output_tensor = {
&partial_output_tensor_forward_direction, &partial_output_tensor_backward_direction};

auto& worker_reader_runtime_args_by_core = GetRuntimeArgs(program, kernel_ids.reader);
auto& worker_writer_runtime_args_by_core = GetRuntimeArgs(program, kernel_ids.writer);
if (topology_config.is_at_end_of_line()) {

for (auto const& core : worker_cores.final_reducers_vec) {
auto &worker_reader_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
worker_reader_runtime_args.at(0) = partial_output_tensor[LineDirection::FORWARD]->buffer()->address();;
worker_reader_runtime_args.at(1) = partial_output_tensor[LineDirection::BACKWARD]->buffer()->address();;

auto& worker_writer_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
worker_writer_runtime_args.at(0) = final_output_tensor.buffer()->address();;
}
for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) {
bool is_start_of_line = topology_config.is_first_device_in_line(direction);
for (auto const& core : worker_cores.partial_reducers_vec[direction]) {
auto &worker_reader_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
worker_reader_runtime_args.at(0) = input_tensor.buffer()->address();
if (is_start_of_line) {
worker_reader_runtime_args.at(1) = input_tensor_from_remote.at(direction)->buffer()->address();
}

auto& worker_writer_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
if (is_start_of_line) {
worker_writer_runtime_args.at(0) = input_tensor_from_remote[direction]->buffer()->address();;
} else {
worker_writer_runtime_args.at(0) = final_output_tensor.buffer()->address();;
}
}
}
} else {
for (auto direction : {LineDirection::FORWARD, LineDirection::BACKWARD}) {
for (auto const &core : worker_cores.partial_reducers_vec[direction]) {
auto &worker_reader_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
auto& worker_writer_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
worker_reader_runtime_args.at(0) = input_tensor.buffer()->address();
worker_reader_runtime_args.at(1) = input_tensor_from_remote[direction]->buffer()->address();

// input_tensor_from_remote and remote output partial result tensor share the same addresses
// because the input from remote of one chip is the partial result remote output of another
worker_writer_runtime_args.at(0) = input_tensor_from_remote[direction]->buffer()->address();
worker_writer_runtime_args.at(1) = partial_output_tensor[direction]->buffer()->address();
}
}

for (auto const& core : worker_cores.final_reducers_vec) {
auto &worker_reader_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];

worker_reader_runtime_args.at(0) = partial_output_tensor[LineDirection::FORWARD]->buffer()->address();
worker_reader_runtime_args.at(1) = partial_output_tensor[LineDirection::BACKWARD]->buffer()->address();

auto& worker_writer_runtime_args = worker_reader_runtime_args_by_core[core.x][core.y];
worker_writer_runtime_args.at(0) = final_output_tensor.buffer()->address();
}
}


};

log_trace(tt::LogOp, "Done program factory");
Expand Down

0 comments on commit 7a19dbe

Please sign in to comment.