diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp index 2fbb66aa34e..77f7d34d08a 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_tg.cpp @@ -5,6 +5,8 @@ #include "common/bfloat16.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.hpp" +#include "ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp" +#include "ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_types.hpp" #include "ttnn/cpp/ttnn/multi_device.hpp" #include "ttnn/async_runtime.hpp" #include "ttnn_multi_command_queue_fixture.hpp" @@ -102,3 +104,113 @@ TEST(TGTests, TestAllGatherDeadlock) { } ttnn::multi_device::close_device_mesh(mesh); } + +TEST(TGTests, TestReduceScatterDeadlock) { + if (not tt::Cluster::instance().is_galaxy_cluster()) { + GTEST_SKIP() << "Skipping Galaxy test, since this is not a Galaxy System"; + } + // Construct the remote devices in this cluster. TTNN Device Mesh APIs need this to be passed in. + // Do this using TT Cluster APIs, since device IDs may change in the future. + uint32_t num_devices_in_tunnel = tt::Cluster::instance().get_mmio_device_max_tunnel_depth(0); + uint32_t num_mmio_devices = tt::Cluster::instance().number_of_pci_devices(); + uint32_t cluster_tunnel_count = tt::Cluster::instance().get_mmio_device_tunnel_count(0); + TT_FATAL(num_devices_in_tunnel == 4, "Expected Galaxy to have tunnel depth of 4"); + TT_FATAL(num_mmio_devices * cluster_tunnel_count == 8, "Expected 8 tunnels in a Galaxy"); + + std::vector all_device_ids = {}; + for (uint32_t mmio_idx = 0; mmio_idx < num_mmio_devices; mmio_idx++) { + auto tunnels_from_mmio = tt::Cluster::instance().get_tunnels_from_mmio_device(mmio_idx); + for (uint32_t tunnel_idx = 0; tunnel_idx < tunnels_from_mmio.size(); tunnel_idx++) { + auto remote_devices_in_tunnel = tunnels_from_mmio.at(tunnel_idx); + all_device_ids.insert(all_device_ids.end(), remote_devices_in_tunnel.begin(), remote_devices_in_tunnel.end()); + } + } + + // Create the device mesh: Grid size is . + auto mesh = ttnn::multi_device::open_device_mesh({cluster_tunnel_count * num_mmio_devices, num_devices_in_tunnel}, all_device_ids, 0, 0, 1); + // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the + // first tunnel (forward path). + std::vector ring_devices = mesh.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = mesh.get_devices_on_column(3); // Orthogonal to tunnel .. no deadlocks + ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); + std::vector ring_devices_2 = mesh.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::reverse(ring_devices_2.begin(), ring_devices_2.end()); + ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); + std::vector ring_devices_3 = mesh.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::reverse(ring_devices_3.begin(), ring_devices_3.end()); + ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); + + ring_devices.insert(ring_devices.end(), ring_devices_1.begin(), ring_devices_1.end()); + ring_devices.insert(ring_devices.end(), ring_devices_2.begin(), ring_devices_2.end()); + ring_devices.insert(ring_devices.end(), ring_devices_3.begin(), ring_devices_3.end()); + + // Setup input data and output data containers + MemoryConfig mem_cfg = MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED, + .buffer_type = BufferType::DRAM, + .shard_spec = std::nullopt}; + ttnn::Shape shape = ttnn::Shape(Shape({1, 2, 256, 256 * ring_devices.size()})); + uint32_t buf_size_datums = 2 * 256 * 256 * 20; + uint32_t datum_size_bytes = 2; + // Output of reduce scatter is input_numel / num_devices_used_in_scatter_op + auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); + auto readback_data = std::shared_ptr(new bfloat16[buf_size_datums / ring_devices.size()]); + uint32_t scatter_dim = 3; + uint32_t outer_loops = 500; + + // Input to CCL is a tensor of 1s. The output will contain 20x (ring size) less data along the innermost dim, + // with each entry == 20 * 1 + for (int j = 0; j < buf_size_datums; j++) { + host_data[j] = bfloat16(static_cast(1)); + } + std::vector device_ids = {}; + + for (auto dev : ring_devices) { + dev->enable_program_cache(); + device_ids.push_back(dev->id()); + } + + log_info(LogTest, "Running Reduce Scatter Op for {} loops", outer_loops); + log_info(LogTest, "Devices in Ring: {}", device_ids); + // Run reduce scatter multiple times. + // For the first tunnel, send adversarial traffic that can clog the forward path, if the op is not tagged correctly. + for (int i = 0; i < outer_loops; i++) { + std::vector output_tensors = {}; + uint32_t dev_idx = 0; + if (i % 100 == 0) { + log_info(LogTest, "Running iteration {}", i); + } + for (auto& dev : ring_devices) { + auto input_buffer = ttnn::allocate_buffer_on_device(buf_size_datums * datum_size_bytes, dev, shape, DataType::BFLOAT16, Layout::TILE, mem_cfg); + auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; + Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE); + // Push inputs. + ttnn::write_buffer(0, input_tensor, {host_data}); + // Configure CCL running on this device. + uint32_t receiver_device_id = device_ids[(dev_idx + 1) % ring_devices.size()]; + uint32_t sender_device_id = device_ids[(dev_idx + ring_devices.size() - 1) % ring_devices.size()]; + auto all_gather_op = ttnn::ReduceScatter{ + ttnn::operations::binary::BinaryOpType::ADD, scatter_dim, 1, ring_devices.size(), dev_idx, receiver_device_id, sender_device_id, input_tensor.memory_config(), ttnn::all_gather_op::Topology::Ring}; + // Send CCL to this device. All CCLs will complete simultaneously. + output_tensors.push_back(ttnn::run_operation(0, all_gather_op, {input_tensor}).at(0)); + // Expose deadlock: After the CCL is sent to a device in the first tunnel, send enough data to it to backpressure prefetch_h. This will block the + // demux, which will prevent the CCL from being sent to additional chips on the tunnel. If the CCL has been tagged as having multi-device dependencies, deadlock should + // get bypassed. + // if (dev_idx < 3) { + for (int j = 0; j < 16; j++) { + ttnn::write_buffer(0, input_tensor, {host_data}); + } + // } + dev_idx++; + } + // Readback data and verify correctness. + for (auto& tensor : output_tensors) { + ASSERT_EQ(tensor.get_shape(), ttnn::Shape(Shape({1, 2, 256, 256}))); + ttnn::read_buffer(0, tensor, {readback_data}); + for (int j = 0; j < 512 * 256; j++) { + ASSERT_EQ(readback_data[j].to_float(), 20); + } + } + } + ttnn::multi_device::close_device_mesh(mesh); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp index a65df40b027..192f982a7df 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp @@ -729,6 +729,8 @@ operation::ProgramWithCallbacks reduce_scatter_with_workers( ////////////////// tt::tt_metal::Program program{}; + // Issue #10978: CCLs need to be tagged as having multi-device dependencies, when running on Galaxy. + program.capture_multi_device_dependencies(); const auto& device = local_chip_tensor.device(); auto const& topology_config =