Skip to content

Commit

Permalink
#0: Add pybindings for sub-device, global semaphore, and global circu…
Browse files Browse the repository at this point in the history
…lar buffer apis
  • Loading branch information
tt-aho committed Dec 4, 2024
1 parent 2500267 commit 643ca05
Show file tree
Hide file tree
Showing 29 changed files with 830 additions and 19 deletions.
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/api/test_global_semaphores.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST_F(DispatchFixture, CreateMultipleGlobalSemaphoresOnSameCore) {
}
for (auto device : devices_) {
{
std::vector<std::unique_ptr<tt::tt_metal::GlobalSemaphore>> global_semaphores;
std::vector<std::shared_ptr<tt::tt_metal::GlobalSemaphore>> global_semaphores;
global_semaphores.reserve(cores.size());
std::vector<DeviceAddr> addresses;
addresses.reserve(cores.size());
Expand Down
6 changes: 3 additions & 3 deletions tests/tt_metal/tt_metal/dispatch/sub_device_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// TODO: ARCH_NAME specific, must remove
#include "eth_l1_address_map.h"

inline std::tuple<Program, CoreCoord, std::unique_ptr<GlobalSemaphore>> create_single_sync_program(
inline std::tuple<Program, CoreCoord, std::shared_ptr<GlobalSemaphore>> create_single_sync_program(
Device* device, SubDevice sub_device) {
auto syncer_coord = sub_device.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord;
auto syncer_core = CoreRangeSet(CoreRange(syncer_coord, syncer_coord));
Expand All @@ -26,7 +26,7 @@ inline std::tuple<Program, CoreCoord, std::unique_ptr<GlobalSemaphore>> create_s
return {std::move(syncer_program), std::move(syncer_coord), std::move(global_sem)};
}

inline std::tuple<Program, Program, Program, std::unique_ptr<GlobalSemaphore>> create_basic_sync_program(
inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> create_basic_sync_program(
Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) {
auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::TENSIX).ranges().at(0).start_coord;
auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord));
Expand Down Expand Up @@ -70,7 +70,7 @@ inline std::tuple<Program, Program, Program, std::unique_ptr<GlobalSemaphore>> c
std::move(waiter_program), std::move(syncer_program), std::move(incrementer_program), std::move(global_sem)};
}

inline std::tuple<Program, Program, Program, std::unique_ptr<GlobalSemaphore>> create_basic_eth_sync_program(
inline std::tuple<Program, Program, Program, std::shared_ptr<GlobalSemaphore>> create_basic_eth_sync_program(
Device* device, const SubDevice& sub_device_1, const SubDevice& sub_device_2) {
auto waiter_coord = sub_device_2.cores(HalProgrammableCoreType::ACTIVE_ETH).ranges().at(0).start_coord;
auto waiter_core = CoreRangeSet(CoreRange(waiter_coord, waiter_coord));
Expand Down
42 changes: 42 additions & 0 deletions tests/ttnn/unit_tests/test_global_circular_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn


def run_global_circular_buffer(device):
sender_cores = [ttnn.CoreCoord(1, 1), ttnn.CoreCoord(2, 2)]
receiver_cores = [
ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
),
}
),
ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(2, 3),
ttnn.CoreCoord(2, 4),
),
}
),
]
sender_receiver_mapping = dict(zip(sender_cores, receiver_cores))

global_circular_buffer = ttnn.create_global_circular_buffer(device, sender_receiver_mapping, 3200)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_circular_buffer(device, enable_async_mode):
run_global_circular_buffer(device)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_circular_buffer_mesh(mesh_device, enable_async_mode):
run_global_circular_buffer(mesh_device)
42 changes: 42 additions & 0 deletions tests/ttnn/unit_tests/test_global_semaphore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn


def run_global_semaphore(device):
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
),
}
)
tensix_cores1 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
),
}
)
global_sem0 = ttnn.create_global_semaphore(device, tensix_cores0, 1)
global_sem1 = ttnn.create_global_semaphore(device, tensix_cores1, 2)

assert ttnn.get_global_semaphore_address(global_sem0) != ttnn.get_global_semaphore_address(global_sem1)

ttnn.reset_global_semaphore_value(global_sem0)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_semaphore(device, enable_async_mode):
run_global_semaphore(device)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_semaphore_mesh(mesh_device, enable_async_mode):
run_global_semaphore(mesh_device)
112 changes: 112 additions & 0 deletions tests/ttnn/unit_tests/test_sub_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn


def run_sub_devices(device):
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
),
}
)
tensix_cores1 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
),
}
)
sub_device_1 = ttnn.SubDevice([tensix_cores0])
sub_device_2 = ttnn.SubDevice([tensix_cores1])
sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200)
sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200)
device.load_sub_device_manager(sub_device_manager1)
device.load_sub_device_manager(sub_device_manager2)
device.clear_loaded_sub_device_manager()
device.remove_sub_device_manager(sub_device_manager1)
device.remove_sub_device_manager(sub_device_manager2)


def run_sub_devices_program(device):
is_mesh_device = isinstance(device, ttnn.MeshDevice)
if is_mesh_device:
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0)
output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0)
num_devices = device.get_num_devices()
else:
inputs_mesh_mapper = None
output_mesh_composer = None
num_devices = 1
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
),
}
)
tensix_cores1 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
),
}
)
sub_device_1 = ttnn.SubDevice([tensix_cores0])
sub_device_2 = ttnn.SubDevice([tensix_cores1])
sub_device_manager = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200)
device.load_sub_device_manager(sub_device_manager)

x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16)
xt = ttnn.from_torch(
x,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
mesh_mapper=inputs_mesh_mapper,
)

grid_size = device.compute_with_storage_grid_size()
shard_size = [32, 64]
shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR
yt = ttnn.interleaved_to_sharded(
xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16
)
y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer)

eq = torch.equal(x, y)
assert eq

device.clear_loaded_sub_device_manager()
device.remove_sub_device_manager(sub_device_manager)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_devices(device, enable_async_mode):
run_sub_devices(device)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_devices_mesh(mesh_device, enable_async_mode):
run_sub_devices(mesh_device)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_device_program(device, enable_async_mode):
run_sub_devices_program(device)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_device_program_mesh(mesh_device, enable_async_mode):
run_sub_devices_program(mesh_device)
44 changes: 44 additions & 0 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,48 @@ size_t MeshDevice::num_program_cache_entries() const {
return total_entries;
}

MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
for (uint32_t i = 0; i < this->num_devices(); i++) {
auto* device = this->devices[i];
auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() {
sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size);
});
}
for (auto* device : this->devices) {
device->synchronize();
}
return mesh_sub_device_manager_id;
}
void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) {
for (uint32_t i = 0; i < this->num_devices(); i++) {
auto* device = this->devices[i];
auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
device->push_work([device, sub_device_manager_id]() {
device->load_sub_device_manager(sub_device_manager_id);
});
}
}
void MeshDevice::clear_loaded_sub_device_manager() {
for (auto* device : this->devices) {
device->push_work([device]() {
device->clear_loaded_sub_device_manager();
});
}
}
void MeshDevice::remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) {
for (uint32_t i = 0; i < this->num_devices(); i++) {
auto* device = this->devices[i];
auto sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
device->push_work([device, sub_device_manager_id]() {
device->remove_sub_device_manager(sub_device_manager_id);
});
}
}

MeshSubDeviceManagerId::MeshSubDeviceManagerId(const MeshDevice& mesh_device) {
this->sub_device_manager_ids.resize(mesh_device.num_devices());
}

} // namespace tt::tt_metal::distributed
20 changes: 19 additions & 1 deletion tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include <optional>
#include <vector>

#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/distributed/mesh_device_view.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/impl/sub_device/sub_device_types.hpp"
#include "tt_metal/tt_stl/span.hpp"

namespace tt::tt_metal::distributed {

Expand All @@ -19,6 +21,8 @@ using MeshDeviceID = size_t;
using MeshOffset = std::pair<size_t, size_t>;
class MeshDeviceView;

struct MeshSubDeviceManagerId;

struct MeshDeviceConfig {
MeshShape mesh_shape;
MeshOffset offset;
Expand Down Expand Up @@ -171,6 +175,12 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

size_t num_program_cache_entries() const;

MeshSubDeviceManagerId create_sub_device_manager(
tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size);
void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id);
void clear_loaded_sub_device_manager();
void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id);

static std::shared_ptr<MeshDevice> fetch_mesh_device(const std::vector<Device*>& devices);
static std::shared_ptr<MeshDevice> create(
const MeshDeviceConfig& config,
Expand All @@ -182,4 +192,12 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device);

// TODO: This will be removed once we have DistributedDevice
// Currently required since each device manages its own sub-device manager ids
struct MeshSubDeviceManagerId {
MeshSubDeviceManagerId(const MeshDevice& mesh_device);

std::vector<SubDeviceManagerId> sub_device_manager_ids;
};

} // namespace tt::tt_metal::distributed
8 changes: 4 additions & 4 deletions tt_metal/host_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ uint32_t CreateSemaphore(
* Initializes a global semaphore on all cores within the specified CoreRangeSet.
* This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL.
*
* Return value: std::unique_ptr<GlobalSemaphore>
* Return value: std::shared_ptr<GlobalSemaphore>
*
* | Argument | Description | Type | Valid Range | Required |
* |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------|
Expand All @@ -307,15 +307,15 @@ uint32_t CreateSemaphore(
* | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No |
*/
// clang-format on
std::unique_ptr<GlobalSemaphore> CreateGlobalSemaphore(
std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1);

// clang-format off
/**
* Initializes a global semaphore on all cores within the specified CoreRangeSet.
* This only supports tensix cores, and can only use L1 buffer types like BufferType::L1 and BufferType::L1_SMALL.
*
* Return value: std::unique_ptr<GlobalSemaphore>
* Return value: std::shared_ptr<GlobalSemaphore>
*
* | Argument | Description | Type | Valid Range | Required |
* |---------------|------------------------------------------------------|-----------------------------------------------------------|--------------|----------|
Expand All @@ -325,7 +325,7 @@ std::unique_ptr<GlobalSemaphore> CreateGlobalSemaphore(
* | buffer_type | Buffer type to store the semaphore | BufferType | L1 types | No |
*/
// clang-format on
std::unique_ptr<GlobalSemaphore> CreateGlobalSemaphore(
std::shared_ptr<GlobalSemaphore> CreateGlobalSemaphore(
Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1);

// clang-format off
Expand Down
6 changes: 4 additions & 2 deletions tt_metal/impl/buffers/global_semaphore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type) {
this->reset_semaphore_value();
}

std::unique_ptr<GlobalSemaphore> GlobalSemaphore::create(
std::shared_ptr<GlobalSemaphore> GlobalSemaphore::create(
Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) {
return std::make_unique<GlobalSemaphore>(device, cores, initial_value, buffer_type);
}
std::unique_ptr<GlobalSemaphore> GlobalSemaphore::create(
std::shared_ptr<GlobalSemaphore> GlobalSemaphore::create(
Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type) {
return std::make_unique<GlobalSemaphore>(device, std::move(cores), initial_value, buffer_type);
}

Device* GlobalSemaphore::device() const { return device_; }

DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); }

void GlobalSemaphore::reset_semaphore_value() {
Expand Down
Loading

0 comments on commit 643ca05

Please sign in to comment.