Skip to content

Commit

Permalink
#162: TTNN: Support for multi-device system descriptor.
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Aug 8, 2024
1 parent 5c61b9c commit 5bf6652
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 45 deletions.
7 changes: 4 additions & 3 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet(
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet(
MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx,
unsigned endpoint0,
unsigned endpoint1);
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipChannelAttrGet(
MlirContext ctx, unsigned deviceId0, int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size, unsigned deviceId1,
int64_t *ethernetCoreCoord1, size_t ethernetCoreCoord1Size);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet(
MlirContext ctx, MlirAttribute *chipDescs, size_t chipDescsSize,
Expand Down
7 changes: 5 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ def TT_ChipChannelAttr : TT_Attr<"ChipChannel", "chip_channel"> {
TT chip_channel attribute
}];

let parameters = (ins "unsigned":$endpoint0, "unsigned":$endpoint1);
let assemblyFormat = "`<` $endpoint0 `,` $endpoint1 `>`";
let parameters = (ins "unsigned":$deviceId0,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord0,
"unsigned":$deviceId1,
ArrayRefParameter<"int64_t">:$ethernetCoreCoord1);
let assemblyFormat = "`<` $deviceId0 `,` $ethernetCoreCoord0 `,` $deviceId1 `,` $ethernetCoreCoord1 `>`";
}

def TT_SystemDescAttr : TT_Attr<"SystemDesc", "system_desc"> {
Expand Down
6 changes: 4 additions & 2 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ struct ChipCoord {
}

struct ChipChannel {
endpoint0: uint32;
endpoint1: uint32;
device_id0: uint32;
ethernet_core_coord0: Dim2d;
device_id1: uint32;
ethernet_core_coord1: Dim2d;
}

table SystemDesc {
Expand Down
21 changes: 13 additions & 8 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,24 @@ toFlatbuffer(FlatbufferObjectCache &, ChipCapabilityAttr capabilityAttr) {
}

inline ::tt::target::ChipCoord toFlatbuffer(FlatbufferObjectCache &cache,
ChipCoordAttr chipCoord) {
const ChipCoordAttr &chipCoord) {
return ::tt::target::ChipCoord(chipCoord.getRack(), chipCoord.getShelf(),
chipCoord.getY(), chipCoord.getX());
}

inline ::tt::target::ChipChannel toFlatbuffer(FlatbufferObjectCache &cache,
ChipChannelAttr chipChannel) {
return ::tt::target::ChipChannel(chipChannel.getEndpoint0(),
chipChannel.getEndpoint1());
inline ::tt::target::ChipChannel
toFlatbuffer(FlatbufferObjectCache &cache, const ChipChannelAttr &chipChannel) {
return ::tt::target::ChipChannel(
chipChannel.getDeviceId0(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord0()[0],
chipChannel.getEthernetCoreCoord0()[1]),
chipChannel.getDeviceId1(),
::tt::target::Dim2d(chipChannel.getEthernetCoreCoord1()[0],
chipChannel.getEthernetCoreCoord1()[1]));
}

inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache,
GridAttr arch) {
const GridAttr &arch) {
assert(arch.getShape().size() == 2 && "expected a 2D grid");
return ::tt::target::Dim2d(arch.getShape()[0], arch.getShape()[1]);
}
Expand Down Expand Up @@ -188,7 +193,7 @@ toFlatbuffer(FlatbufferObjectCache &cache, ::llvm::ArrayRef<T> arr) {
}

inline flatbuffers::Offset<::tt::target::ChipDesc>
toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) {
toFlatbuffer(FlatbufferObjectCache &cache, const ChipDescAttr &chipDesc) {
assert(chipDesc.getGrid().size() == 2 && "expected a 2D grid");
auto grid = ::tt::target::Dim2d(chipDesc.getGrid()[0], chipDesc.getGrid()[1]);
return ::tt::target::CreateChipDesc(
Expand All @@ -200,7 +205,7 @@ toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) {
}

inline flatbuffers::Offset<::tt::target::SystemDesc>
toFlatbuffer(FlatbufferObjectCache &cache, SystemDescAttr systemDesc) {
toFlatbuffer(FlatbufferObjectCache &cache, const SystemDescAttr &systemDesc) {
auto chipDescs = toFlatbuffer(cache, systemDesc.getChipDescs());
auto chipDescIndices = toFlatbuffer(cache, systemDesc.getChipDescIndices());
auto chipCapabilities = toFlatbuffer(cache, systemDesc.getChipCapabilities());
Expand Down
15 changes: 12 additions & 3 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,18 @@ MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack,
return wrap(ChipCoordAttr::get(unwrap(ctx), rack, shelf, y, x));
}

MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned endpoint0,
unsigned endpoint1) {
return wrap(ChipChannelAttr::get(unwrap(ctx), endpoint0, endpoint1));
MlirAttribute ttmlirTTChipChannelAttrGet(MlirContext ctx, unsigned deviceId0,
int64_t *ethernetCoreCoord0,
size_t ethernetCoreCoord0Size,
unsigned deviceId1,
int64_t *ethernetCoreCoord1,
size_t ethernetCoreCoord1Size) {
std::vector<int64_t> ethCoord0Vec(
ethernetCoreCoord0, ethernetCoreCoord0 + ethernetCoreCoord0Size);
std::vector<int64_t> ethCoord1Vec(
ethernetCoreCoord1, ethernetCoreCoord1 + ethernetCoreCoord1Size);
return wrap(ChipChannelAttr::get(unwrap(ctx), deviceId0, ethCoord0Vec,
deviceId1, ethCoord1Vec));
}

MlirAttribute ttmlirTTSystemDescAttrGet(
Expand Down
13 changes: 8 additions & 5 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ void populateTTModule(py::module &m) {
});

py::class_<tt::ChipChannelAttr>(m, "ChipChannelAttr")
.def_static("get",
[](MlirContext ctx, unsigned endpoint0, unsigned endpoint1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), endpoint0,
endpoint1));
});
.def_static("get", [](MlirContext ctx, unsigned deviceId0,
std::vector<int64_t> ethernetCoreCoord0,
unsigned deviceId1,
std::vector<int64_t> ethernetCoreCoord1) {
return wrap(tt::ChipChannelAttr::get(unwrap(ctx), deviceId0,
ethernetCoreCoord0, deviceId1,
ethernetCoreCoord1));
});

py::class_<tt::SystemDescAttr>(m, "SystemDescAttr")
.def_static("get", [](MlirContext ctx,
Expand Down
124 changes: 103 additions & 21 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@
//
// SPDX-License-Identifier: Apache-2.0
#include "tt/runtime/runtime.h"
#include "hostdevcommon/common_values.hpp"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/utils.h"
#include "utils.h"
#include <numeric>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wsign-compare"
#include "ttnn/multi_device.hpp"
#pragma clang diagnostic pop

#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
#include <numeric>

namespace tt::runtime::ttnn {

static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
switch (arch) {
case ::tt::ARCH::GRAYSKULL:
Expand All @@ -26,44 +35,91 @@ static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
throw std::runtime_error("Unsupported arch");
}

static ::tt::target::Dim2d toFlatbuffer(CoreCoord coreCoord) {
static ::tt::target::Dim2d toFlatbuffer(const CoreCoord &coreCoord) {
return ::tt::target::Dim2d(coreCoord.y, coreCoord.x);
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
std::vector<int> chipIds;
static std::vector<::tt::target::ChipChannel>
getAllDeviceConnections(const vector<::ttnn::Device *> &devices) {
std::set<std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>>
connectionSet;

auto addConnection = [&connectionSet](
chip_id_t deviceId0, CoreCoord ethCoreCoord0,
chip_id_t deviceId1, CoreCoord ethCoreCoord1) {
if (deviceId0 > deviceId1) {
std::swap(deviceId0, deviceId1);
std::swap(ethCoreCoord0, ethCoreCoord1);
}
connectionSet.emplace(deviceId0, ethCoreCoord0, deviceId1, ethCoreCoord1);
};

for (const ::ttnn::Device *device : devices) {
std::unordered_set<CoreCoord> activeEthernetCores =
device->get_active_ethernet_cores(true);
for (const CoreCoord &ethernetCore : activeEthernetCores) {
std::tuple<chip_id_t, CoreCoord> connectedDevice =
device->get_connected_ethernet_core(ethernetCore);
addConnection(device->id(), ethernetCore, std::get<0>(connectedDevice),
std::get<1>(connectedDevice));
}
}

std::vector<::tt::target::ChipChannel> allConnections;
allConnections.resize(connectionSet.size());

std::transform(
connectionSet.begin(), connectionSet.end(), allConnections.begin(),
[](const std::tuple<chip_id_t, CoreCoord, chip_id_t, CoreCoord>
&connection) {
return ::tt::target::ChipChannel(
std::get<0>(connection), toFlatbuffer(std::get<1>(connection)),
std::get<2>(connection), toFlatbuffer(std::get<3>(connection)));
});

return allConnections;
}

static std::unique_ptr<SystemDesc>
getCurrentSystemDescImpl(const ::ttnn::multi_device::DeviceMesh &deviceMesh) {
std::vector<::ttnn::Device *> devices = deviceMesh.get_devices();
std::sort(devices.begin(), devices.end(),
[](const ::ttnn::Device *a, const ::ttnn::Device *b) {
return a->id() < b->id();
});

std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs;
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;
// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords;
::flatbuffers::FlatBufferBuilder fbb;
for (size_t deviceId = 0; deviceId < numDevices; deviceId++) {
auto &device = ::ttnn::open_device(deviceId);
chipIds.push_back(device.id());
::tt::target::Dim2d deviceGrid = toFlatbuffer(device.logical_grid_size());
chipDescs.emplace_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device.arch()), &deviceGrid,
device.l1_size_per_core(), device.num_dram_channels(),
device.dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,

for (const ::ttnn::Device *device : devices) {
// Construct chip descriptor
::tt::target::Dim2d deviceGrid = toFlatbuffer(device->logical_grid_size());
chipDescs.push_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT));
chipDescIndices.push_back(deviceId);
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
::tt::target::ChipCapability::NONE;
if (device.is_mmio_capable()) {
if (device->is_mmio_capable()) {
chipCapability = chipCapability | ::tt::target::ChipCapability::PCIE |
::tt::target::ChipCapability::HostMMIO;
}
chipCapabilities.push_back(chipCapability);
int x, y, rack, shelf;
std::tie(x, y, rack, shelf) = device.get_chip_location();
chipCoords.emplace_back(::tt::target::ChipCoord(rack, shelf, y, x));
::ttnn::close_device(device);
}
std::vector<::tt::target::ChipChannel> chipChannel;
// Extract chip connected channels
std::vector<::tt::target::ChipChannel> allConnections =
getAllDeviceConnections(devices);
// Create SystemDesc
auto systemDesc = ::tt::target::CreateSystemDescDirect(
fbb, &chipDescs, &chipDescIndices, &chipCapabilities, &chipCoords,
&chipChannel);
&allConnections);
::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion();
::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor,
ttmlirVersion.patch);
Expand All @@ -78,7 +134,33 @@ std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
auto size = fbb.GetSize();
auto handle = ::tt::runtime::utils::malloc_shared(size);
std::memcpy(handle.get(), buf, size);
return std::make_pair(SystemDesc(handle), chipIds);
return std::make_unique<SystemDesc>(handle);
}

std::pair<SystemDesc, DeviceIds> getCurrentSystemDesc() {
size_t numDevices = ::tt::tt_metal::GetNumAvailableDevices();
size_t numPciDevices = ::tt::tt_metal::GetNumPCIeDevices();
TT_FATAL(numDevices % numPciDevices == 0,
"Unexpected non-rectangular grid of devices");
std::vector<chip_id_t> deviceIds(numDevices);
std::iota(deviceIds.begin(), deviceIds.end(), 0);
::ttnn::multi_device::DeviceGrid deviceGrid(numDevices / numPciDevices,
numPciDevices);
::ttnn::multi_device::DeviceMesh deviceMesh =
::ttnn::multi_device::open_device_mesh(deviceGrid, deviceIds,
DEFAULT_L1_SMALL_SIZE);
std::exception_ptr eptr = nullptr;
std::unique_ptr<SystemDesc> desc;
try {
desc = getCurrentSystemDescImpl(deviceMesh);
} catch (...) {
eptr = std::current_exception();
}
deviceMesh.close_devices();
if (eptr) {
std::rethrow_exception(eptr);
}
return std::make_pair(*desc, deviceIds);
}

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions runtime/test/ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_runtime_gtest(subtract_test test_subtract.cpp)
add_runtime_gtest(sys_desc_sanity test_generate_sys_desc.cpp)
12 changes: 12 additions & 0 deletions runtime/test/ttnn/test_generate_sys_desc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#ifndef TT_RUNTIME_ENABLE_TTNN
#error "TT_RUNTIME_ENABLE_TTNN must be defined"
#endif
#include "tt/runtime/runtime.h"
#include <gtest/gtest.h>

TEST(TTNNSysDesc, Sanity) {
auto sysDesc = ::tt::runtime::getCurrentSystemDesc();
}
4 changes: 4 additions & 0 deletions runtime/test/ttnn/test_subtract.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#ifndef TT_RUNTIME_ENABLE_TTNN
#error "TT_RUNTIME_ENABLE_TTNN must be defined"
#endif
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/runtime.h"
#include "tt/runtime/utils.h"
#include <cstring>
Expand Down
3 changes: 2 additions & 1 deletion runtime/tools/python/ttrt/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def run(args):

# execution
print("executing action for all provided flatbuffers")
device = ttrt.runtime.open_device(device_ids)
system_desc, device_ids = ttrt.runtime.get_current_system_desc()
device = ttrt.runtime.open_device([device_ids[0]])
atexit.register(lambda: ttrt.runtime.close_device(device))

torch.manual_seed(args.seed)
Expand Down

0 comments on commit 5bf6652

Please sign in to comment.