Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some Device methods to private section #16259

Merged
merged 4 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/tt_metal/tt_metal/eth/test_ring_gather_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ std::vector<v1::DeviceHandle> get_device_ring(std::vector<tt::tt_metal::v1::Devi
std::vector<std::vector<int>> adj(devices.size(), std::vector<int>(devices.size(), 0));
for (uint32_t i = 0; i < devices.size(); ++i) {
const auto& device = devices[i];
for (const auto& connected_device_id : device->get_ethernet_connected_device_ids()) {
auto ethernet_connected_device_ids = tt::Cluster::instance().get_ethernet_connected_device_ids(device->id());
for (const auto& connected_device_id : ethernet_connected_device_ids) {
for (uint32_t j = 0; j < devices.size(); ++j) {
if (devices[j]->id() == connected_device_id) {
adj[i][j] = 1;
Expand Down
67 changes: 56 additions & 11 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ bool Device::is_inactive_ethernet_core(CoreCoord logical_core) const {
return inactive_ethernet_cores.find(logical_core) != inactive_ethernet_cores.end();
}

std::tuple<chip_id_t, CoreCoord> Device::get_connected_ethernet_core(CoreCoord eth_core) const {
return tt::Cluster::instance().get_connected_ethernet_core(std::make_tuple(this->id_, eth_core));
}

std::vector<CoreCoord> Device::get_ethernet_sockets(chip_id_t connected_chip_id) const {
return tt::Cluster::instance().get_ethernet_sockets(this->id_, connected_chip_id);
}

bool Device::is_mmio_capable() const {
return tt::Cluster::instance().get_associated_mmio_device(this->id_) == this->id_;
}

CoreRangeSet Device::worker_cores(HalProgrammableCoreType core_type, SubDeviceId sub_device_id) const {
return this->active_sub_device_manager_->sub_device(sub_device_id).cores(core_type);
}
Expand Down Expand Up @@ -3223,15 +3235,15 @@ CoreCoord Device::logical_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).worker_grid_size;
}

CoreCoord Device::dram_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).get_dram_grid_size();
}

CoreCoord Device::compute_with_storage_grid_size() const {
const auto &dispatch_core_config = dispatch_core_manager::instance().get_dispatch_core_config(id_);
return tt::get_compute_grid_size(id_, num_hw_cqs_, dispatch_core_config);
}

CoreCoord Device::dram_grid_size() const {
return tt::Cluster::instance().get_soc_desc(id_).get_dram_grid_size();
}

CoreType Device::core_type_from_physical_core(const CoreCoord &physical_coord) const {
const metal_SocDescriptor &soc_desc = tt::Cluster::instance().get_soc_desc(this->id_);
if (soc_desc.physical_cores.find(physical_coord) == soc_desc.physical_cores.end())
Expand All @@ -3249,7 +3261,6 @@ CoreType Device::core_type_from_virtual_core(const CoreCoord &virtual_coord) con
return this->core_type_from_physical_core(virtual_coord);
}


CoreCoord Device::virtual_noc0_coordinate(uint8_t noc_index, CoreCoord coord) const {
if (coord.x >= this->grid_size().x || coord.y >= this->grid_size().y) {
// Coordinate already in virtual space: NOC0 and NOC1 are the same
Expand Down Expand Up @@ -3304,6 +3315,7 @@ std::vector<CoreCoord> Device::ethernet_cores_from_logical_cores(const std::vect
}
return eth_cores;
}

CoreCoord Device::virtual_core_from_logical_core(const CoreCoord &logical_coord, const CoreType& core_type) const {
return tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(this->id_, logical_coord, core_type);
}
Expand Down Expand Up @@ -3679,14 +3691,18 @@ void Device::enable_async(bool enable) {
}

bool Device::using_slow_dispatch() const {
return not (this->using_fast_dispatch_);
return !using_fast_dispatch();
}

bool Device::using_fast_dispatch() const {
return using_fast_dispatch_;
}

void Device::begin_trace(const uint8_t cq_id, const uint32_t tid) {
ZoneScoped;
TracyTTMetalBeginTrace(this->id(), tid);
TT_FATAL(!this->hw_command_queues_[cq_id]->tid.has_value(), "CQ {} is already being used for tracing tid {}", (uint32_t)cq_id, tid);
this->MarkAllocationsSafe();
this->mark_allocations_safe();
// Create an empty trace buffer here. This will get initialized in end_trace
TT_FATAL(this->active_sub_device_manager_->get_trace(tid) == nullptr, "Trace already exists for tid {} on device {}'s active sub-device manager {}", tid, this->id_, this->active_sub_device_manager_id_);
auto &trace_buffer = this->active_sub_device_manager_->create_trace(tid);
Expand All @@ -3701,7 +3717,7 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) {
TT_FATAL(trace_buffer != nullptr, "Trace instance {} must exist on device {}'s active sub-device manager {}", tid, this->id_, this->active_sub_device_manager_id_);
this->hw_command_queues_[cq_id]->record_end();
Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer);
this->MarkAllocationsUnsafe();
this->mark_allocations_unsafe();
}

void Device::replay_trace(const uint8_t cq_id, const uint32_t tid, const bool blocking) {
Expand All @@ -3724,7 +3740,7 @@ void Device::release_trace(const uint32_t tid) {

// Only enable allocations once all captured traces are released
if (this->trace_buffers_size_ == 0) {
this->MarkAllocationsSafe();
this->mark_allocations_safe();
}
}

Expand All @@ -3750,11 +3766,11 @@ std::size_t Device::num_program_cache_entries() {
return program_cache_.num_entries();
}

void Device::MarkAllocationsUnsafe() {
void Device::mark_allocations_unsafe() {
tt::tt_metal::allocator::mark_allocations_unsafe(*this->get_initialized_allocator());
}

void Device::MarkAllocationsSafe() {
void Device::mark_allocations_safe() {
tt::tt_metal::allocator::mark_allocations_safe(*this->get_initialized_allocator());
}

Expand Down Expand Up @@ -3964,6 +3980,35 @@ std::vector<CoreCoord> Device::get_optimal_dram_bank_to_logical_worker_assignmen
return this->optimal_dram_bank_to_logical_worker_assignment_;
}

HalProgrammableCoreType Device::get_programmable_core_type(CoreCoord virtual_core) const {
if (!tt::Cluster::instance().is_ethernet_core(virtual_core, this->id_)) {
return HalProgrammableCoreType::TENSIX;
}

// Eth pcores have a different address, but only active ones.
CoreCoord logical_core = this->logical_core_from_ethernet_core(virtual_core);
if (this->is_active_ethernet_core(logical_core)) {
return HalProgrammableCoreType::ACTIVE_ETH;
}

return HalProgrammableCoreType::IDLE_ETH;
}

// TODO: Find a better home for this function
// Extracts all the pairs of noc multicast encodings given a set of core ranges
std::vector<std::pair<transfer_info_cores, uint32_t>> Device::extract_dst_noc_multicast_info(const std::vector<CoreRange>& ranges, const CoreType core_type) {
std::vector<std::pair<transfer_info_cores, uint32_t>> dst_noc_multicast_info;
dst_noc_multicast_info.reserve(ranges.size());
for (const CoreRange& core_range : ranges) {
ayerofieiev-tt marked this conversation as resolved.
Show resolved Hide resolved
CoreCoord virtual_start = this->virtual_core_from_logical_core(core_range.start_coord, core_type);
CoreCoord virtual_end = this->virtual_core_from_logical_core(core_range.end_coord, core_type);

uint32_t num_receivers = core_range.size();
dst_noc_multicast_info.push_back(std::make_pair(CoreRange(virtual_start, virtual_end), num_receivers));
}
return dst_noc_multicast_info;
}



size_t v1::GetNumAvailableDevices() { return tt::Cluster::instance().number_of_user_devices(); }
Expand Down
Loading
Loading