From ae94cb180cb1d896b7c9730325d03943ff2abd9f Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Mon, 23 Sep 2024 23:29:20 -0400 Subject: [PATCH] #695: Runtime multi-device bootstrap - Replace all uses of Device with MeshDevice (#726) --- runtime/include/tt/runtime/detail/ttmetal.h | 5 +- runtime/include/tt/runtime/detail/ttnn.h | 11 ++-- runtime/include/tt/runtime/runtime.h | 5 +- runtime/include/tt/runtime/types.h | 4 -- runtime/lib/runtime.cpp | 18 +++++- runtime/lib/ttmetal/runtime.cpp | 52 +++++++++-------- .../lib/ttnn/include/tt/runtime/ttnn/types.h | 48 ++++++++++++++-- .../ttnn/operations/context/get_device.cpp | 56 +++++++++++++++---- runtime/lib/ttnn/operations/conv/conv2d.cpp | 8 ++- .../lib/ttnn/operations/creation/empty.cpp | 16 +++--- runtime/lib/ttnn/operations/creation/full.cpp | 7 ++- .../ttnn/operations/data_movement/concat.cpp | 5 +- .../ttnn/operations/data_movement/reshape.cpp | 2 +- .../operations/data_movement/transpose.cpp | 2 +- .../lib/ttnn/operations/deletion/dealloc.cpp | 2 +- .../lib/ttnn/operations/eltwise/binary.cpp | 2 +- runtime/lib/ttnn/operations/eltwise/unary.cpp | 2 +- .../ttnn/operations/embedding/embedding.cpp | 2 +- .../tt/runtime/ttnn/operations/utils.cpp | 7 --- .../tt/runtime/ttnn/operations/utils.h | 3 - .../ttnn/operations/layout/from_device.cpp | 2 +- .../lib/ttnn/operations/layout/to_device.cpp | 8 ++- .../lib/ttnn/operations/layout/to_layout.cpp | 7 ++- .../operations/layout/to_memory_config.cpp | 11 ++-- runtime/lib/ttnn/operations/matmul/matmul.cpp | 2 +- .../ttnn/operations/normalization/softmax.cpp | 2 +- .../lib/ttnn/operations/pool/maxpool2d.cpp | 9 +-- .../ttnn/operations/reduction/reduction.cpp | 2 +- runtime/lib/ttnn/program.cpp | 15 +++-- runtime/lib/ttnn/runtime.cpp | 46 ++++++++++----- runtime/test/ttnn/test_subtract.cpp | 5 +- runtime/tools/python/ttrt/common/run.py | 4 +- runtime/tools/python/ttrt/runtime/module.cpp | 14 ++--- 33 files changed, 251 insertions(+), 133 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index e655c4fb0..740841b29 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -63,8 +63,9 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { tt::target::DataType getTensorDataType(Tensor tensor); -Device openDevice(std::vector const &deviceIds = {0}, - std::vector const &numHWCQs = {}); +size_t getNumAvailableDevices(); + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); void closeDevice(Device device); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index c02765224..e03e5f7bd 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -42,6 +42,9 @@ #pragma clang diagnostic ignored "-Wc99-extensions" #define FMT_HEADER_ONLY +#include "host_api.hpp" +#include "hostdevcommon/common_values.hpp" +#include "impl/device/mesh_device.hpp" #include "ttnn/device.hpp" #include "ttnn/operations/conv/conv2d/conv2d.hpp" #include "ttnn/operations/copy.hpp" @@ -59,7 +62,6 @@ #include "ttnn/operations/reduction/generic/generic_reductions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" - #pragma clang diagnostic pop #include "tt/runtime/types.h" @@ -84,8 +86,9 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { tt::target::DataType getTensorDataType(Tensor tensor); -Device openDevice(std::vector const &deviceIds = {0}, - std::vector const &numHWCQs = {}); +size_t getNumAvailableDevices(); + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); void closeDevice(Device device); @@ -97,7 +100,7 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); -void runProgram(::ttnn::Device &device, +void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 395e7551b..05971f160 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -43,8 +43,9 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { tt::target::DataType getTensorDataType(Tensor tensor); -Device openDevice(std::vector const &deviceIds = {0}, - std::vector const &numHWCQs = {}); +size_t getNumAvailableDevices(); + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); void closeDevice(Device device); diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index bfb7e4ba5..d582a964c 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -104,10 +104,6 @@ struct Binary : public Flatbuffer { struct Device : public detail::RuntimeCheckedObjectImpl { using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; - - template static Device borrow(T &object, DeviceRuntime runtime) { - return Device(utils::unsafe_borrow_shared(&object), runtime); - } }; struct Event : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 7b34f04e5..31a5866ef 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -138,8 +138,22 @@ tt::target::DataType getTensorDataType(Tensor tensor) { throw std::runtime_error("runtime is not enabled"); } -Device openDevice(std::vector const &deviceIds, - std::vector const &numHWCQs) { +size_t getNumAvailableDevices() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getNumAvailableDevices(); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getNumAvailableDevices(); + } +#endif + throw std::runtime_error("runtime is not enabled"); +} + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { return ::tt::runtime::ttnn::openDevice(deviceIds, numHWCQs); diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 8e6b74472..247c6df1a 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -15,7 +15,7 @@ namespace tt::runtime::ttmetal { using ::tt::runtime::DeviceRuntime; constexpr inline std::size_t kHostBufferCommandQueueId = 0; using Events = std::vector>; -using DeviceMesh = std::vector<::tt::tt_metal::Device *>; +using DeviceList = std::vector<::tt::tt_metal::Device *>; using MetalTensor = std::variant>; @@ -58,29 +58,33 @@ tt::target::DataType getTensorDataType(Tensor tensor) { return ::tt::target::DataType::Float32; } -Device openDevice(std::vector const &deviceIds, - std::vector const &numHWCQs) { - assert(numHWCQs.empty() || numHWCQs.size() == deviceIds.size()); - std::shared_ptr deviceMesh = std::make_shared(); - int i = 0; - for (int deviceId : deviceIds) { - uint8_t num_hw_cqs = numHWCQs.empty() ? 1 : numHWCQs[i]; - deviceMesh->push_back(CreateDevice(deviceId, num_hw_cqs)); - ++i; - } - return Device(static_pointer_cast(deviceMesh), DeviceRuntime::TTMetal); +size_t getNumAvailableDevices() { + return ::tt::tt_metal::GetNumAvailableDevices(); +} + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { + assert(deviceIds.size() && "No devices specified"); + ::tt::tt_metal::MeshShape grid = std::make_pair(1, deviceIds.size()); + std::shared_ptr<::tt::tt_metal::MeshDevice> meshDevice = + ::tt::tt_metal::MeshDevice::create( + grid, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + ::tt::tt_metal::DispatchCoreType::WORKER); + + return Device(std::static_pointer_cast(meshDevice), + DeviceRuntime::TTMetal); } void closeDevice(Device device) { - DeviceMesh &deviceMesh = device.as(DeviceRuntime::TTMetal); - for (::tt::tt_metal::Device *device : deviceMesh) { - ::tt::tt_metal::CloseDevice(device); - } + ::tt::tt_metal::MeshDevice &ttmetalMeshDevice = + device.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal); + ttmetalMeshDevice.close_devices(); } void deallocateBuffers(Device deviceHandle) { - DeviceMesh &deviceMesh = deviceHandle.as(DeviceRuntime::TTMetal); - for (::tt::tt_metal::Device *device : deviceMesh) { + ::tt::tt_metal::MeshDevice &meshDevice = + deviceHandle.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal); + + for (::tt::tt_metal::Device *device : meshDevice.get_devices()) { device->deallocate_buffers(); } } @@ -166,13 +170,17 @@ Event submit(Device deviceHandle, Binary executableHandle, ::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle); ::tt::target::metal::Program const *program = fbb.programs()->Get(programIndex); - DeviceMesh &deviceMesh = deviceHandle.as(DeviceRuntime::TTMetal); - assert(deviceMesh.size() == 1 && "Only one device is supported for now"); + ::tt::tt_metal::MeshDevice &meshDevice = + deviceHandle.as<::tt::tt_metal::MeshDevice>(DeviceRuntime::TTMetal); + DeviceList allDevices = meshDevice.get_devices(); + assert(allDevices.size() > 0 && "Unexpected empty device mesh"); + DeviceList deviceList = {allDevices[0]}; + assert(deviceList.size() == 1 && "Only one device is supported for now"); std::shared_ptr events = std::make_shared(); - assert(program->device_programs()->size() == deviceMesh.size() && + assert(program->device_programs()->size() == deviceList.size() && "Device programs size mismatch"); for (std::size_t i = 0; i < program->device_programs()->size(); ++i) { - ::tt::tt_metal::Device *device = deviceMesh[i]; + ::tt::tt_metal::Device *device = deviceList[i]; ZoneScoped; std::string zoneName = "submit_" + std::string(program->name()->c_str()) + diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index 2fdc41485..5da980435 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -9,7 +9,6 @@ namespace tt::runtime::ttnn { -using DeviceMap = std::unordered_map; using TensorMap = std::unordered_map; struct ProgramTensorPool { ProgramTensorPool(const TensorMap &liveTensors) : liveTensors(liveTensors) {} @@ -56,13 +55,50 @@ struct ProgramTensorPool { std::unordered_map intermedTensors; }; -struct ProgramContext { +class ProgramContext { +public: + ProgramContext(const TensorMap &liveTensors, ::ttnn::MeshDevice *meshDevice) + : tensorPool(ProgramTensorPool(liveTensors)), meshDevice(meshDevice) {} + + const ::ttnn::MeshDevice &getMeshDevice() const { + assert(meshDevice && "Mesh device not initialized"); + return *meshDevice; + } + + ::ttnn::MeshDeviceView &getMeshView(uint32_t globalId) { + assert(meshViews.contains(globalId) && + "Mesh view with global id not initialized"); + return *(meshViews.at(globalId)); + } + + ProgramTensorPool &getTensorPool() { return tensorPool; } + + void addMeshView(uint32_t globalId, + std::unique_ptr<::ttnn::MeshDeviceView> view) { + assert(not meshViews.contains(globalId) && + "Mesh view with globalId already set"); + meshViews.try_emplace(globalId, std::move(view)); + } + + ::ttnn::Device &getDeviceFromView(uint32_t globalId, int deviceId) { + assert(meshViews.contains(globalId) && "Mesh view not initialized"); + ::tt::tt_metal::Coordinate deviceCoord = + meshViews.at(globalId)->find_device(deviceId); + return *( + meshViews.at(globalId)->get_device(deviceCoord.row, deviceCoord.col)); + } + +private: ProgramTensorPool tensorPool; - DeviceMap allDevices; - DeviceMap devicePool; - ProgramContext(const TensorMap &liveTensors, const DeviceMap &allDevices) - : tensorPool(ProgramTensorPool(liveTensors)), allDevices(allDevices) {} + // Contains all devices borrowed from the user that are available to the + // program + ::ttnn::MeshDevice *meshDevice = nullptr; + + // Contains various views of meshDevice that is used by the program + // Will be populated by get_device ops + std::unordered_map> + meshViews; }; } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/operations/context/get_device.cpp b/runtime/lib/ttnn/operations/context/get_device.cpp index 36aa40be3..545368747 100644 --- a/runtime/lib/ttnn/operations/context/get_device.cpp +++ b/runtime/lib/ttnn/operations/context/get_device.cpp @@ -8,16 +8,52 @@ #include "ttmlir/Target/TTNN/program_generated.h" namespace tt::runtime::ttnn::operations::context { -void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) { - DeviceMap &devicePool = context.devicePool; - DeviceMap &allDevices = context.allDevices; - const flatbuffers::Vector *chipIds = op->chip_ids(); - assert(chipIds->size() == 1 && "Expected 1 chip id"); - for (const uint32_t chipId : *chipIds) { - assert(allDevices.contains(chipId) && "Device not found"); - auto [iter, inserted] = - devicePool.try_emplace(chipId, allDevices.at(chipId)); - assert(inserted && "Duplicate device"); + +static std::pair<::tt::tt_metal::Coordinate, ::tt::tt_metal::Coordinate> +deriveMeshViewCoordinates(const ::ttnn::MeshDevice &meshDevice, + const std::unordered_set &desiredDeviceIds, + const ::tt::target::Dim2d *meshViewShape) { + ::tt::tt_metal::Coordinate topLeft, bottomRight; + for (int row = 0; row < meshDevice.num_rows(); row++) { + for (int col = 0; col < meshDevice.num_cols(); col++) { + const ::ttnn::Device *currDevice = meshDevice.get_device(row, col); + if (desiredDeviceIds.contains(currDevice->id())) { + topLeft.row = row; + topLeft.col = col; + // coords are inclusive when constructing mesh view + bottomRight.row = topLeft.row + meshViewShape->y() - 1; + bottomRight.col = topLeft.col + meshViewShape->x() - 1; + return std::make_pair(topLeft, bottomRight); + } + } } + throw std::runtime_error("Device not found in mesh for get device op"); +} + +static std::unique_ptr<::ttnn::MeshDeviceView> +constructMeshView(const ::ttnn::MeshDevice &meshDevice, + const std::unordered_set &desiredDeviceIds, + const ::tt::target::Dim2d *meshViewShape) { + // Carve out a mesh view from MeshDevice + auto [topLeft, bottomRight] = + deriveMeshViewCoordinates(meshDevice, desiredDeviceIds, meshViewShape); + + return std::make_unique<::ttnn::MeshDeviceView>(meshDevice, topLeft, + bottomRight); +} + +void run(const ::tt::target::ttnn::GetDeviceOp *op, ProgramContext &context) { + const ::ttnn::MeshDevice &meshDevice = context.getMeshDevice(); + const ::tt::target::Dim2d *meshViewShape = op->mesh(); + assert(meshViewShape->y() == 1 && + "Expected 1xN mesh shape for get device op"); + const ::flatbuffers::Vector *deviceIds = op->chip_ids(); + std::unordered_set desiredDeviceIds(deviceIds->begin(), + deviceIds->end()); + assert(desiredDeviceIds.size() == deviceIds->size() && + "Duplicate device ids in get device op"); + std::unique_ptr<::ttnn::MeshDeviceView> meshView = + constructMeshView(meshDevice, desiredDeviceIds, meshViewShape); + context.addMeshView(op->out()->global_id(), std::move(meshView)); } } // namespace tt::runtime::ttnn::operations::context diff --git a/runtime/lib/ttnn/operations/conv/conv2d.cpp b/runtime/lib/ttnn/operations/conv/conv2d.cpp index c3b800dd8..7d030fa1e 100644 --- a/runtime/lib/ttnn/operations/conv/conv2d.cpp +++ b/runtime/lib/ttnn/operations/conv/conv2d.cpp @@ -9,8 +9,11 @@ namespace tt::runtime::ttnn::operations::conv { void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + // Investigate how to handle multi device in conv2d + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); std::optional<::ttnn::Tensor> bias = @@ -19,7 +22,6 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig(); config.dtype = utils::getDataType(op->input()); config.weights_dtype = utils::getDataType(op->weight()); - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); ::ttnn::Tensor out = std::get<0>(::ttnn::operations::conv::conv2d::conv2d<::ttnn::Device>( input, weight, &device, op->in_channels(), op->out_channels(), diff --git a/runtime/lib/ttnn/operations/creation/empty.cpp b/runtime/lib/ttnn/operations/creation/empty.cpp index 83255ba06..62ddffd9e 100644 --- a/runtime/lib/ttnn/operations/creation/empty.cpp +++ b/runtime/lib/ttnn/operations/creation/empty.cpp @@ -9,10 +9,11 @@ namespace tt::runtime::ttnn::operations::creation { void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; - ::ttnn::DataType dtype = - ::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype()); + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); + ::ttnn::DataType dtype = utils::getDataType(op->out()); // TODO(bug #582): ttnn::empty doesn't work properly with tile layout, // using ROW_MAJOR until we fix it ::ttnn::Layout layout __attribute__((unused)) = @@ -22,13 +23,12 @@ void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) { ::tt::runtime::ttnn::utils::toShapeFromFBShape( *op->out()->desc()->shape()))); - const tt::target::DeviceRef *device = op->device(); + const tt::target::DeviceRef *deviceRef = op->device(); ::ttnn::Tensor out; - if (device) { + if (deviceRef) { ::ttnn::MemoryConfig memoryConfig = utils::createMemoryConfig(op->memcfg(), op->out()); - out = ::ttnn::empty(shape, dtype, layout, - &utils::getDevice(device, devicePool), memoryConfig); + out = ::ttnn::empty(shape, dtype, layout, &device, memoryConfig); } else { out = ::ttnn::zeros(shape, dtype, layout); } diff --git a/runtime/lib/ttnn/operations/creation/full.cpp b/runtime/lib/ttnn/operations/creation/full.cpp index 5ecef9d0b..7e5586d15 100644 --- a/runtime/lib/ttnn/operations/creation/full.cpp +++ b/runtime/lib/ttnn/operations/creation/full.cpp @@ -9,9 +9,10 @@ namespace tt::runtime::ttnn::operations::creation { void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap devicePool = context.devicePool; - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); ::ttnn::DataType outputDataType = utils::getDataType(op->out()); auto shape = ::ttnn::Shape(::tt::tt_metal::LegacyShape( ::tt::runtime::ttnn::utils::toShapeFromFBShape( diff --git a/runtime/lib/ttnn/operations/data_movement/concat.cpp b/runtime/lib/ttnn/operations/data_movement/concat.cpp index e904adc2a..d798bf6bd 100644 --- a/runtime/lib/ttnn/operations/data_movement/concat.cpp +++ b/runtime/lib/ttnn/operations/data_movement/concat.cpp @@ -7,12 +7,13 @@ namespace tt::runtime::ttnn::operations::data_movement { void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); std::vector<::ttnn::Tensor> inputs; for (const auto &input : *op->inputs()) { - inputs.push_back(context.tensorPool.at(input->global_id())); + inputs.push_back(tensorPool.at(input->global_id())); } int32_t dim = op->dim(); ::ttnn::Tensor out = ::ttnn::concat(inputs, dim); - context.tensorPool.insert_or_assign(op->out()->global_id(), out); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::data_movement diff --git a/runtime/lib/ttnn/operations/data_movement/reshape.cpp b/runtime/lib/ttnn/operations/data_movement/reshape.cpp index d91788687..dc2788984 100644 --- a/runtime/lib/ttnn/operations/data_movement/reshape.cpp +++ b/runtime/lib/ttnn/operations/data_movement/reshape.cpp @@ -25,7 +25,7 @@ static ::ttnn::Tensor invoke_reshape(const ::ttnn::Tensor &tensor, } void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); const auto *fbShape = op->shape(); std::vector shape(fbShape->begin(), fbShape->end()); diff --git a/runtime/lib/ttnn/operations/data_movement/transpose.cpp b/runtime/lib/ttnn/operations/data_movement/transpose.cpp index bcf0b4c33..f299d5129 100644 --- a/runtime/lib/ttnn/operations/data_movement/transpose.cpp +++ b/runtime/lib/ttnn/operations/data_movement/transpose.cpp @@ -8,7 +8,7 @@ namespace tt::runtime::ttnn::operations::data_movement { void run(const ::tt::target::ttnn::TransposeOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); int32_t dim0 = op->dim0(); int32_t dim1 = op->dim1(); diff --git a/runtime/lib/ttnn/operations/deletion/dealloc.cpp b/runtime/lib/ttnn/operations/deletion/dealloc.cpp index 0fc992583..adb87ff9c 100644 --- a/runtime/lib/ttnn/operations/deletion/dealloc.cpp +++ b/runtime/lib/ttnn/operations/deletion/dealloc.cpp @@ -7,7 +7,7 @@ namespace tt::runtime::ttnn::operations::deletion { void run(const ::tt::target::ttnn::DeallocOp *op, ProgramContext &context) { bool force = true; - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); ::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id()); tensor.deallocate(force); tensorPool.erase(op->in()->global_id()); diff --git a/runtime/lib/ttnn/operations/eltwise/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary.cpp index b6143d358..25cca3fb2 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary.cpp @@ -60,7 +60,7 @@ static void runEltwiseBinaryCompositeOP( void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { assert(isBinaryOp(op) && "Expected binary operation"); - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { /* Eltwise Binary */ case ::tt::target::ttnn::EltwiseOpType::Add: { diff --git a/runtime/lib/ttnn/operations/eltwise/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary.cpp index 18b00bd8c..abcaae632 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary.cpp @@ -54,7 +54,7 @@ static void runEltwiseUnaryWithFastAndApproximateModeOP( void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { assert(isUnaryOp(op) && "Expected binary operation"); - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { case ::tt::target::ttnn::EltwiseOpType::Abs: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs); diff --git a/runtime/lib/ttnn/operations/embedding/embedding.cpp b/runtime/lib/ttnn/operations/embedding/embedding.cpp index 433428eda..b3c4bfac8 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding.cpp @@ -9,7 +9,7 @@ namespace tt::runtime::ttnn::operations::embedding { void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &input = tensorPool.at(op->input()->global_id()); const ::ttnn::Tensor &weight = tensorPool.at(op->weight()->global_id()); // default params for embedding op diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index 0f7792b74..929d3aa55 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -27,13 +27,6 @@ ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) { tensorRef->desc()->layout()->memory_desc()->data_type()); } -::ttnn::Device &getDevice(const ::tt::target::DeviceRef *deviceRef, - DeviceMap &devicePool) { - uint32_t deviceId = deviceRef->global_id(); - assert(devicePool.contains(deviceId) && "Device not found in device pool"); - return *devicePool.at(deviceId); -} - CoreRangeSet toCoreRangeSet( const ::flatbuffers::Vector *coreRangeSet) { std::set coreRanges; diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index b977303de..3d6d207e4 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -19,9 +19,6 @@ bool isOnDevice(const ::ttnn::Tensor &tensor); ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef); -::ttnn::Device &getDevice(const ::tt::target::DeviceRef *deviceRef, - DeviceMap &devicePool); - CoreRangeSet toCoreRangeSet( const ::flatbuffers::Vector *coreRangeSet); diff --git a/runtime/lib/ttnn/operations/layout/from_device.cpp b/runtime/lib/ttnn/operations/layout/from_device.cpp index 613859135..56835cc59 100644 --- a/runtime/lib/ttnn/operations/layout/from_device.cpp +++ b/runtime/lib/ttnn/operations/layout/from_device.cpp @@ -9,7 +9,7 @@ namespace tt::runtime::ttnn::operations::layout { void run(const ::tt::target::ttnn::FromDeviceOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); assert(utils::isOnDevice(inputTensor) && "Calling ttnn::from_device on a host tensor"); diff --git a/runtime/lib/ttnn/operations/layout/to_device.cpp b/runtime/lib/ttnn/operations/layout/to_device.cpp index df8fcd6c0..5f16b903c 100644 --- a/runtime/lib/ttnn/operations/layout/to_device.cpp +++ b/runtime/lib/ttnn/operations/layout/to_device.cpp @@ -9,15 +9,17 @@ namespace tt::runtime::ttnn::operations::layout { void run(const ::tt::target::ttnn::ToDeviceOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); assert(utils::isOnHost(inputTensor) && "Calling ttnn::to_device on a device tensor"); ::ttnn::MemoryConfig memoryConfig = utils::createMemoryConfig(op->memcfg(), op->out()); - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); + ::ttnn::Tensor out = ::ttnn::to_device(inputTensor, &device, memoryConfig); tensorPool.try_emplace(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/layout/to_layout.cpp b/runtime/lib/ttnn/operations/layout/to_layout.cpp index adfeb8b43..86c5e5ae6 100644 --- a/runtime/lib/ttnn/operations/layout/to_layout.cpp +++ b/runtime/lib/ttnn/operations/layout/to_layout.cpp @@ -8,8 +8,10 @@ namespace tt::runtime::ttnn::operations::layout { void run(const ::tt::target::ttnn::ToLayoutOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); assert((utils::isOnHost(inputTensor) or utils::isOnDevice(inputTensor)) && "Unsupported storage type"); @@ -27,7 +29,6 @@ void run(const ::tt::target::ttnn::ToLayoutOp *op, ProgramContext &context) { break; } - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); ::ttnn::Tensor out = ::ttnn::to_layout(inputTensor, layout, std::nullopt, std::nullopt, &device); diff --git a/runtime/lib/ttnn/operations/layout/to_memory_config.cpp b/runtime/lib/ttnn/operations/layout/to_memory_config.cpp index dfc3442ff..dfeb39604 100644 --- a/runtime/lib/ttnn/operations/layout/to_memory_config.cpp +++ b/runtime/lib/ttnn/operations/layout/to_memory_config.cpp @@ -197,8 +197,7 @@ handleToL1MemoryConfigOp(::ttnn::Device &device, void run(const ::tt::target::ttnn::ToMemoryConfigOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in0()->global_id()); assert(utils::isOnHost(inputTensor) or utils::isOnDevice(inputTensor) && "Unsupported storage type"); @@ -220,12 +219,16 @@ void run(const ::tt::target::ttnn::ToMemoryConfigOp *op, break; } case ::tt::target::MemorySpace::DeviceDRAM: { - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); handleToDramMemoryConfigOp(device, op->in0(), op->out(), tensorPool); break; } case ::tt::target::MemorySpace::DeviceL1: { - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); + // TODO (jnie): Update this once we support multi device tensors + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); handleToL1MemoryConfigOp(device, op->in0(), op->out(), tensorPool); break; } diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index 458cfe444..7565605d1 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -9,7 +9,7 @@ namespace tt::runtime::ttnn::operations::matmul { // ANCHOR: adding_an_op_matmul_runtime void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &lhs = tensorPool.at(op->in0()->global_id()); const ::ttnn::Tensor &rhs = tensorPool.at(op->in1()->global_id()); ::ttnn::DataType outputDataType = utils::getDataType(op->out()); diff --git a/runtime/lib/ttnn/operations/normalization/softmax.cpp b/runtime/lib/ttnn/operations/normalization/softmax.cpp index eba2323d4..f02467803 100644 --- a/runtime/lib/ttnn/operations/normalization/softmax.cpp +++ b/runtime/lib/ttnn/operations/normalization/softmax.cpp @@ -8,7 +8,7 @@ namespace tt::runtime::ttnn::operations::normalization { void run(const ::tt::target::ttnn::SoftmaxOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); int32_t dimension = op->dimension(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index 2e1100990..e2b7d16fb 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -37,13 +37,14 @@ preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, } void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; - DeviceMap &devicePool = context.devicePool; + ProgramTensorPool &tensorPool = context.getTensorPool(); + // TODO (jnie): Update this once we support multi device tensors + // Investigate how to handle multi device in maxpool2d + ::ttnn::Device &device = + context.getDeviceFromView(op->device()->global_id(), 0); const ::ttnn::operations::pool::MaxPool2DOp operation = ::ttnn::operations::pool::MaxPool2DOp(); - ::ttnn::Device &device = utils::getDevice(op->device(), devicePool); - const ::ttnn::Tensor preShardedInput = preshardForMaxPool2d(op, device, tensorPool); diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 88925e3d2..de705a7da 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -33,7 +33,7 @@ static void runReductionOp( } void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) { - ProgramTensorPool &tensorPool = context.tensorPool; + ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { case ::tt::target::ttnn::ReductionOpType::Sum: { runReductionOp(op, tensorPool, ::ttnn::sum); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 379286d52..00f4ca424 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -26,9 +26,8 @@ namespace tt::runtime::ttnn { struct ProgramExecutor { - ProgramContext context; - ProgramExecutor(const TensorMap &liveTensors, const DeviceMap &allDevices) - : context(ProgramContext(liveTensors, allDevices)) {} + ProgramExecutor(const TensorMap &liveTensors, ::ttnn::MeshDevice *meshDevice) + : context(ProgramContext(liveTensors, meshDevice)) {} void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { @@ -36,7 +35,10 @@ struct ProgramExecutor { } } + ProgramContext &getContext() { return context; } + private: + ProgramContext context; void runOperation(const ::tt::target::ttnn::Operation *op); }; @@ -127,7 +129,7 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program, return isNop; } -void runProgram(::ttnn::Device &device, +void runProgram(::ttnn::MeshDevice &meshDevice, ::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs) { @@ -135,11 +137,8 @@ void runProgram(::ttnn::Device &device, return; } TensorMap liveTensors; - DeviceMap allDevices; int inputIndex = 0; assert(program->inputs()->size() == inputs.size()); - // Assuming single device for now until we support multichip - allDevices.try_emplace(device.id(), &device); for (::tt::target::TensorRef const *input : *program->inputs()) { auto [iter, inserted] = liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); @@ -153,7 +152,7 @@ void runProgram(::ttnn::Device &device, liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]); assert(inserted && "Duplicate output tensor"); } - ProgramExecutor executor(liveTensors, allDevices); + ProgramExecutor executor(liveTensors, &meshDevice); executor.execute(program); } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index fde7287bd..d4888d7a6 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -57,26 +57,45 @@ tt::target::DataType getTensorDataType(Tensor tensor) { return utils::fromTTNNDataType(nnTensor.get_dtype()); } -Device openDevice(std::vector const &deviceIds, - std::vector const &numHWCQs) { - assert(deviceIds.size() == 1 && "Only one device is supported for now"); - assert(numHWCQs.empty() && "HWCQs are not supported for now"); - auto &device = ::ttnn::open_device(deviceIds.front(), kL1SmallSize); +size_t getNumAvailableDevices() { + return ::tt::tt_metal::GetNumAvailableDevices(); +} + +Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { + assert(deviceIds.size() && "No devices specified"); + ::tt::tt_metal::MeshShape grid = std::make_pair(1, deviceIds.size()); + std::shared_ptr<::ttnn::MeshDevice> meshDevice = ::ttnn::MeshDevice::create( + grid, kL1SmallSize, DEFAULT_TRACE_REGION_SIZE, numHWCQs, + ::tt::tt_metal::DispatchCoreType::WORKER); + bool enableAsync = debug::Env::get().enableAsyncTTNN; - device.enable_async(enableAsync); - return Device::borrow(device, DeviceRuntime::TTNN); + for (::ttnn::Device *device : meshDevice->get_devices()) { + device->enable_async(enableAsync); + } + + return Device(std::static_pointer_cast(meshDevice), + DeviceRuntime::TTNN); } void closeDevice(Device device) { - auto &ttnn_device = device.as<::ttnn::Device>(DeviceRuntime::TTNN); + ::ttnn::MeshDevice &ttnnMeshDevice = + device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + #if defined(TT_RUNTIME_ENABLE_PERF_TRACE) - ::tt::tt_metal::detail::DumpDeviceProfileResults(&ttnn_device); + for (const ::ttnn::Device *ttnnDevice : ttnnMeshDevice.get_devices()) { + ::tt::tt_metal::tt_metal::detail::DumpDeviceProfileResults(ttnnDevice); + } #endif - ::ttnn::close_device(ttnn_device); + + ttnnMeshDevice.close_devices(); } void deallocateBuffers(Device deviceHandle) { - deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN).deallocate_buffers(); + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + for (::ttnn::Device *device : meshDevice.get_devices()) { + device->deallocate_buffers(); + } } static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { @@ -92,7 +111,8 @@ Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, std::vector const &outputHandles) { - ::ttnn::Device &device = deviceHandle.as<::ttnn::Device>(DeviceRuntime::TTNN); + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); std::vector<::ttnn::Tensor *> inputs; inputs.reserve(inputHandles.size()); @@ -106,7 +126,7 @@ Event submit(Device deviceHandle, Binary executableHandle, assert(output.matchesRuntime(DeviceRuntime::TTNN)); outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); } - tt::runtime::ttnn::runProgram(device, fbb.programs()->Get(programIndex), + tt::runtime::ttnn::runProgram(meshDevice, fbb.programs()->Get(programIndex), inputs, outputs); return Event(nullptr, DeviceRuntime::TTNN); } diff --git a/runtime/test/ttnn/test_subtract.cpp b/runtime/test/ttnn/test_subtract.cpp index f2f94a940..00aebe20f 100644 --- a/runtime/test/ttnn/test_subtract.cpp +++ b/runtime/test/ttnn/test_subtract.cpp @@ -46,7 +46,10 @@ TEST(TTNNSubtract, Equal) { outputTensors.emplace_back(::tt::runtime::createTensor(data, desc)); } - auto device = ::tt::runtime::openDevice(); + size_t numDevices = ::tt::runtime::getNumAvailableDevices(); + std::vector deviceIds(numDevices); + std::iota(deviceIds.begin(), deviceIds.end(), 0); + auto device = ::tt::runtime::openDevice(deviceIds); auto ev = ::tt::runtime::submit(device, fbb, 0, inputTensors, outputTensors); ::tt::runtime::closeDevice(device); diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index a2215527b..831f0003f 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -241,8 +241,8 @@ def _execute(binaries): self.logging.debug(f"setting torch manual seed={self['--seed']}") torch.manual_seed(self["--seed"]) ttrt.runtime.set_compatible_runtime(binaries[0].fbb) - self.logging.debug(f"opening device id={self.query.device_ids[0]}") - device = ttrt.runtime.open_device([self.query.device_ids[0]]) + self.logging.debug(f"opening devices={self.query.device_ids}") + device = ttrt.runtime.open_device(self.query.device_ids) try: for bin in binaries: diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 8d8f6601d..158c6a8ff 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -14,8 +14,7 @@ namespace py = pybind11; PYBIND11_MODULE(_C, m) { m.doc() = "ttrt.runtime python extension for interacting with the " - "Tenstorrent devies"; - + "Tenstorrent devices"; py::class_(m, "Device") .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); @@ -58,11 +57,12 @@ PYBIND11_MODULE(_C, m) { shape, stride, itemsize, dataType); }, "Create a tensor with borrowed memory"); - m.def("open_device", &tt::runtime::openDevice, - py::arg("device_ids") = std::vector{0}, - py::arg("num_hw_cqs") = std::vector{}, - "Open a device for execution"); - m.def("close_device", &tt::runtime::closeDevice, "Close a device"); + m.def("get_num_available_devices", &tt::runtime::getNumAvailableDevices, + "Get the number of available devices"); + m.def("open_device", &tt::runtime::openDevice, py::arg("device_ids"), + py::arg("num_hw_cqs") = size_t{1}, + "Open a mesh of devices for execution"); + m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); m.def("submit", &tt::runtime::submit, py::arg("device"), py::arg("executable"), py::arg("program_index"), py::arg("inputs"), py::arg("outputs"), "Submit a binary for execution");