Skip to content

Commit

Permalink
#9045: Enable switching between 1 and 2 cqs in the same process. Add …
Browse files Browse the repository at this point in the history
…num hw cqs to the build key since this affects the bank mappings, and have device pool compile as needed
  • Loading branch information
tt-aho committed Jun 28, 2024
1 parent 86ab828 commit 4e70862
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 27 deletions.
44 changes: 24 additions & 20 deletions tt_metal/common/core_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ inline uint32_t get_l1_bank_size(chip_id_t device_id, const uint8_t num_hw_cqs)
}

inline const std::vector<CoreCoord> &get_logical_storage_cores(chip_id_t device_id, const uint8_t num_hw_cqs) {
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
static std::unordered_map<chip_id_t, std::vector<CoreCoord>> logical_storage_cores_by_device;
if (logical_storage_cores_by_device.count(device_id)) {
return logical_storage_cores_by_device.at(device_id);
static std::unordered_map<chip_id_t, std::unordered_map<uint8_t, std::vector<CoreCoord>>> logical_storage_cores_by_device;
auto& logical_storage_cores_by_cq = logical_storage_cores_by_device[device_id];
if (auto it = logical_storage_cores_by_cq.find(num_hw_cqs); it != logical_storage_cores_by_cq.end()) {
return it->second;
}
CoreCoord grid_size = tt::Cluster::instance().get_soc_desc(device_id).worker_grid_size;
std::vector<CoreCoord> &logical_storage_cores = logical_storage_cores_by_device[device_id];
std::vector<CoreCoord> &logical_storage_cores = logical_storage_cores_by_cq[num_hw_cqs];
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
std::transform(core_desc.relative_storage_cores.cbegin(), core_desc.relative_storage_cores.cend(), std::back_inserter(logical_storage_cores),
[&grid_size](RelativeCoreCoord rel_coord) { return get_core_coord_from_relative(rel_coord, grid_size); });
return logical_storage_cores;
Expand All @@ -105,38 +106,41 @@ inline CoreCoord get_compute_grid_size(chip_id_t device_id, const uint8_t num_hw
}

inline const std::vector<CoreCoord> &get_logical_compute_cores(chip_id_t device_id, const uint8_t num_hw_cqs) {
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
static std::unordered_map<chip_id_t, std::vector<CoreCoord>> logical_compute_cores_by_device;
if (logical_compute_cores_by_device.count(device_id)) {
return logical_compute_cores_by_device.at(device_id);
static std::unordered_map<chip_id_t, std::unordered_map<uint8_t, std::vector<CoreCoord>>> logical_compute_cores_by_device;
auto& logical_compute_cores_by_cq = logical_compute_cores_by_device[device_id];
if (auto it = logical_compute_cores_by_cq.find(num_hw_cqs); it != logical_compute_cores_by_cq.end()) {
return it->second;
}
CoreCoord grid_size = tt::Cluster::instance().get_soc_desc(device_id).worker_grid_size;
std::vector<CoreCoord> &logical_compute_cores = logical_compute_cores_by_device[device_id];
std::vector<CoreCoord> &logical_compute_cores = logical_compute_cores_by_cq[num_hw_cqs];
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
std::transform(core_desc.relative_compute_cores.cbegin(), core_desc.relative_compute_cores.cend(), std::back_inserter(logical_compute_cores),
[&grid_size](RelativeCoreCoord rel_coord) { return get_core_coord_from_relative(rel_coord, grid_size); });
return logical_compute_cores;
}

inline const std::vector<CoreCoord> &get_logical_dispatch_cores(chip_id_t device_id, const uint8_t num_hw_cqs) {
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
static std::unordered_map<chip_id_t, std::vector<CoreCoord>> logical_dispatch_cores_by_device;
if (logical_dispatch_cores_by_device.count(device_id)) {
return logical_dispatch_cores_by_device.at(device_id);
static std::unordered_map<chip_id_t, std::unordered_map<uint8_t, std::vector<CoreCoord>>> logical_dispatch_cores_by_device;
auto& logical_dispatch_cores_by_cq = logical_dispatch_cores_by_device[device_id];
if (auto it = logical_dispatch_cores_by_cq.find(num_hw_cqs); it != logical_dispatch_cores_by_cq.end()) {
return it->second;
}
CoreCoord grid_size = tt::Cluster::instance().get_soc_desc(device_id).worker_grid_size;
std::vector<CoreCoord> &logical_dispatch_cores = logical_dispatch_cores_by_device[device_id];
std::vector<CoreCoord> &logical_dispatch_cores = logical_dispatch_cores_by_cq[num_hw_cqs];
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
std::transform(core_desc.relative_dispatch_cores.cbegin(), core_desc.relative_dispatch_cores.cend(), std::back_inserter(logical_dispatch_cores),
[&grid_size](RelativeCoreCoord rel_coord) { return get_core_coord_from_relative(rel_coord, grid_size); });
return logical_dispatch_cores;
}

inline const CoreType get_dispatch_core_type(chip_id_t device_id, const uint8_t num_hw_cqs) {
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
static std::unordered_map<chip_id_t, CoreType> dispatch_core_type_by_device;
if (dispatch_core_type_by_device.count(device_id)) {
return dispatch_core_type_by_device.at(device_id);
static std::unordered_map<chip_id_t, std::unordered_map<uint8_t, CoreType>> dispatch_core_type_by_device;
auto& dispatch_core_type_by_cq = dispatch_core_type_by_device[device_id];
if (auto it = dispatch_core_type_by_cq.find(num_hw_cqs); it != dispatch_core_type_by_cq.end()) {
return it->second;
}
dispatch_core_type_by_device[device_id] = core_desc.dispatch_core_type;
const core_descriptor_t &core_desc = get_core_descriptor_config(device_id, num_hw_cqs);
dispatch_core_type_by_cq[num_hw_cqs] = core_desc.dispatch_core_type;
return core_desc.dispatch_core_type;
}

Expand Down
7 changes: 3 additions & 4 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ Device::Device(
chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal, uint32_t worker_core) :
id_(device_id), worker_thread_core(worker_core), work_executor(worker_core, device_id) {
ZoneScoped;
TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs <= Device::max_num_hw_cqs, "num_hw_cqs can be between 1 and {}", Device::max_num_hw_cqs);
this->build_key_ = tt::Cluster::instance().get_harvesting_mask(device_id);
tunnel_device_dispatch_workers_ = {};
this->initialize(num_hw_cqs, l1_small_size, trace_region_size, l1_bank_remap, minimal);
}
Expand Down Expand Up @@ -1606,10 +1604,11 @@ void Device::initialize_synchronous_sw_cmd_queue() {
bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal) {
ZoneScoped;
log_info(tt::LogMetal, "Initializing device {}. Program cache is {}enabled", this->id_, this->program_cache.is_enabled() ? "": "NOT ");
TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs < 3, "num_hw_cqs can be between 1 and 2");
TT_FATAL(num_hw_cqs > 0 and num_hw_cqs <= Device::max_num_hw_cqs, "num_hw_cqs can be between 1 and {}", Device::max_num_hw_cqs);
this->using_fast_dispatch = false;
this->build_key_ = tt::Cluster::instance().get_harvesting_mask(this->id());
this->num_hw_cqs_ = num_hw_cqs;
constexpr uint32_t harvesting_map_bits = 12;
this->build_key_ = ((uint32_t)this->num_hw_cqs_ << harvesting_map_bits) | tt::Cluster::instance().get_harvesting_mask(this->id());
this->initialize_cluster();
this->initialize_allocator(l1_small_size, trace_region_size, l1_bank_remap);
this->initialize_build();
Expand Down
9 changes: 8 additions & 1 deletion tt_metal/impl/device/device_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,20 @@ void DevicePool::activate_device(chip_id_t id) {
int core_assigned_to_device = this->device_to_core_map.at(id);
auto dev =
new Device(id, this->num_hw_cqs, this->l1_small_size, this->trace_region_size, this->l1_bank_remap, false, core_assigned_to_device);
dev->build_firmware();
if (!this->firmware_built_keys.contains(dev->build_key())) {
dev->build_firmware();
this->firmware_built_keys.insert(dev->build_key());
}
this->devices[id] = std::unique_ptr<Device>(dev);
} else {
const auto& dev = this->devices[id];
log_debug(tt::LogMetal, "DevicePool re-initialize device {}", id);
if (not dev->is_initialized()) {
dev->initialize(num_hw_cqs, this->l1_small_size, this->trace_region_size, this->l1_bank_remap);
if (!this->firmware_built_keys.contains(dev->build_key())) {
dev->build_firmware();
this->firmware_built_keys.insert(dev->build_key());
}
} else {
TT_THROW("Cannot re-initialize device {}, must first call close()", id);
}
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/device/device_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class DevicePool {
std::mutex lock;
std::vector<std::unique_ptr<Device>> devices;
bool skip_remote_devices;
std::unordered_set<uint32_t> firmware_built_keys;

// Determine which CPU cores the worker threads need to be placed on for each device
std::unordered_map<uint32_t, uint32_t> device_to_core_map;
Expand Down
8 changes: 6 additions & 2 deletions tt_metal/impl/dispatch/dispatch_core_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ class dispatch_core_manager {

// Ugly to accept num HW CQs here but it is needed to pull the correct number of initially available dispatch cores for assignment
static dispatch_core_manager &get(uint8_t num_hw_cqs) {
static dispatch_core_manager inst = dispatch_core_manager(num_hw_cqs);
return inst;
static std::unordered_map<uint8_t, std::unique_ptr<dispatch_core_manager>> dispatch_core_managers;
if (dispatch_core_managers[num_hw_cqs] == nullptr) {
// Need to do this since dispatch_core_manager constructor is private
dispatch_core_managers[num_hw_cqs] = std::unique_ptr<dispatch_core_manager>(new dispatch_core_manager(num_hw_cqs));
}
return *dispatch_core_managers[num_hw_cqs];
}

/// @brief Gets the location of the kernel desginated to read from the issue queue region from a particular command queue
Expand Down

0 comments on commit 4e70862

Please sign in to comment.