Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new HAL APIs #15645

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "tt_metal/types.hpp"

// FIXME: ARCH_NAME specific
#include "noc/noc_parameters.h" // NOC_XY_ENCODING
#include "eth_l1_address_map.h"

namespace tt {
Expand Down Expand Up @@ -1290,7 +1289,7 @@ void Device::update_workers_build_settings(std::vector<std::vector<std::tuple<tt
compile_args[13] = 0; // unused: remote ds semaphore
compile_args[14] = 0; // preamble size
compile_args[15] = true, // split_prefetcher
compile_args[16] = NOC_XY_ENCODING(prefetch_physical_core.x, prefetch_physical_core.y),
compile_args[16] = tt::tt_metal::hal.noc_xy_encoding(prefetch_physical_core.x, prefetch_physical_core.y),
compile_args[17] = prefetch_h_settings.producer_semaphore_id, // sem_id on prefetch_h that dispatch_d is meant to increment, to resume sending of cmds post exec_buf stall
compile_args[18] = dispatch_constants::get(dispatch_core_type).mux_buffer_pages(num_hw_cqs), // XXXX should this be mux pages?
compile_args[19] = settings.num_compute_cores;
Expand Down Expand Up @@ -1345,7 +1344,7 @@ void Device::update_workers_build_settings(std::vector<std::vector<std::tuple<tt
compile_args[13] = 0; // unused: remote ds semaphore
compile_args[14] = 0; // preamble size
compile_args[15] = true, // split_prefetcher
compile_args[16] = NOC_XY_ENCODING(prefetch_physical_core.x, prefetch_physical_core.y),
compile_args[16] = tt::tt_metal::hal.noc_xy_encoding(prefetch_physical_core.x, prefetch_physical_core.y),
compile_args[17] = prefetch_h_settings.producer_semaphore_id, // sem_id on prefetch_h that dispatch_d is meant to increment, to resume sending of cmds post exec_buf stall
compile_args[18] = mux_settings.cb_pages,
compile_args[19] = settings.num_compute_cores;
Expand Down Expand Up @@ -3146,7 +3145,7 @@ std::vector<CoreCoord> Device::ethernet_cores_from_logical_cores(const std::vect

uint32_t Device::get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const {
const auto& grid_size = this->grid_size();
return NOC_XY_ENCODING(
return tt::tt_metal::hal.noc_xy_encoding(
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.x, physical_core.x),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.y, physical_core.y)
);
Expand All @@ -3157,14 +3156,14 @@ uint32_t Device::get_noc_multicast_encoding(uint8_t noc_index, const CoreRange&

// NOC 1 mcasts from bottom left to top right, so we need to reverse the coords
if (noc_index == 0) {
return NOC_MULTICAST_ENCODING(
return tt::tt_metal::hal.noc_multicast_encoding(
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.x, physical_cores.start_coord.x),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.y, physical_cores.start_coord.y),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.x, physical_cores.end_coord.x),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.y, physical_cores.end_coord.y)
);
} else {
return NOC_MULTICAST_ENCODING(
return tt::tt_metal::hal.noc_multicast_encoding(
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.x, physical_cores.end_coord.x),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.y, physical_cores.end_coord.y),
tt::tt_metal::hal.noc_coordinate(noc_index, grid_size.x, physical_cores.start_coord.x),
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/llrt/blackhole/bh_hal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ void Hal::initialize_bh() {
((addr >= NOC1_REGS_START_ADDR) && (addr < NOC1_REGS_START_ADDR + 0x1000)) ||
(addr == RISCV_DEBUG_REG_SOFT_RESET_0));
};

this->noc_xy_encoding_func_ = [](uint32_t x, uint32_t y) { return NOC_XY_ENCODING(x, y); };
this->noc_multicast_encoding_func_ = [](uint32_t x_start, uint32_t y_start, uint32_t x_end, uint32_t y_end) {
return NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end);
};
}

} // namespace tt_metal
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/llrt/grayskull/gs_hal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ void Hal::initialize_gs() {
((addr >= NOC1_REGS_START_ADDR) && (addr < NOC1_REGS_START_ADDR + 0x1000)) ||
(addr == RISCV_DEBUG_REG_SOFT_RESET_0));
};

this->noc_xy_encoding_func_ = [](uint32_t x, uint32_t y) { return NOC_XY_ENCODING(x, y); };
this->noc_multicast_encoding_func_ = [](uint32_t x_start, uint32_t y_start, uint32_t x_end, uint32_t y_end) {
return NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end);
};
}

} // namespace tt_metal
Expand Down
9 changes: 9 additions & 0 deletions tt_metal/llrt/hal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class Hal {
public:
using RelocateFunc = std::function<uint64_t(uint64_t, uint64_t)>;
using ValidRegAddrFunc = std::function<bool(uint32_t)>;
using NOCXYEncodingFunc = std::function<uint32_t(uint32_t, uint32_t)>;
using NOCMulticastEncodingFunc = std::function<uint32_t(uint32_t, uint32_t, uint32_t, uint32_t)>;

private:
tt::ARCH arch_;
Expand All @@ -153,6 +155,8 @@ class Hal {
// Functions where implementation varies by architecture
RelocateFunc relocate_func_;
ValidRegAddrFunc valid_reg_addr_func_;
NOCXYEncodingFunc noc_xy_encoding_func_;
NOCMulticastEncodingFunc noc_multicast_encoding_func_;

public:
Hal();
Expand All @@ -165,6 +169,11 @@ class Hal {
return noc_index == 0 ? coord : (noc_size - 1 - coord);
}

uint32_t noc_xy_encoding(uint32_t x, uint32_t y) const { return noc_xy_encoding_func_(x, y); }
uint32_t noc_multicast_encoding(uint32_t x_start, uint32_t y_start, uint32_t x_end, uint32_t y_end) const {
return noc_multicast_encoding_func_(x_start, y_start, x_end, y_end);
}

uint32_t get_programmable_core_type_count() const;
HalProgrammableCoreType get_programmable_core_type(uint32_t core_type_index) const;
uint32_t get_programmable_core_type_index(HalProgrammableCoreType programmable_core_type_index) const;
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/llrt/wormhole/wh_hal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ void Hal::initialize_wh() {
((addr >= NOC1_REGS_START_ADDR) && (addr < NOC1_REGS_START_ADDR + 0x1000)) ||
(addr == RISCV_DEBUG_REG_SOFT_RESET_0));
};

this->noc_xy_encoding_func_ = [](uint32_t x, uint32_t y) { return NOC_XY_ENCODING(x, y); };
this->noc_multicast_encoding_func_ = [](uint32_t x_start, uint32_t y_start, uint32_t x_end, uint32_t y_end) {
return NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end);
};
}

} // namespace tt_metal
Expand Down
Loading