Skip to content

Commit

Permalink
Move some Device methods to private section (#16259)
Browse files Browse the repository at this point in the history
### Ticket
None

### Problem description
There are plenty of Device methods which are only used by device itself
but they are in a public section

This PR depends on #16256

### What's changed
This PR does not add any new functionality, nor changes existing. It
shuffles things around.
* Moved near every method thats only used by Device to the private
section (19)
* Re-groupped methods logically: methods that proxy to allocator, to
cluster
* Moved a couple implementations to cpp
* Reviewed usage of some methods, left comments where its desirable to
review usage with owners
* Removed template where it was not used

Whats next:
* Discuss and align what we can/want to do about the two outlined groups
of methods
* Review notes with owners, align on next steps

### Checklist
- [ ] [Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/12475902305)
  • Loading branch information
ayerofieiev-tt authored Dec 25, 2024
1 parent 9f4eb98 commit a927a47
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 168 deletions.
3 changes: 2 additions & 1 deletion tests/tt_metal/tt_metal/eth/test_ring_gather_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ std::vector<v1::DeviceHandle> get_device_ring(std::vector<tt::tt_metal::v1::Devi
std::vector<std::vector<int>> adj(devices.size(), std::vector<int>(devices.size(), 0));
for (uint32_t i = 0; i < devices.size(); ++i) {
const auto& device = devices[i];
for (const auto& connected_device_id : device->get_ethernet_connected_device_ids()) {
auto ethernet_connected_device_ids = tt::Cluster::instance().get_ethernet_connected_device_ids(device->id());
for (const auto& connected_device_id : ethernet_connected_device_ids) {
for (uint32_t j = 0; j < devices.size(); ++j) {
if (devices[j]->id() == connected_device_id) {
adj[i][j] = 1;
Expand Down
67 changes: 56 additions & 11 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ bool Device::is_inactive_ethernet_core(CoreCoord logical_core) const {
return inactive_ethernet_cores.find(logical_core) != inactive_ethernet_cores.end();
}

std::tuple<chip_id_t, CoreCoord> Device::get_connected_ethernet_core(CoreCoord eth_core) const {
return tt::Cluster::instance().get_connected_ethernet_core(std::make_tuple(this->id_, eth_core));
}

std::vector<CoreCoord> Device::get_ethernet_sockets(chip_id_t connected_chip_id) const {
return tt::Cluster::instance().get_ethernet_sockets(this->id_, connected_chip_id);
}

bool Device::is_mmio_capable() const {
return tt::Cluster::instance().get_associated_mmio_device(this->id_) == this->id_;
}

CoreRangeSet Device::worker_cores(HalProgrammableCoreType core_type, SubDeviceId sub_device_id) const {
return this->active_sub_device_manager_->sub_device(sub_device_id).cores(core_type);
}
Expand Down Expand Up @@ -3223,15 +3235,15 @@ CoreCoord Device::logical_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).worker_grid_size;
}

CoreCoord Device::dram_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).get_dram_grid_size();
}

CoreCoord Device::compute_with_storage_grid_size() const {
const auto &dispatch_core_config = dispatch_core_manager::instance().get_dispatch_core_config(id_);
return tt::get_compute_grid_size(id_, num_hw_cqs_, dispatch_core_config);
}

CoreCoord Device::dram_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).get_dram_grid_size();
}

CoreType Device::core_type_from_physical_core(const CoreCoord &physical_coord) const {
const metal_SocDescriptor &soc_desc = tt::Cluster::instance().get_soc_desc(this->id_);
if (soc_desc.physical_cores.find(physical_coord) == soc_desc.physical_cores.end())
Expand All @@ -3249,7 +3261,6 @@ CoreType Device::core_type_from_virtual_core(const CoreCoord &virtual_coord) con
return this->core_type_from_physical_core(virtual_coord);
}


CoreCoord Device::virtual_noc0_coordinate(uint8_t noc_index, CoreCoord coord) const {
if (coord.x >= this->grid_size().x || coord.y >= this->grid_size().y) {
// Coordinate already in virtual space: NOC0 and NOC1 are the same
Expand Down Expand Up @@ -3304,6 +3315,7 @@ std::vector<CoreCoord> Device::ethernet_cores_from_logical_cores(const std::vect
}
return eth_cores;
}

CoreCoord Device::virtual_core_from_logical_core(const CoreCoord &logical_coord, const CoreType& core_type) const {
return tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(this->id_, logical_coord, core_type);
}
Expand Down Expand Up @@ -3679,14 +3691,18 @@ void Device::enable_async(bool enable) {
}

bool Device::using_slow_dispatch() const {
return not (this->using_fast_dispatch_);
return !using_fast_dispatch();
}

bool Device::using_fast_dispatch() const {
return using_fast_dispatch_;
}

void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
ZoneScoped;
TracyTTMetalBeginTrace(this->id(), tid);
TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid);
this->MarkAllocationsSafe();
this->mark_allocations_safe();
// Create an empty trace buffer here. This will get initialized in end_trace
TT_FATAL(this->active_sub_device_manager_->get_trace(tid) == nullptr, "Trace already exists for tid {} on device {}'s active sub-device manager {}", tid, this->id_, this->active_sub_device_manager_id_);
auto &trace_buffer = this->active_sub_device_manager_->create_trace(tid);
Expand All @@ -3701,7 +3717,7 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
TT_FATAL(trace_buffer != nullptr, "Trace instance {} must exist on device {}'s active sub-device manager {}", tid, this->id_, this->active_sub_device_manager_id_);
this->hw_command_queues_[cq_id]->record_end();
Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer);
this->MarkAllocationsUnsafe();
this->mark_allocations_unsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
Expand All @@ -3724,7 +3740,7 @@ void Device::release_trace(const uint32_t tid) {

// Only enable allocations once all captured traces are released
if (this->trace_buffers_size_ == 0) {
this->MarkAllocationsSafe();
this->mark_allocations_safe();
}
}

Expand All @@ -3750,11 +3766,11 @@ std::size_t Device::num_program_cache_entries() {
return program_cache_.num_entries();
}

void Device::MarkAllocationsUnsafe() {
void Device::mark_allocations_unsafe() {
tt::tt_metal::allocator::mark_allocations_unsafe(*this->get_initialized_allocator());
}

void Device::MarkAllocationsSafe() {
void Device::mark_allocations_safe() {
tt::tt_metal::allocator::mark_allocations_safe(*this->get_initialized_allocator());
}

Expand Down Expand Up @@ -3964,6 +3980,35 @@ std::vector<CoreCoord> Device::get_optimal_dram_bank_to_logical_worker_assignmen
return this->optimal_dram_bank_to_logical_worker_assignment_;
}

HalProgrammableCoreType Device::get_programmable_core_type(CoreCoord virtual_core) const {
if (!tt::Cluster::instance().is_ethernet_core(virtual_core, this->id_)) {
return HalProgrammableCoreType::TENSIX;
}

// Eth pcores have a different address, but only active ones.
CoreCoord logical_core = this->logical_core_from_ethernet_core(virtual_core);
if (this->is_active_ethernet_core(logical_core)) {
return HalProgrammableCoreType::ACTIVE_ETH;
}

return HalProgrammableCoreType::IDLE_ETH;
}

// TODO: Find a better home for this function
// Extracts all the pairs of noc multicast encodings given a set of core ranges
std::vector<std::pair<transfer_info_cores, uint32_t>> Device::extract_dst_noc_multicast_info(const std::vector<CoreRange>& ranges, const CoreType core_type) {
std::vector<std::pair<transfer_info_cores, uint32_t>> dst_noc_multicast_info;
dst_noc_multicast_info.reserve(ranges.size());
for (const CoreRange& core_range : ranges) {
CoreCoord virtual_start = this->virtual_core_from_logical_core(core_range.start_coord, core_type);
CoreCoord virtual_end = this->virtual_core_from_logical_core(core_range.end_coord, core_type);

uint32_t num_receivers = core_range.size();
dst_noc_multicast_info.push_back(std::make_pair(CoreRange(virtual_start, virtual_end), num_receivers));
}
return dst_noc_multicast_info;
}



size_t v1::GetNumAvailableDevices() { return tt::Cluster::instance().number_of_user_devices(); }
Expand Down
Loading

0 comments on commit a927a47

Please sign in to comment.