Skip to content

Commit

Permalink
Remove ARCH_NAME specific includes from erisc_datamover_builder (#16505)
Browse files Browse the repository at this point in the history
  • Loading branch information
blozano-tt authored Jan 8, 2025
1 parent 7f4eb32 commit d0bb408
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 95 deletions.
4 changes: 0 additions & 4 deletions .clang-format-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp
ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp
ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp
ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp
ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.cpp
ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp
ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/multi_core/all_gather_op_multi_core_new.cpp
ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp
ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp
ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp
ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp
Expand All @@ -58,8 +56,6 @@ ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp
ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp
ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.cpp
ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_host_commands.hpp
ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp
ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp
ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp
ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp
ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp
Expand Down
25 changes: 12 additions & 13 deletions tests/ttnn/unit_tests/gtests/ccl/test_ccl_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,37 +58,36 @@ TEST(CclHelpers, EriscDatamoverConfig_GetEdmHandshakeAddress_GT_0) {
if (arch == tt::ARCH::GRAYSKULL) {
GTEST_SKIP();
}
ttnn::ccl::EriscDatamoverConfig config;
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() > 0);
ASSERT_TRUE(config.get_edm_handshake_address() > 0);
}
}
TEST(CclHelpers, EriscDatamoverConfig_GetSemaphoresBaseAddress_GT_0) {
ttnn::ccl::EriscDatamoverConfig config;
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(
ttnn::ccl::EriscDatamoverConfig::get_semaphores_base_address(i) >=
(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() +
ttnn::ccl::EriscDatamoverConfig::handshake_location_size +
ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
config.get_semaphores_base_address(i) >=
(config.get_edm_handshake_address() + config.handshake_location_size +
config.edm_receiver_first_level_ack_source_word_size));
}
}

TEST(CclHelpers, EriscDatamoverConfig_GetBuffersBaseAddress_GT_0) {
ttnn::ccl::EriscDatamoverConfig config;
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(
ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >=
(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() +
ttnn::ccl::EriscDatamoverConfig::handshake_location_size +
ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
config.get_buffers_base_address(i) >= (config.get_edm_handshake_address() + config.handshake_location_size +
config.edm_receiver_first_level_ack_source_word_size));
}
}

TEST(CclHelpers, EriscDatamoverConfig_ComputeBufferSize_GT_0) {
ttnn::ccl::EriscDatamoverConfig config;
for (std::size_t i = 0; i < 8; i++) {
ASSERT_TRUE(
ttnn::ccl::EriscDatamoverConfig::get_buffers_base_address(i) >=
(ttnn::ccl::EriscDatamoverConfig::get_edm_handshake_address() +
ttnn::ccl::EriscDatamoverConfig::handshake_location_size +
ttnn::ccl::EriscDatamoverConfig::edm_receiver_first_level_ack_source_word_size));
config.get_buffers_base_address(i) >= (config.get_edm_handshake_address() + config.handshake_location_size +
config.edm_receiver_first_level_ack_source_word_size));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tt_metal/common/math.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/experimental/hal.hpp"
#include "tt_metal/impl/kernels/kernel.hpp"
#include "tt_metal/test_utils/comparison.hpp"
#include "tt_metal/test_utils/df/df.hpp"
Expand All @@ -31,6 +32,7 @@
using namespace tt;
using namespace tt::test_utils;
using namespace tt::test_utils::df;
using namespace tt::tt_metal::experimental;

// Taken from ccl_common... some dependency annoyance to deal with so just copying it here for now... resolve before
// merging
Expand Down Expand Up @@ -378,7 +380,7 @@ bool RunWriteBWTest(
tt_metal::detail::WriteToBuffer(buffer_id, all_zeros);
}

uint32_t erisc_handshake_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
uint32_t erisc_handshake_address = hal::get_erisc_l1_unreserved_base();

uint32_t chip0_next_buffer_address = erisc_handshake_address + 16;
std::vector<uint32_t> chip0_edm_args = {erisc_handshake_address};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,18 @@ bool RunLoopbackTest(
auto const& worker_core = worker_cores.at(0);
log_trace(tt::LogTest, "Worker {}. On Core x={},y={}", 0, worker_core.x, worker_core.y);

std::vector<ttnn::ccl::edm_termination_info_t> const& edm_termination_infos =
const auto& edm_config = ttnn::ccl::FabricEriscDatamoverConfig(edm_buffer_size, 1, 2);
const std::vector<ttnn::ccl::edm_termination_info_t>& edm_termination_infos =
enable_persistent_fabric ? std::vector<ttnn::ccl::edm_termination_info_t>{}
: std::vector<ttnn::ccl::edm_termination_info_t>{
{1,
sender_device->ethernet_core_from_logical_core(eth_receiver_core).x,
sender_device->ethernet_core_from_logical_core(eth_receiver_core).y,
ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address},
chip_0_edm_builder.config.termination_signal_address},
{0,
sender_device->ethernet_core_from_logical_core(eth_sender_core).x,
sender_device->ethernet_core_from_logical_core(eth_sender_core).y,
ttnn::ccl::FabricEriscDatamoverConfig::termination_signal_address}};
chip_0_edm_builder.config.termination_signal_address}};

TT_ASSERT(
(enable_persistent_fabric && edm_termination_infos.size() == 0) ||
Expand Down
17 changes: 17 additions & 0 deletions tt_metal/experimental/hal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "tt_metal/experimental/hal.hpp"
#include "tt_metal/llrt/hal.hpp"
#include <umd/device/types/arch.h>

using tt::tt_metal::HalL1MemAddrType;
using tt::tt_metal::HalMemType;
Expand All @@ -24,4 +25,20 @@ uint32_t get_l1_alignment() { return HalSingleton::getInstance().get_alignment(H

uint32_t get_pcie_alignment() { return HalSingleton::getInstance().get_alignment(HalMemType::HOST); }

uint32_t get_erisc_l1_unreserved_base() {
auto& hal = HalSingleton::getInstance();
if (hal.get_arch() != tt::ARCH::GRAYSKULL) {
return hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::UNRESERVED);
}
return 0;
}

uint32_t get_erisc_l1_unreserved_size() {
auto& hal = HalSingleton::getInstance();
if (hal.get_arch() != tt::ARCH::GRAYSKULL) {
return hal.get_dev_size(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::UNRESERVED);
}
return 0;
}

} // namespace tt::tt_metal::experimental::hal
16 changes: 16 additions & 0 deletions tt_metal/experimental/hal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,20 @@ uint32_t get_l1_alignment();
*/
uint32_t get_pcie_alignment();

/**
* @brief Uses the hardware abstraction layer to inform client of architecture specific address.
* this address corresponds to the beginning of free space in the ERISC's L1 SRAM
*
* @return address
*/
uint32_t get_erisc_l1_unreserved_base();

/**
* @brief Uses the hardware abstraction layer to inform client of architecture specific size.
* this size corresponds to the total free space in the ERISC's L1 SRAM for host usage
*
* @return size in bytes
*/
uint32_t get_erisc_l1_unreserved_size();

} // namespace tt::tt_metal::experimental::hal
28 changes: 14 additions & 14 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
#include "ttnn/operations/math.hpp"

#include "tt_metal/host_api.hpp"
#include "tt_metal/experimental/hal.hpp"

#include "ttnn/tensor/tensor_utils.hpp"

#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "eth_l1_address_map.h"
#include "ttnn/cpp/ttnn/operations/copy.hpp"

using namespace tt::tt_metal::experimental;

namespace ttnn {
namespace ccl {
namespace all_gather_detail {
Expand Down Expand Up @@ -50,8 +52,7 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con
return AllGatherBidirectionalMode::FULL_TENSOR;
}

std::size_t eth_l1_capacity =
eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
std::size_t eth_l1_capacity = hal::get_erisc_l1_unreserved_size();
std::size_t tensor_size_bytes = input_tensor.volume() * input_tensor.element_size();
// This is currently a guestimate. We need a lot more hard data to identify where this dividing line is.
bool perf_degradation_from_full_tensor_mode = tensor_size_bytes > (2 * eth_l1_capacity);
Expand All @@ -62,8 +63,8 @@ AllGatherBidirectionalMode AllGatherConfig::choose_bidirectional_mode(Tensor con
}

AllGatherConfig::AllGatherConfig(
Tensor const& input_tensor,
Tensor const& output_tensor,
const Tensor& input_tensor,
const Tensor& output_tensor,
uint32_t dim,
uint32_t ring_size,
uint32_t num_links,
Expand All @@ -75,7 +76,7 @@ AllGatherConfig::AllGatherConfig(
semaphore_size(32),
ring_size(ring_size),

erisc_handshake_address(tt::round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)),
erisc_handshake_address(tt::round_up(hal::get_erisc_l1_unreserved_base(), 16)),
topology(topology),
enable_bidirectional(topology == ttnn::ccl::Topology::Ring),

Expand All @@ -86,8 +87,8 @@ AllGatherConfig::AllGatherConfig(
enable_merged_payload_and_channel_sync(true),
num_edm_buffers_per_channel(num_edm_buffers_per_channel) {
TT_FATAL(num_edm_buffers_per_channel > 0, "num_edm_buffers_per_channel must be > 0");
TT_ASSERT(erisc_handshake_address >= eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE);
TT_ASSERT(erisc_handshake_address < eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + 16);
TT_ASSERT(erisc_handshake_address >= hal::get_erisc_l1_unreserved_base());
TT_ASSERT(erisc_handshake_address < hal::get_erisc_l1_unreserved_base() + 16);
TT_ASSERT((erisc_handshake_address & (16 - 1)) == 0);
if (input_tensor.get_layout() == Layout::TILE && dim != 3) {
// See issue #6448
Expand All @@ -106,8 +107,7 @@ AllGatherConfig::AllGatherConfig(
(topology == ttnn::ccl::Topology::Ring && bidirectional_mode != AllGatherBidirectionalMode::FULL_TENSOR) ? 1
: 2;

constexpr uint32_t total_l1_buffer_space =
eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
uint32_t total_l1_buffer_space = hal::get_erisc_l1_unreserved_size();

this->is_sharded = input_tensor.is_sharded();
if (user_defined_num_workers.has_value()) {
Expand All @@ -117,9 +117,10 @@ AllGatherConfig::AllGatherConfig(
(this->enable_bidirectional ? 8 /*1*/ : (topology != ttnn::ccl::Topology::Linear ? 8 : 4));
}

constexpr std::int32_t MAX_NUM_CONCURRENT_TRANSACTIONS = 8;
if (bidirectional_mode == AllGatherBidirectionalMode::FULL_TENSOR) {
this->num_eth_buffers = std::min(
this->num_eth_buffers, eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions);
this->num_eth_buffers =
std::min(this->num_eth_buffers, MAX_NUM_CONCURRENT_TRANSACTIONS / num_duplicate_directions);
}

this->num_workers_per_link = this->num_eth_buffers;
Expand All @@ -145,8 +146,7 @@ AllGatherConfig::AllGatherConfig(
total_l1_buffer_space,
"Error");
TT_FATAL(
eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <=
eth_l1_mem::address_map::MAX_NUM_CONCURRENT_TRANSACTIONS,
eth_buffer_size == 0 or (this->num_eth_buffers * num_duplicate_directions) <= MAX_NUM_CONCURRENT_TRANSACTIONS,
"Error");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
#include "ttnn/operations/ccl/ccl_common.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/experimental/hal.hpp"
#include "tt_metal/impl/buffers/circular_buffer_types.hpp"

#include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "ttnn/operations/eltwise/binary/common/binary_op_utils.hpp"

using namespace tt::tt_metal;
using namespace tt::tt_metal::experimental;

namespace ttnn::ccl::barrier::detail {

Expand Down Expand Up @@ -51,16 +53,15 @@ static std::tuple<std::array<uint32_t, 7>, std::array<uint32_t, 10>, std::array<
CoreCoord const& sem_init_core) {
const uint32_t worker_sem0 = CreateSemaphore(program, sem_init_core, 0, CoreType::WORKER);
const uint32_t worker_sem1 = CreateSemaphore(program, sem_init_core, 0, CoreType::WORKER);
constexpr uint32_t start_semaphore_address =
eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + EriscDatamoverConfig::eth_word_size_bytes;
constexpr uint32_t erisc_semaphore_address =
eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + (EriscDatamoverConfig::eth_word_size_bytes * 2);
constexpr uint32_t erisc_buffer_address =
eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE + (EriscDatamoverConfig::eth_word_size_bytes * 3);
uint32_t start_semaphore_address = hal::get_erisc_l1_unreserved_base() + EriscDatamoverConfig::eth_word_size_bytes;
uint32_t erisc_semaphore_address =
hal::get_erisc_l1_unreserved_base() + (EriscDatamoverConfig::eth_word_size_bytes * 2);
uint32_t erisc_buffer_address =
hal::get_erisc_l1_unreserved_base() + (EriscDatamoverConfig::eth_word_size_bytes * 3);

const std::array<uint32_t, 10> receiver_rt_args = {
static_cast<uint32_t>(is_starting_core ? 1 : 0),
eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE,
hal::get_erisc_l1_unreserved_base(),
static_cast<uint32_t>(device->ethernet_core_from_logical_core(eth_sender_core).x),
static_cast<uint32_t>(device->ethernet_core_from_logical_core(eth_sender_core).y),
erisc_semaphore_address,
Expand All @@ -70,8 +71,8 @@ static std::tuple<std::array<uint32_t, 7>, std::array<uint32_t, 10>, std::array<
static_cast<uint32_t>(device->virtual_core_from_logical_core(sem_init_core, CoreType::WORKER).y),
worker_sem0};
const std::array<uint32_t, 7> sender_rt_args = {
static_cast<uint32_t>(is_starting_core ? 1 : 0), // is_ring_start
eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, // handshake_addr
static_cast<uint32_t>(is_starting_core ? 1 : 0), // is_ring_start
hal::get_erisc_l1_unreserved_base(), // handshake_addr
erisc_buffer_address,
erisc_semaphore_address,
static_cast<uint32_t>(device->virtual_core_from_logical_core(sem_init_core, CoreType::WORKER).x),
Expand Down
9 changes: 5 additions & 4 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,17 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder(
std::size_t num_buffers_per_channel,
ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode,
ccl::EriscDataMoverTerminationMode termination_mode) {
ccl::EriscDatamoverConfig config;
TT_ASSERT(num_channels > 0);
std::vector<uint32_t> edm_sem_addresses(num_channels, 0);
std::vector<uint32_t> edm_buffer_addresses(num_channels, 0);

uint32_t edm_sem_addr = ccl::EriscDatamoverConfig::get_semaphores_base_address(num_channels);
uint32_t edm_buffer_addr = ccl::EriscDatamoverConfig::get_buffers_base_address(num_channels);
uint32_t edm_sem_addr = config.get_semaphores_base_address(num_channels);
uint32_t edm_buffer_addr = config.get_buffers_base_address(num_channels);
TT_ASSERT(edm_sem_addr > 0);
TT_ASSERT(edm_buffer_addr > 0);
const uint32_t channel_buffer_size =
ccl::EriscDatamoverConfig::compute_buffer_size(num_channels, num_buffers_per_channel, page_size);
config.compute_buffer_size(num_channels, num_buffers_per_channel, page_size);
for (std::size_t c = 0; c < num_channels; ++c) {
edm_sem_addresses.at(c) = edm_sem_addr;
edm_sem_addr += ccl::EriscDatamoverConfig::semaphore_size;
Expand All @@ -352,7 +353,7 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder(

return ccl::EriscDatamoverBuilder(
channel_buffer_size,
ccl::EriscDatamoverConfig::get_edm_handshake_address(),
config.get_edm_handshake_address(),
edm_sem_addresses,
edm_buffer_addresses,
buffer_sharing_mode,
Expand Down
15 changes: 7 additions & 8 deletions ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "eth_l1_address_map.h"
#include "tt_metal/experimental/hal.hpp"
#include "ttnn/cpp/ttnn/tensor/tensor_impl.hpp"
#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_types.hpp"
Expand All @@ -15,9 +15,8 @@ namespace ttnn {
namespace ccl {

struct EriscDatamoverConfig {
static constexpr std::size_t total_l1_buffer_space =
eth_l1_mem::address_map::MAX_L1_LOADING_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
static constexpr std::size_t usable_l1_base_address = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE;
std::size_t total_l1_buffer_space = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_size();
std::size_t usable_l1_base_address = tt::tt_metal::experimental::hal::get_erisc_l1_unreserved_base();

static constexpr std::size_t semaphore_size = 32;
static constexpr std::size_t handshake_location_size = 16; // ethernet word size
Expand All @@ -32,14 +31,14 @@ struct EriscDatamoverConfig {
static constexpr std::size_t eth_word_size_bytes = 16;
static constexpr bool enable_merged_payload_and_channel_sync = true;
static std::size_t get_eth_channel_sync_size_bytes();
static uint32_t get_edm_handshake_address();
uint32_t get_edm_handshake_address();
static std::size_t get_semaphores_region_size(std::size_t num_edm_channels);
static std::size_t get_semaphores_region_start_offset(std::size_t num_edm_channels);
static uint32_t get_semaphores_base_address(std::size_t num_edm_channels);
uint32_t get_semaphores_base_address(std::size_t num_edm_channels);
static uint32_t get_buffers_region_start_offset(std::size_t num_edm_channels);
static std::size_t get_eth_word_size();
static uint32_t get_buffers_base_address(std::size_t num_edm_channels);
static uint32_t compute_buffer_size(
uint32_t get_buffers_base_address(std::size_t num_edm_channels);
uint32_t compute_buffer_size(
std::size_t num_edm_channels,
std::size_t num_buffers_per_channel = 1,
uint32_t page_size = eth_word_size_bytes);
Expand Down
Loading

0 comments on commit d0bb408

Please sign in to comment.