diff --git a/.clang-format-ignore b/.clang-format-ignore index 2125ca6ce0a..7041c5d9150 100644 --- a/.clang-format-ignore +++ b/.clang-format-ignore @@ -40,10 +40,8 @@ ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp -ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.cpp ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/multi_core/all_gather_op_multi_core_new.cpp -ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp @@ -58,8 +56,6 @@ ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp -ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp -ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_helpers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_helpers.cpp index 7dbea301adf..65c4fed634c 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_ccl_helpers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_ccl_helpers.cpp @@ -58,37 +58,36 @@ TEST(CclHelpers, EriscDatamoverConfig_GetEdmHandshakeAddress_GT_0) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } + ttnn::ccl::EriscDatamoverConfig config; for (std::size_t i = 0; i < 8; i++) { - ASSERT_TRUE(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0); + ASSERT_TRUE(config.get_edm_handshake_address() > 0); } } TEST(CclHelpers, EriscDatamoverConfig_GetSemaphoresBaseAddress_GT_0) { + ttnn::ccl::EriscDatamoverConfig config; for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ttnn::ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >= - (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ttnn::ccl::EriscDatamoverConfig::handshake_location_size + - ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + config.get_semaphores_base_address(i) >= + (config.get_edm_handshake_address() + config.handshake_location_size + + config.edm_receiver_first_level_ack_source_word_size)); } } TEST(CclHelpers, EriscDatamoverConfig_GetBuffersBaseAddress_GT_0) { + ttnn::ccl::EriscDatamoverConfig config; for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= - (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ttnn::ccl::EriscDatamoverConfig::handshake_location_size + - ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + config.get_buffers_base_address(i) >= (config.get_edm_handshake_address() + config.handshake_location_size + + config.edm_receiver_first_level_ack_source_word_size)); } } TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) { + ttnn::ccl::EriscDatamoverConfig config; for (std::size_t i = 0; i < 8; i++) { ASSERT_TRUE( - ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >= - (ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() + - ttnn::ccl::EriscDatamoverConfig::handshake_location_size + - ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size)); + config.get_buffers_base_address(i) >= (config.get_edm_handshake_address() + config.handshake_location_size + + config.edm_receiver_first_level_ack_source_word_size)); } } diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp index 9d84e63c926..647d374595d 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_erisc_data_mover_with_workers.cpp @@ -16,6 +16,7 @@ #include "tt_metal/common/math.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/experimental/hal.hpp" #include "tt_metal/impl/kernels/kernel.hpp" #include "tt_metal/test_utils/comparison.hpp" #include "tt_metal/test_utils/df/df.hpp" @@ -31,6 +32,7 @@ using namespace tt; using namespace tt::test_utils; using namespace tt::test_utils::df; +using namespace tt::tt_metal::experimental; // Taken from ccl_common... some dependency annoyance to deal with so just copying it here for now... resolve before // merging @@ -378,7 +380,7 @@ bool RunWriteBWTest( tt_metal::detail::WriteToBuffer(buffer_id, all_zeros); } - uint32_t erisc_handshake_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + uint32_t erisc_handshake_address = hal::get_erisc_l1_unreserved_base(); uint32_t chip0_next_buffer_address = erisc_handshake_address + 16; std::vector chip0_edm_args = {erisc_handshake_address}; diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index 8a86411c840..34bdf9844b1 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -432,17 +432,18 @@ bool RunLoopbackTest( auto const& worker_core = worker_cores.at(0); log_trace(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y); - std::vector const& edm_termination_infos = + const auto& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); + const std::vector& edm_termination_infos = enable_persistent_fabric ? std::vector{} : std::vector{ {1, sender_device->ethernet_core_from_logical_core(eth_receiver_core).x, sender_device->ethernet_core_from_logical_core(eth_receiver_core).y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}, + chip_0_edm_builder.config.termination_signal_address}, {0, sender_device->ethernet_core_from_logical_core(eth_sender_core).x, sender_device->ethernet_core_from_logical_core(eth_sender_core).y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}}; + chip_0_edm_builder.config.termination_signal_address}}; TT_ASSERT( (enable_persistent_fabric && edm_termination_infos.size() == 0) || diff --git a/tt_metal/experimental/hal.cpp b/tt_metal/experimental/hal.cpp index 9590d910f89..c748f34e2ed 100644 --- a/tt_metal/experimental/hal.cpp +++ b/tt_metal/experimental/hal.cpp @@ -6,6 +6,7 @@ #include "tt_metal/experimental/hal.hpp" #include "tt_metal/llrt/hal.hpp" +#include using tt::tt_metal::HalL1MemAddrType; using tt::tt_metal::HalMemType; @@ -24,4 +25,20 @@ uint32_t get_l1_alignment() { return HalSingleton::getInstance().get_alignment(H uint32_t get_pcie_alignment() { return HalSingleton::getInstance().get_alignment(HalMemType::HOST); } +uint32_t get_erisc_l1_unreserved_base() { + auto& hal = HalSingleton::getInstance(); + if (hal.get_arch() != tt::ARCH::GRAYSKULL) { + return hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::UNRESERVED); + } + return 0; +} + +uint32_t get_erisc_l1_unreserved_size() { + auto& hal = HalSingleton::getInstance(); + if (hal.get_arch() != tt::ARCH::GRAYSKULL) { + return hal.get_dev_size(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::UNRESERVED); + } + return 0; +} + } // namespace tt::tt_metal::experimental::hal diff --git a/tt_metal/experimental/hal.hpp b/tt_metal/experimental/hal.hpp index 30afd86a85a..3d9b4108913 100644 --- a/tt_metal/experimental/hal.hpp +++ b/tt_metal/experimental/hal.hpp @@ -36,4 +36,20 @@ uint32_t get_l1_alignment(); */ uint32_t get_pcie_alignment(); +/** + * @brief Uses the hardware abstraction layer to inform client of architecture specific address. + * this address corresponds to the beginning of free space in the ERISC's L1 SRAM + * + * @return address + */ +uint32_t get_erisc_l1_unreserved_base(); + +/** + * @brief Uses the hardware abstraction layer to inform client of architecture specific size. + * this size corresponds to the total free space in the ERISC's L1 SRAM for host usage + * + * @return size in bytes + */ +uint32_t get_erisc_l1_unreserved_size(); + } // namespace tt::tt_metal::experimental::hal diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 54c6b3540a6..7c28a33ee80 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -6,13 +6,15 @@ #include "ttnn/operations/math.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/experimental/hal.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" -#include "eth_l1_address_map.h" #include "ttnn/cpp/ttnn/operations/copy.hpp" +using namespace tt::tt_metal::experimental; + namespace ttnn { namespace ccl { namespace all_gather_detail { @@ -50,8 +52,7 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con return AllGatherBidirectionalMode::FULL_TENSOR; } - std::size_t eth_l1_capacity = - eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + std::size_t eth_l1_capacity = hal::get_erisc_l1_unreserved_size(); std::size_t tensor_size_bytes = input_tensor.volume() * input_tensor.element_size(); // This is currently a guestimate. We need a lot more hard data to identify where this dividing line is. bool perf_degradation_from_full_tensor_mode = tensor_size_bytes > (2 * eth_l1_capacity); @@ -62,8 +63,8 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con } AllGatherConfig::AllGatherConfig( - Tensor const& input_tensor, - Tensor const& output_tensor, + const Tensor& input_tensor, + const Tensor& output_tensor, uint32_t dim, uint32_t ring_size, uint32_t num_links, @@ -75,7 +76,7 @@ AllGatherConfig::AllGatherConfig( semaphore_size(32), ring_size(ring_size), - erisc_handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)), + erisc_handshake_address(tt::round_up(hal::get_erisc_l1_unreserved_base(), 16)), topology(topology), enable_bidirectional(topology == ttnn::ccl::Topology::Ring), @@ -86,8 +87,8 @@ AllGatherConfig::AllGatherConfig( enable_merged_payload_and_channel_sync(true), num_edm_buffers_per_channel(num_edm_buffers_per_channel) { TT_FATAL(num_edm_buffers_per_channel > 0, "num_edm_buffers_per_channel must be > 0"); - TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE); - TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16); + TT_ASSERT(erisc_handshake_address >= hal::get_erisc_l1_unreserved_base()); + TT_ASSERT(erisc_handshake_address < hal::get_erisc_l1_unreserved_base() + 16); TT_ASSERT((erisc_handshake_address & (16 - 1)) == 0); if (input_tensor.get_layout() == Layout::TILE && dim != 3) { // See issue #6448 @@ -106,8 +107,7 @@ AllGatherConfig::AllGatherConfig( (topology == ttnn::ccl::Topology::Ring && bidirectional_mode != AllGatherBidirectionalMode::FULL_TENSOR) ? 1 : 2; - constexpr uint32_t total_l1_buffer_space = - eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + uint32_t total_l1_buffer_space = hal::get_erisc_l1_unreserved_size(); this->is_sharded = input_tensor.is_sharded(); if (user_defined_num_workers.has_value()) { @@ -117,9 +117,10 @@ AllGatherConfig::AllGatherConfig( (this->enable_bidirectional ? 8 /*1*/ : (topology != ttnn::ccl::Topology::Linear ? 8 : 4)); } + constexpr std::int32_t MAX_NUM_CONCURRENT_TRANSACTIONS = 8; if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) { - this->num_eth_buffers = std::min( - this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions); + this->num_eth_buffers = + std::min(this->num_eth_buffers, MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions); } this->num_workers_per_link = this->num_eth_buffers; @@ -145,8 +146,7 @@ AllGatherConfig::AllGatherConfig( total_l1_buffer_space, "Error"); TT_FATAL( - eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= - eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS, + eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= MAX_NUM_CONCURRENT_TRANSACTIONS, "Error"); } diff --git a/ttnn/cpp/ttnn/operations/ccl/barrier/device/host/barrier_full_worker_grid.cpp b/ttnn/cpp/ttnn/operations/ccl/barrier/device/host/barrier_full_worker_grid.cpp index ff308938ec5..062781a7de5 100644 --- a/ttnn/cpp/ttnn/operations/ccl/barrier/device/host/barrier_full_worker_grid.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/barrier/device/host/barrier_full_worker_grid.cpp @@ -12,12 +12,14 @@ #include "ttnn/operations/ccl/ccl_common.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/experimental/hal.hpp" #include "tt_metal/impl/buffers/circular_buffer_types.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_utils.hpp" using namespace tt::tt_metal; +using namespace tt::tt_metal::experimental; namespace ttnn::ccl::barrier::detail { @@ -51,16 +53,15 @@ static std::tuple, std::array, std::array< CoreCoord const& sem_init_core) { const uint32_t worker_sem0 = CreateSemaphore(program, sem_init_core, 0, CoreType::WORKER); const uint32_t worker_sem1 = CreateSemaphore(program, sem_init_core, 0, CoreType::WORKER); - constexpr uint32_t start_semaphore_address = - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + EriscDatamoverConfig::eth_word_size_bytes; - constexpr uint32_t erisc_semaphore_address = - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + (EriscDatamoverConfig::eth_word_size_bytes * 2); - constexpr uint32_t erisc_buffer_address = - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + (EriscDatamoverConfig::eth_word_size_bytes * 3); + uint32_t start_semaphore_address = hal::get_erisc_l1_unreserved_base() + EriscDatamoverConfig::eth_word_size_bytes; + uint32_t erisc_semaphore_address = + hal::get_erisc_l1_unreserved_base() + (EriscDatamoverConfig::eth_word_size_bytes * 2); + uint32_t erisc_buffer_address = + hal::get_erisc_l1_unreserved_base() + (EriscDatamoverConfig::eth_word_size_bytes * 3); const std::array receiver_rt_args = { static_cast(is_starting_core ? 1 : 0), - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, + hal::get_erisc_l1_unreserved_base(), static_cast(device->ethernet_core_from_logical_core(eth_sender_core).x), static_cast(device->ethernet_core_from_logical_core(eth_sender_core).y), erisc_semaphore_address, @@ -70,8 +71,8 @@ static std::tuple, std::array, std::array< static_cast(device->virtual_core_from_logical_core(sem_init_core, CoreType::WORKER).y), worker_sem0}; const std::array sender_rt_args = { - static_cast(is_starting_core ? 1 : 0), // is_ring_start - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, // handshake_addr + static_cast(is_starting_core ? 1 : 0), // is_ring_start + hal::get_erisc_l1_unreserved_base(), // handshake_addr erisc_buffer_address, erisc_semaphore_address, static_cast(device->virtual_core_from_logical_core(sem_init_core, CoreType::WORKER).x), diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 980833f5d49..931c1429764 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -326,16 +326,17 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( std::size_t num_buffers_per_channel, ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, ccl::EriscDataMoverTerminationMode termination_mode) { + ccl::EriscDatamoverConfig config; TT_ASSERT(num_channels > 0); std::vector edm_sem_addresses(num_channels, 0); std::vector edm_buffer_addresses(num_channels, 0); - uint32_t edm_sem_addr = ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels); - uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels); + uint32_t edm_sem_addr = config.get_semaphores_base_address(num_channels); + uint32_t edm_buffer_addr = config.get_buffers_base_address(num_channels); TT_ASSERT(edm_sem_addr > 0); TT_ASSERT(edm_buffer_addr > 0); const uint32_t channel_buffer_size = - ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size); + config.compute_buffer_size(num_channels, num_buffers_per_channel, page_size); for (std::size_t c = 0; c < num_channels; ++c) { edm_sem_addresses.at(c) = edm_sem_addr; edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size; @@ -352,7 +353,7 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( return ccl::EriscDatamoverBuilder( channel_buffer_size, - ccl::EriscDatamoverConfig::get_edm_handshake_address(), + config.get_edm_handshake_address(), edm_sem_addresses, edm_buffer_addresses, buffer_sharing_mode, diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp index 340d8532256..1cd5377edd6 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp @@ -4,7 +4,7 @@ #pragma once -#include "eth_l1_address_map.h" +#include "tt_metal/experimental/hal.hpp" #include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp" #include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/ccl_host_types.hpp" @@ -15,9 +15,8 @@ namespace ttnn { namespace ccl { struct EriscDatamoverConfig { - static constexpr std::size_t total_l1_buffer_space = - eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - static constexpr std::size_t usable_l1_base_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + std::size_t total_l1_buffer_space = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_size(); + std::size_t usable_l1_base_address = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_base(); static constexpr std::size_t semaphore_size = 32; static constexpr std::size_t handshake_location_size = 16; // ethernet word size @@ -32,14 +31,14 @@ struct EriscDatamoverConfig { static constexpr std::size_t eth_word_size_bytes = 16; static constexpr bool enable_merged_payload_and_channel_sync = true; static std::size_t get_eth_channel_sync_size_bytes(); - static uint32_t get_edm_handshake_address(); + uint32_t get_edm_handshake_address(); static std::size_t get_semaphores_region_size(std::size_t num_edm_channels); static std::size_t get_semaphores_region_start_offset(std::size_t num_edm_channels); - static uint32_t get_semaphores_base_address(std::size_t num_edm_channels); + uint32_t get_semaphores_base_address(std::size_t num_edm_channels); static uint32_t get_buffers_region_start_offset(std::size_t num_edm_channels); static std::size_t get_eth_word_size(); - static uint32_t get_buffers_base_address(std::size_t num_edm_channels); - static uint32_t compute_buffer_size( + uint32_t get_buffers_base_address(std::size_t num_edm_channels); + uint32_t compute_buffer_size( std::size_t num_edm_channels, std::size_t num_buffers_per_channel = 1, uint32_t page_size = eth_word_size_bytes); diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp index e3ed7dc9d98..7b89e0d4849 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp @@ -6,7 +6,6 @@ #include "common/math.hpp" #include "erisc_datamover_builder.hpp" -#include "eth_l1_address_map.h" #include "sub_device/sub_device_types.hpp" #include "tt_metal/common/assert.hpp" #include "ttnn/operations/ccl/ccl_common.hpp" @@ -19,11 +18,15 @@ #include "tt_metal/detail/tt_metal.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" +#include "tt_metal/experimental/hal.hpp" #include #include #include #include + +using namespace tt::tt_metal::experimental; + namespace ttnn::ccl { @@ -41,6 +44,7 @@ namespace ttnn::ccl { FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size) { + TT_FATAL(sender_channel_1_buffer_index_address != sender_channel_0_buffer_index_address, "FabricEriscDatamoverConfig was constructed with illegal buffer index address"); const size_t min_buffer_size = sizeof(tt::fabric::PacketHeader) + 2 * FabricEriscDatamoverConfig::eth_channel_sync_size; TT_FATAL(channel_buffer_size_bytes >= min_buffer_size, "FabricEriscDatamoverConfig was constructed with `channel_buffer_size_bytes` argument set smaller than minimum size of {}", min_buffer_size); const std::size_t channel_buffer_size_with_channel_sync = @@ -81,7 +85,7 @@ FabricEriscDatamoverConfig::FabricEriscDatamoverConfig( this->available_channel_buffering_space, "Internal error when computing channel sizes. Total channel size exceeds available space"); TT_FATAL( this->receiver_channel_base_address + this->receiver_channel_size_bytes < - eth_l1_mem::address_map::MAX_L1_LOADING_SIZE, "Internal error - channel buffers spilled past the end of usable L1 region."); + this->max_l1_loading_size, "Internal error - channel buffers spilled past the end of usable L1 region."); } void get_runtime_args_for_edm_termination_infos(std::vector const& edm_termination_infos, std::vector& args_out) { @@ -164,7 +168,7 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( config(config), my_chip_id(my_chip_id), peer_chip_id(peer_chip_id), - handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, FabricEriscDatamoverConfig::eth_channel_sync_size)), + handshake_address(tt::round_up(hal::get_erisc_l1_unreserved_base(), FabricEriscDatamoverConfig::eth_channel_sync_size)), channel_buffer_size(config.channel_buffer_size_bytes), sender_0_num_buffers(config.sender_0_num_buffers), sender_1_num_buffers(config.sender_1_num_buffers), @@ -179,17 +183,17 @@ FabricEriscDatamoverBuilder::FabricEriscDatamoverBuilder( sender_channel_0_buffer_index_semaphore_id(sender_channel_0_buffer_index_semaphore_id), sender_channel_1_buffer_index_semaphore_id(sender_channel_1_buffer_index_semaphore_id), - receiver_channel_local_buffer_index_address(FabricEriscDatamoverConfig::receiver_channel_local_buffer_index_address), + receiver_channel_local_buffer_index_address(config.receiver_channel_local_buffer_index_address), local_sender_channel_0_buffer_address(config.sender_0_channel_base_address), local_sender_channel_0_connection_info_addr( - FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address), + config.sender_channel_0_worker_connection_info_address), local_sender_channel_1_buffer_address(config.sender_1_channel_base_address), local_sender_channel_1_connection_info_addr( - FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address), + config.sender_channel_1_worker_connection_info_address), local_receiver_channel_buffer_address(config.receiver_channel_base_address), - termination_signal_ptr(FabricEriscDatamoverConfig::termination_signal_address), + termination_signal_ptr(config.termination_signal_address), enable_persistent_mode(enable_persistent_mode), build_in_worker_connection_mode(build_in_worker_connection_mode) {} @@ -214,9 +218,9 @@ std::vector FabricEriscDatamoverBuilder::get_compile_time_args() const this->receiver_num_buffers, config.sender_0_channel_base_address, - FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + config.sender_channel_0_worker_connection_info_address, config.sender_1_channel_base_address, - FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + config.sender_channel_1_worker_connection_info_address, config.receiver_channel_base_address, config.receiver_channel_base_address, @@ -259,11 +263,11 @@ FabricEriscDatamoverBuilder FabricEriscDatamoverBuilder::build( bool build_in_worker_connection_mode) { if (enable_persistent_mode) { auto sender_channel_0_buffer_index_semaphore_address = - FabricEriscDatamoverConfig::sender_channel_0_buffer_index_semaphore_address; + config.sender_channel_0_buffer_index_semaphore_address; auto sender_channel_0_flow_control_semaphore_address = - FabricEriscDatamoverConfig::sender_channel_0_local_flow_control_semaphore_address; + config.sender_channel_0_local_flow_control_semaphore_address; auto sender_channel_0_connection_semaphore_address = - FabricEriscDatamoverConfig::sender_channel_0_connection_semaphore_address; + config.sender_channel_0_connection_semaphore_address; std::optional receiver_channel_downstream_flow_control_semaphore_address = build_in_worker_connection_mode ? 0: tt::tt_metal::CreateSemaphore(program, ethernet_core, 0, CoreType::ETH); @@ -342,7 +346,7 @@ SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_worker_ this->sender_0_num_buffers, this->sender_channel_0_flow_control_semaphore_id, this->sender_channel_0_connection_semaphore_id, - FabricEriscDatamoverConfig::sender_channel_0_worker_connection_info_address, + this->config.sender_channel_0_worker_connection_info_address, this->config.channel_buffer_size_bytes, this->sender_channel_0_buffer_index_semaphore_id, this->enable_persistent_mode @@ -358,7 +362,7 @@ SenderWorkerAdapterSpec FabricEriscDatamoverBuilder::build_connection_to_fabric_ this->sender_1_num_buffers, this->sender_channel_1_flow_control_semaphore_id, this->sender_channel_1_connection_semaphore_id, - FabricEriscDatamoverConfig::sender_channel_1_worker_connection_info_address, + this->config.sender_channel_1_worker_connection_info_address, this->config.channel_buffer_size_bytes, this->sender_channel_1_buffer_index_semaphore_id, false @@ -635,7 +639,7 @@ std::vector EdmLineFabricOpInterface::generate_local_chi 0, edm_builder.my_noc_x, edm_builder.my_noc_y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}; + edm_builder.config.termination_signal_address}; }; std::vector edm_termination_infos; edm_termination_infos.reserve(this->num_links * 2); @@ -655,6 +659,8 @@ std::vector EdmLineFabricOpInterface::generate_local_chi } std::vector EdmLineFabricOpInterface::generate_ordered_termination_info_farthest_to_nearest() const { + static constexpr std::size_t edm_buffer_size = 4096 + sizeof(tt::fabric::PacketHeader); + static const auto config = FabricEriscDatamoverConfig(edm_buffer_size, 1, 2); TT_ASSERT(device_sequence.size() > 0); const size_t num_hops = device_sequence.size() - 1; TT_ASSERT(num_hops > 0); @@ -678,7 +684,7 @@ std::vector EdmLineFabricOpInterface::generate_ordered_t {distance_receiver, farther_edm.my_noc_x, farther_edm.my_noc_y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + config.termination_signal_address}); } for (size_t l = 0; l < this->num_links; l++) { auto &nearer_edm = nearer_edms.at(l); @@ -687,7 +693,7 @@ std::vector EdmLineFabricOpInterface::generate_ordered_t {distance_sender, nearer_edm.my_noc_x, nearer_edm.my_noc_y, - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}); + config.termination_signal_address}); } } log_trace(tt::LogOp, "Done Generating termination infos"); @@ -700,7 +706,7 @@ void FabricEriscDatamoverBuilder::teardown_from_host(IDevice*d, tt::fabric::Term d->push_work([&](){tt::tt_metal::detail::WriteToDeviceL1( d, d->logical_core_from_ethernet_core(CoreCoord(this->my_noc_x, this->my_noc_y)), - ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address, + config.termination_signal_address, val, CoreType::ETH);}, true); } diff --git a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp index 5319c194556..17f451ea73e 100644 --- a/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp @@ -8,7 +8,6 @@ #include #include -#include "eth_l1_address_map.h" #include "ttnn/distributed/types.hpp" #include "umd/device/types/cluster_descriptor_types.h" #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" @@ -17,6 +16,7 @@ #include "tt_metal/device.hpp" #include "tt_metal/impl/program/program.hpp" +#include "tt_metal/experimental/hal.hpp" #include #include @@ -33,54 +33,53 @@ struct FabricEriscDatamoverConfig { // Global static constexpr std::size_t eth_channel_sync_size = 16; - static constexpr std::size_t handshake_addr = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE/* + 1024*/; - static constexpr std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; - static constexpr std::size_t termination_signal_address = + std::size_t handshake_addr = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_base()/* + 1024*/; + std::size_t edm_channel_ack_addr = handshake_addr + eth_channel_sync_size; + std::size_t termination_signal_address = edm_channel_ack_addr + (2 * eth_channel_sync_size); // pad extra bytes to match old EDM so handshake logic will still work // ----------- Sender Channel 0 - static constexpr std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; - static constexpr std::size_t sender_channel_0_worker_connection_info_address = + std::size_t sender_channel_0_buffer_index_address = termination_signal_address + field_size; + std::size_t sender_channel_0_worker_connection_info_address = sender_channel_0_buffer_index_address + field_size; - static constexpr std::size_t sender_channel_0_local_flow_control_semaphore_address = + std::size_t sender_channel_0_local_flow_control_semaphore_address = sender_channel_0_worker_connection_info_address + field_size; // persistent mode field - static constexpr std::size_t sender_channel_0_connection_semaphore_address = + std::size_t sender_channel_0_connection_semaphore_address = sender_channel_0_local_flow_control_semaphore_address + field_size; // persistent mode field - static constexpr std::size_t sender_channel_0_buffer_index_semaphore_address = + std::size_t sender_channel_0_buffer_index_semaphore_address = sender_channel_0_connection_semaphore_address + field_size; static_assert(field_size >= sizeof(tt::fabric::EDMChannelWorkerLocationInfo)); // ----------- Sender Channel 1 - static constexpr std::size_t sender_channel_1_buffer_index_address = + std::size_t sender_channel_1_buffer_index_address = sender_channel_0_buffer_index_semaphore_address + field_size; - static constexpr std::size_t sender_channel_1_worker_connection_info_address = + std::size_t sender_channel_1_worker_connection_info_address = sender_channel_1_buffer_index_address + field_size; - static constexpr std::size_t sender_channel_1_local_flow_control_semaphore_address = + std::size_t sender_channel_1_local_flow_control_semaphore_address = sender_channel_1_worker_connection_info_address + field_size; // persistent mode field - static constexpr std::size_t sender_channel_1_connection_semaphore_address = + std::size_t sender_channel_1_connection_semaphore_address = sender_channel_1_local_flow_control_semaphore_address + field_size; // persistent mode field - static constexpr std::size_t sender_channel_1_buffer_index_semaphore_address = + std::size_t sender_channel_1_buffer_index_semaphore_address = sender_channel_1_connection_semaphore_address + field_size; // ----------- Receiver Channel - static constexpr std::size_t receiver_channel_local_buffer_index_address = + std::size_t receiver_channel_local_buffer_index_address = sender_channel_1_buffer_index_semaphore_address + field_size; // persistent mode field - static constexpr std::size_t receiver_channel_downstream_flow_control_semaphore_address = + std::size_t receiver_channel_downstream_flow_control_semaphore_address = receiver_channel_local_buffer_index_address + field_size; // Channel Allocations - static constexpr std::size_t buffer_region_start = + std::size_t max_l1_loading_size = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_size() + tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_base(); + std::size_t buffer_region_start = (receiver_channel_downstream_flow_control_semaphore_address + field_size + buffer_alignment) & ~(buffer_alignment - 1); // Align - static constexpr std::size_t available_channel_buffering_space = - eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - buffer_region_start; - - static_assert(sender_channel_1_buffer_index_address != sender_channel_0_buffer_index_address); + std::size_t available_channel_buffering_space = + max_l1_loading_size - buffer_region_start; FabricEriscDatamoverConfig( std::size_t channel_buffer_size_bytes, std::size_t sender_ratio_size, std::size_t receiver_ratio_size);