Skip to content

Commit

Permalink
#15290: Fix UAF issue with Singleton destruction order
Browse files Browse the repository at this point in the history
- Decouple MeshDevice destruction from SystemMesh destruction.
- Refactor SystemMesh to encapsulate internals in pImpl class.
- Added a new test case for SystemMesh tear down with static mesh in the distributed tests.
  • Loading branch information
cfjchu committed Dec 4, 2024
1 parent 37fc6b6 commit cab16ca
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 112 deletions.
13 changes: 12 additions & 1 deletion tests/ttnn/distributed/test_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DistributedTest : public ::testing::Test {
void TearDown() override {}
};

TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose_LocalMesh) {
auto& sys = tt::tt_metal::distributed::SystemMesh::instance();
auto mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);
Expand All @@ -26,4 +26,15 @@ TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose) {
EXPECT_GT(cols, 0);
}

TEST_F(DistributedTest, TestSystemMeshTearDownWithoutClose_StaticMesh) {
static std::shared_ptr<ttnn::MeshDevice> mesh;
auto& sys = tt::tt_metal::distributed::SystemMesh::instance();
mesh = ttnn::distributed::open_mesh_device(
{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER);

auto [rows, cols] = sys.get_shape();
EXPECT_GT(rows, 0);
EXPECT_GT(cols, 0);
}

} // namespace ttnn::distributed::test
175 changes: 108 additions & 67 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,47 @@ static std::unordered_map<LogicalCoordinate, PhysicalCoordinate> load_translatio
return result;
}

MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) {
class SystemMesh::Impl {
private:
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;

std::unordered_map<MeshDeviceID, std::map<chip_id_t, Device*>> opened_devices;
std::unordered_map<MeshDeviceID, std::vector<chip_id_t>> assigned_devices;
std::unordered_map<MeshDeviceID, std::weak_ptr<MeshDevice>> assigned_mesh_device_devices;

MeshShape logical_mesh_shape;
std::unordered_map<LogicalCoordinate, PhysicalCoordinate> logical_to_physical_coordinates;
std::unordered_map<PhysicalCoordinate, chip_id_t> physical_coordinate_to_device_id;
std::unordered_map<chip_id_t, PhysicalCoordinate> physical_device_id_to_coordinate;


public:
Impl() = default;
~Impl() = default;

bool is_system_mesh_initialized() const;
void initialize();
const MeshShape& get_shape() const;
size_t get_num_devices() const;
std::vector<Device *> map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig &config);
std::vector<chip_id_t> get_mapped_physical_device_ids(const MeshDeviceConfig& config) const;
void remove_expired_mesh_devices();
Device* get_device(const chip_id_t physical_device_id) const;
void register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices);

static MeshShape get_system_mesh_shape(size_t system_num_devices);
static std::unordered_map<LogicalCoordinate, PhysicalCoordinate> get_system_mesh_translation_map(size_t system_num_devices);
};

// Implementation of private static methods
MeshShape SystemMesh::Impl::get_system_mesh_shape(size_t system_num_devices) {
const std::unordered_map<size_t, MeshShape> system_mesh_to_shape = {
{1, MeshShape{1, 1}}, // single-device
{2, MeshShape{1, 2}}, // N300
Expand All @@ -67,7 +107,7 @@ MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) {
return shape;
}

std::unordered_map<LogicalCoordinate, PhysicalCoordinate> SystemMesh::get_system_mesh_translation_map(size_t system_num_devices) {
std::unordered_map<LogicalCoordinate, PhysicalCoordinate> SystemMesh::Impl::get_system_mesh_translation_map(size_t system_num_devices) {
const std::unordered_map<size_t, std::string> system_mesh_translation_map = {
{1, "device.json"},
{2, "N300.json"},
Expand All @@ -80,36 +120,30 @@ std::unordered_map<LogicalCoordinate, PhysicalCoordinate> SystemMesh::get_system
return load_translation_map(translation_config_file, "logical_to_physical_coordinates");
}

bool SystemMesh::is_system_mesh_initialized() const {
// Implementation of public methods
bool SystemMesh::Impl::is_system_mesh_initialized() const {
return this->physical_coordinate_to_device_id.size() > 0;
}

SystemMesh& SystemMesh::instance() {
static SystemMesh instance;
if (!instance.is_system_mesh_initialized()) {
instance.initialize();
}
return instance;
}
void SystemMesh::initialize() {
void SystemMesh::Impl::initialize() {
this->physical_device_id_to_coordinate = tt::Cluster::instance().get_user_chip_ethernet_coordinates();
for (const auto& [chip_id, physical_coordinate] : this->physical_device_id_to_coordinate) {
this->physical_coordinate_to_device_id.emplace(physical_coordinate, chip_id);
}

// Initialize the system mesh shape and translation map
auto num_devices = physical_coordinate_to_device_id.size();
this->logical_mesh_shape = SystemMesh::get_system_mesh_shape(num_devices);
this->logical_to_physical_coordinates = SystemMesh::get_system_mesh_translation_map(num_devices);
this->logical_mesh_shape = get_system_mesh_shape(num_devices);
this->logical_to_physical_coordinates = get_system_mesh_translation_map(num_devices);
}

const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; }
size_t SystemMesh::get_num_devices() const {
const MeshShape& SystemMesh::Impl::get_shape() const { return this->logical_mesh_shape; }
size_t SystemMesh::Impl::get_num_devices() const {
auto [num_rows, num_cols] = this->get_shape();
return num_rows * num_cols;
}

std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const {

std::vector<chip_id_t> SystemMesh::Impl::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const {
std::vector<chip_id_t> physical_device_ids;
auto [system_mesh_rows, system_mesh_cols] = this->get_shape();
auto [requested_rows, requested_cols] = config.mesh_shape;
Expand Down Expand Up @@ -142,7 +176,7 @@ std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDevi
}
return physical_device_ids;
}
void SystemMesh::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices) {
void SystemMesh::Impl::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices) {
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
Expand All @@ -151,13 +185,14 @@ void SystemMesh::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_de
this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids});
}

std::vector<Device*> SystemMesh::map_mesh_device(
std::vector<Device*> SystemMesh::Impl::map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig& config) {
this->remove_expired_mesh_devices();

auto [requested_num_rows, requested_num_cols] = mesh_device->shape();
auto [max_num_rows, max_num_cols] = this->logical_mesh_shape;
Expand All @@ -179,34 +214,65 @@ std::vector<Device*> SystemMesh::map_mesh_device(
for (auto physical_device_id : physical_device_ids) {
auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id);
mapped_devices.push_back(mapped_device);
this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device});
}

this->register_mesh_device(mesh_device, mapped_devices); // here
return mapped_devices;
}

void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) {
auto mesh_id = mesh_device->get_mesh_id();
this->assigned_mesh_device_devices.erase(mesh_id);
void SystemMesh::Impl::remove_expired_mesh_devices() {
std::vector<MeshDeviceID> stale_ids;
for (const auto& [mesh_id, weak_mesh_device] : assigned_mesh_device_devices) {
if (weak_mesh_device.expired()) {
stale_ids.push_back(mesh_id);
}
}
for (auto mesh_id : stale_ids) {
this->assigned_mesh_device_devices.erase(mesh_id);

// Close the devices
if (mesh_device->is_parent_mesh()) {
for (auto physical_id : this->assigned_devices.at(mesh_id)) {
this->assigned_physical_id_to_device.erase(physical_id);
if (assigned_devices.count(mesh_id)) {
assigned_devices.erase(mesh_id);
}

if (opened_devices.count(mesh_id)) {
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
}
tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id));
this->opened_devices.erase(mesh_id);
}
this->assigned_devices.erase(mesh_id);
}

Device* SystemMesh::get_device(const chip_id_t physical_device_id) const {
auto it = this->assigned_physical_id_to_device.find(physical_device_id);
if (it == this->assigned_physical_id_to_device.end()) {
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
SystemMesh::SystemMesh() : pimpl(std::make_unique<Impl>()) {}
SystemMesh::~SystemMesh() = default;

SystemMesh& SystemMesh::instance() {
static SystemMesh instance;
if (!instance.pimpl->is_system_mesh_initialized()) {
instance.pimpl->initialize();
}
return it->second;
return instance;
}

const MeshShape& SystemMesh::get_shape() const { return pimpl->get_shape(); }

size_t SystemMesh::get_num_devices() const { return pimpl->get_num_devices(); }

void SystemMesh::register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices) {
pimpl->register_mesh_device(mesh_device, devices);
}

std::vector<Device*> SystemMesh::map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig& config) {
return pimpl->map_mesh_device(mesh_device, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config);
}


std::vector<chip_id_t> SystemMesh::get_mapped_physical_device_ids(const MeshDeviceConfig& config) const {
return pimpl->get_mapped_physical_device_ids(config);
}

static MeshDeviceID generate_unique_mesh_id() {
Expand Down Expand Up @@ -309,7 +375,12 @@ Device* MeshDevice::get_device_index(size_t logical_device_id) const {
}

Device* MeshDevice::get_device(chip_id_t physical_device_id) const {
return SystemMesh::instance().get_device(physical_device_id);
for (auto device : this->devices) {
if (device->id() == physical_device_id) {
return device;
}
}
TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id);
}

std::vector<Device*> MeshDevice::get_devices() const { return this->primary_view->get_devices(this->type); }
Expand Down Expand Up @@ -344,9 +415,7 @@ void MeshDevice::close_devices() {
for (const auto& submesh : this->submeshes) {
submesh->close_devices();
}
if (not this->devices.empty()) {
SystemMesh::instance().unmap_mesh_device(this);
}
this->submeshes.clear();
this->parent_mesh.reset();
this->devices.clear();
this->primary_view.reset();
Expand All @@ -364,34 +433,6 @@ MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; }

bool MeshDevice::is_parent_mesh() const { return this->parent_mesh.expired(); }

std::shared_ptr<MeshDevice> SystemMesh::get_mesh_device(const std::vector<chip_id_t>& physical_device_ids) {
log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids);
std::unordered_set<chip_id_t> input_set(physical_device_ids.begin(), physical_device_ids.end());

for (const auto& [mesh_id, weak_mesh_device] : this->assigned_mesh_device_devices) {
if (auto mesh_device = weak_mesh_device.lock()) {
const auto& assigned_devices = this->assigned_devices.at(mesh_id);
std::unordered_set<chip_id_t> assigned_set(assigned_devices.begin(), assigned_devices.end());
log_trace(LogMetal, "Assigned devices: {}", assigned_devices);

if (input_set == assigned_set) {
return mesh_device;
}
}
}
TT_THROW("No mesh device found for the provided devices");
}

std::shared_ptr<MeshDevice> MeshDevice::fetch_mesh_device(const std::vector<Device*>& devices) {
TT_FATAL(devices.size() > 0, "No devices provided");
auto& instance = SystemMesh::instance();
std::vector<chip_id_t> physical_device_ids;
for (auto device : devices) {
physical_device_ids.push_back(device->id());
}
return instance.get_mesh_device(physical_device_ids);
}

std::vector<std::shared_ptr<MeshDevice>> MeshDevice::get_submeshes() const { return this->submeshes; }

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); }
Expand Down
59 changes: 15 additions & 44 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,62 +50,34 @@ struct MeshDeviceConfig {
// device resources.
class SystemMesh {
private:
using LogicalCoordinate = Coordinate;
using PhysicalCoordinate = eth_coord_t;
friend class MeshDevice;
class Impl; // Forward declaration only
std::unique_ptr<Impl> pimpl;
SystemMesh();
~SystemMesh();

// Keep track of the devices that were opened so we can close them later. We shouldn't
// to keep track of this but DevicePool seems to open all devices associated with an MMIO device id
std::unordered_map<MeshDeviceID, std::map<chip_id_t, Device*>> opened_devices;
std::unordered_map<MeshDeviceID, std::vector<chip_id_t>> assigned_devices;
std::unordered_map<MeshDeviceID, std::weak_ptr<MeshDevice>> assigned_mesh_device_devices;
std::unordered_map<chip_id_t, Device *> assigned_physical_id_to_device;

// Logical mesh shape and coordinates
MeshShape logical_mesh_shape;
std::unordered_map<LogicalCoordinate, PhysicalCoordinate> logical_to_physical_coordinates;

// Handling of physical coordinates
std::unordered_map<PhysicalCoordinate, chip_id_t> physical_coordinate_to_device_id;
std::unordered_map<chip_id_t, PhysicalCoordinate> physical_device_id_to_coordinate;
void register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices);
std::vector<Device *> map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig &config);

SystemMesh() = default;
public:
static SystemMesh &instance();
SystemMesh(const SystemMesh &) = delete;
SystemMesh &operator=(const SystemMesh &) = delete;
SystemMesh(SystemMesh &&) = delete;
SystemMesh &operator=(SystemMesh &&) = delete;

static MeshShape get_system_mesh_shape(size_t system_num_devices);
static std::unordered_map<LogicalCoordinate, PhysicalCoordinate> get_system_mesh_translation_map(
size_t system_num_devices);

bool is_system_mesh_initialized() const;

public:
static SystemMesh &instance();

void initialize();

// Return the shape of the logical mesh
const MeshShape &get_shape() const;
size_t get_num_devices() const;

// Get the physical device IDs mapped to a MeshDevice
std::vector<chip_id_t> get_mapped_physical_device_ids(const MeshDeviceConfig &config) const;
void register_mesh_device(const std::shared_ptr<MeshDevice> &mesh_device, const std::vector<Device*>& devices);

// Map MeshDevice to physical devices
std::vector<Device *> map_mesh_device(
std::shared_ptr<MeshDevice> mesh_device,
size_t num_command_queues,
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const MeshDeviceConfig &config);

// Unmap MeshDevice, releasing the associated physical devices.
void unmap_mesh_device(const MeshDevice* mesh_device);
std::shared_ptr<MeshDevice> get_mesh_device(const std::vector<chip_id_t>& physical_device_ids);
Device* get_device(const chip_id_t physical_device_id) const;
};

class MeshDevice : public std::enable_shared_from_this<MeshDevice> {
Expand Down Expand Up @@ -177,7 +149,6 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

size_t num_program_cache_entries() const;

static std::shared_ptr<MeshDevice> fetch_mesh_device(const std::vector<Device*>& devices);
static std::shared_ptr<MeshDevice> create(
const MeshDeviceConfig &config,
size_t l1_small_size = DEFAULT_L1_SMALL_SIZE,
Expand Down

0 comments on commit cab16ca

Please sign in to comment.