diff --git a/include/ttmlir/Target/TTMetal/binary.fbs b/include/ttmlir/Target/TTMetal/binary.fbs index 99fe502a2..c3ca5bdda 100644 --- a/include/ttmlir/Target/TTMetal/binary.fbs +++ b/include/ttmlir/Target/TTMetal/binary.fbs @@ -5,6 +5,8 @@ include "command.fbs"; namespace tt.target.metal; table DeviceProgram { + inputs: [TensorRef]; + outputs: [TensorRef]; command_queues: [CommandQueue]; } diff --git a/lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp b/lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp index 59e4d7e14..e75c96ed6 100644 --- a/lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp +++ b/lib/Dialect/TTMetal/Transforms/SerializeToBinary.cpp @@ -222,7 +222,8 @@ class TTMetalSerializeToBinary std::vector<::flatbuffers::Offset<::tt::target::metal::DeviceProgram>> devicePrograms = { - ::tt::target::metal::CreateDeviceProgramDirect(fbb, &commandQueues), + ::tt::target::metal::CreateDeviceProgramDirect( + fbb, &cqBuilder.inputs, &cqBuilder.outputs, &commandQueues), }; std::vector<::flatbuffers::Offset<::tt::target::metal::Program>> programs = diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 3345f73f1..b8c72d13f 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -64,15 +64,97 @@ Event submit(Device device, Binary executable, std::uint32_t programIndex, void wait(Event event); -std::shared_ptr<::tt::tt_metal::Event> executeCommandQueue( - ::tt::tt_metal::Device *device, ::tt::target::metal::CommandQueue const *cq, - std::size_t cq_id, - std::vector< - std::pair>> const - &inputs, - std::vector< - std::pair>> const - &outputs); +using InputBuffer = + std::tuple, + std::shared_ptr<::tt::tt_metal::Event>>; + +using OutputBuffer = + std::tuple>; + +std::shared_ptr<::tt::tt_metal::Event> +executeCommandQueue(::tt::tt_metal::Device *device, + ::tt::target::metal::CommandQueue const *cq, + std::size_t cq_id, std::vector const &inputs, + std::vector const &outputs); + +// Utils + +inline CoreRangeSet toCoreRangeSet( + ::flatbuffers::Vector const *coreRangeSet) { + std::set coreRanges; + for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { + CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); + // End is inclusive + CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1, + coreRange->loc().y() + coreRange->size().y() - 1); + coreRanges.emplace(start, end); + } + return CoreRangeSet(coreRanges); +} + +#pragma clang diagnostic push +// Needed to construct ShardedBufferConfig +#pragma clang diagnostic ignored "-Wc++20-designator" + +inline std::shared_ptr<::tt::tt_metal::Buffer> +createBufferFromTensorRef(::tt::tt_metal::Device *device, + ::tt::target::TensorRef const *tensorRef) { + ::tt::target::TensorDesc const *tensorDesc = tensorRef->desc(); + ::tt::target::LayoutDesc const *layout = tensorDesc->layout(); + CoreRangeSet coreRangeSet = toCoreRangeSet(layout->core_range_set()); + auto shardRank = layout->memory_desc()->shape()->size(); + ::tt::target::Dim2d const *tile_shape = layout->memory_desc()->tile_shape(); + std::array shardShape; + shardShape[1] = + layout->memory_desc()->shape()->Get(shardRank - 1) * tile_shape->x(); + shardShape[0] = tile_shape->y(); + for (unsigned i = 0; i < shardRank - 1; ++i) { + shardShape[0] *= layout->memory_desc()->shape()->Get(i); + } + ShardSpec shardSpec(coreRangeSet, shardShape); + std::array pageShape = {static_cast(tile_shape->y()), + shardShape[1]}; + + auto tensorRank = layout->stride()->size(); + auto innerDim = layout->stride()->Get(tensorRank - 2); + assert(layout->stride()->size() >= 2); + assert((layout->stride()->Get(0) * tensorDesc->shape()->Get(0)) % + (pageShape[0] * innerDim) == + 0); + assert(innerDim % pageShape[1] == 0); + std::array tensorShape = { + (layout->stride()->Get(0) * tensorDesc->shape()->Get(0)) / + (pageShape[0] * innerDim), + innerDim / pageShape[1], + }; + + ShardSpecBuffer shardSpecBuffer(shardSpec, pageShape, tensorShape); + assert(layout->memory_desc()->memory_space() == + ::tt::target::MemorySpace::DeviceDRAM || + layout->memory_desc()->memory_space() == + ::tt::target::MemorySpace::DeviceL1); + BufferType bufferType = layout->memory_desc()->memory_space() == + ::tt::target::MemorySpace::DeviceDRAM + ? BufferType::DRAM + : BufferType::L1; + uint64_t pageSize = + pageShape[0] * pageShape[1] * 4; // FIXME: Hardcoded data type size + uint64_t size = tensorShape[0] * tensorShape[1] * pageSize; + auto shardedBufferConfig = ShardedBufferConfig{ + .device = device, + .size = size, + .page_size = pageSize, + .buffer_type = bufferType, + .buffer_layout = TensorMemoryLayout::BLOCK_SHARDED, + .shard_parameters = shardSpecBuffer, + }; + std::shared_ptr<::tt::tt_metal::Buffer> buffer = + ::tt::tt_metal::CreateBuffer(shardedBufferConfig); + assert(tensorRef->address()); + buffer->set_address(tensorRef->address()); + return buffer; +} +#pragma clang diagnostic pop } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/binary.cpp b/runtime/lib/binary.cpp index acf7aceee..db84a8e23 100644 --- a/runtime/lib/binary.cpp +++ b/runtime/lib/binary.cpp @@ -138,7 +138,9 @@ std::vector getProgramInputs(Flatbuffer binary, std::uint32_t programIndex) { std::vector inputs; auto const *program = getBinary(binary)->programs()->Get(programIndex); - for (auto const *input : *program->inputs()) { + assert(program->device_programs()->size() == 1 && + "Currently only one device is supported"); + for (auto const *input : *program->device_programs()->Get(0)->inputs()) { TensorDesc desc; desc.shape = {input->desc()->shape()->begin(), input->desc()->shape()->end()}; @@ -156,7 +158,8 @@ std::vector getProgramOutputs(Flatbuffer binary, std::uint32_t programIndex) { std::vector outputs; auto const *program = getBinary(binary)->programs()->Get(programIndex); - for (auto const *output : *program->outputs()) { + assert(program->device_programs()->size() == 1); + for (auto const *output : *program->device_programs()->Get(0)->outputs()) { TensorDesc desc; desc.shape = {output->desc()->shape()->begin(), output->desc()->shape()->end()}; diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 6612cc7e8..3010a35e7 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -83,6 +83,14 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif } -void wait(Event) { throw std::runtime_error("Not implemented"); } +void wait(Event event) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + return ::tt::runtime::ttnn::wait(event); +#elif defined(TT_RUNTIME_ENABLE_TTMETAL) + return ::tt::runtime::ttmetal::wait(event); +#else + throw std::runtime_error("runtime is not enabled"); +#endif +} } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/command_queue.cpp b/runtime/lib/ttmetal/command_queue.cpp index 97b727734..d7e171963 100644 --- a/runtime/lib/ttmetal/command_queue.cpp +++ b/runtime/lib/ttmetal/command_queue.cpp @@ -11,27 +11,20 @@ #include "ttmlir/Target/TTMetal/Target.h" #include "ttmlir/Version.h" -// Needed to construct ShardedBufferConfig -#pragma clang diagnostic ignored "-Wc++20-designator" - namespace tt::runtime::ttmetal { struct CQExecutor { ::tt::tt_metal::Device *device; + std::vector> initEvents; std::unordered_map> buffers; std::unordered_map> events; ::tt::tt_metal::CommandQueue *cq; - CQExecutor( - ::tt::tt_metal::Device *device, std::size_t cq_id, - std::vector>> const - &inputs, - std::vector>> const - &outputs); + CQExecutor(::tt::tt_metal::Device *device, std::size_t cq_id, + std::vector const &inputs, + std::vector const &outputs); std::shared_ptr<::tt::tt_metal::Event> execute(::tt::target::metal::CommandQueue const *commandQueue); @@ -49,21 +42,21 @@ struct CQExecutor { void execute(::tt::target::metal::FinishCommand const *command); }; -CQExecutor::CQExecutor( - ::tt::tt_metal::Device *device, std::size_t cq_id, - std::vector< - std::pair>> const - &inputs, - std::vector< - std::pair>> const - &outputs) +CQExecutor::CQExecutor(::tt::tt_metal::Device *device, std::size_t cq_id, + std::vector const &inputs, + std::vector const &outputs) : device(device) { for (std::size_t i = 0; i < inputs.size(); ++i) { - buffers[inputs[i].first] = inputs[i].second; + auto [global_id, buffer, event] = inputs[i]; + buffers[global_id] = buffer; + if (event) { + initEvents.push_back(event); + } } for (std::size_t i = 0; i < outputs.size(); ++i) { - buffers[outputs[i].first] = outputs[i].second; + auto [global_id, buffer] = outputs[i]; + buffers[global_id] = buffer; } cq = &device->command_queue(cq_id); @@ -71,6 +64,11 @@ CQExecutor::CQExecutor( std::shared_ptr<::tt::tt_metal::Event> CQExecutor::execute(::tt::target::metal::CommandQueue const *commandQueue) { + for (auto const &event : initEvents) { + ::tt::tt_metal::EnqueueWaitForEvent(*cq, event); + } + initEvents.clear(); + for (::tt::target::metal::Command const *command : *commandQueue->commands()) { execute(command); @@ -134,18 +132,6 @@ void CQExecutor::execute(::tt::target::metal::Command const *command) { } } -static CoreRangeSet toCoreRangeSet( - ::flatbuffers::Vector const *coreRangeSet) { - std::set coreRanges; - for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { - CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); - CoreCoord end(coreRange->loc().x() + coreRange->size().x(), - coreRange->loc().y() + coreRange->size().y()); - coreRanges.emplace(start, end); - } - return CoreRangeSet(coreRanges); -} - static void writeFile(std::string const &fileName, char const *data, std::size_t size) { std::ofstream file(fileName); @@ -256,54 +242,8 @@ void CQExecutor::execute( void CQExecutor::execute( ::tt::target::metal::CreateBufferCommand const *command) { - ::tt::target::LayoutDesc const *layout = command->ref()->desc()->layout(); - CoreRangeSet coreRangeSet = toCoreRangeSet(layout->core_range_set()); - auto shardRank = layout->memory_desc()->shape()->size(); - std::array shardShape; - shardShape[1] = layout->memory_desc()->shape()->Get(shardRank - 1) * - layout->memory_desc()->tile_shape()->x(); - shardShape[0] = layout->memory_desc()->tile_shape()->y(); - for (unsigned i = 0; i < shardRank - 1; ++i) { - shardShape[0] *= layout->memory_desc()->shape()->Get(i); - } - ShardSpec shardSpec(coreRangeSet, shardShape); - - auto tensorRank = layout->stride()->size(); - std::array tensorShape; - assert(layout->stride()->size() >= 2); - tensorShape[1] = layout->stride()->Get(tensorRank - 2); - tensorShape[0] = - layout->stride()->Get(0) * command->ref()->desc()->shape()->Get(0); - - auto pageShape = shardShape; - ShardSpecBuffer shardSpecBuffer(shardSpec, pageShape, tensorShape); - - uint64_t gridVolume = 1; - for (auto dim2dRange : *layout->core_range_set()) { - gridVolume *= dim2dRange->size().x() * dim2dRange->size().y(); - } - - assert(layout->memory_desc()->memory_space() == - ::tt::target::MemorySpace::DeviceDRAM || - layout->memory_desc()->memory_space() == - ::tt::target::MemorySpace::DeviceL1); - BufferType bufferType = layout->memory_desc()->memory_space() == - ::tt::target::MemorySpace::DeviceDRAM - ? BufferType::DRAM - : BufferType::L1; - uint64_t size = gridVolume * layout->memory_desc()->size(); - auto shardedBufferConfig = ShardedBufferConfig{ - .device = device, - .size = size, - .page_size = size, - .buffer_type = bufferType, - .buffer_layout = TensorMemoryLayout::HEIGHT_SHARDED, - .shard_parameters = shardSpecBuffer, - }; - std::shared_ptr<::tt::tt_metal::Buffer> buffer = - ::tt::tt_metal::CreateBuffer(shardedBufferConfig); - buffer->set_address(command->ref()->address()); - buffers[command->ref()->global_id()] = buffer; + buffers[command->ref()->global_id()] = + createBufferFromTensorRef(device, command->ref()); } void CQExecutor::execute( @@ -352,15 +292,11 @@ void CQExecutor::execute(::tt::target::metal::FinishCommand const *) { ::tt::tt_metal::Finish(*cq); } -std::shared_ptr<::tt::tt_metal::Event> executeCommandQueue( - ::tt::tt_metal::Device *device, - ::tt::target::metal::CommandQueue const *commandQueue, std::size_t cq_id, - std::vector< - std::pair>> const - &inputs, - std::vector< - std::pair>> const - &outputs) { +std::shared_ptr<::tt::tt_metal::Event> +executeCommandQueue(::tt::tt_metal::Device *device, + ::tt::target::metal::CommandQueue const *commandQueue, + std::size_t cq_id, std::vector const &inputs, + std::vector const &outputs) { CQExecutor executor(device, cq_id, inputs, outputs); return executor.execute(commandQueue); } diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 010125997..76513521e 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -2,8 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "tt/runtime/runtime.h" +#include + #include "tt/runtime/detail/ttmetal.h" +#include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTMetal/Target.h" @@ -11,8 +13,11 @@ namespace tt::runtime::ttmetal { +constexpr inline std::size_t kHostBufferCommandQueueId = 0; using Events = std::vector>; using DeviceMesh = std::vector<::tt::tt_metal::Device *>; +using MetalTensor = + std::variant>; static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { switch (arch) { @@ -52,7 +57,8 @@ std::pair getCurrentSystemDesc() { ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, ttmlirVersion.patch); - ::tt::target::Dim2d deviceGrid = toFlatbuffer(device->logical_grid_size()); + ::tt::target::Dim2d deviceGrid = + toFlatbuffer(device->compute_with_storage_grid_size()); std::vector<::flatbuffers::Offset> chipDescs = { ::tt::target::CreateChipDesc( fbb, toFlatbuffer(device->arch()), &deviceGrid, (1 << 20), 12, @@ -95,12 +101,13 @@ Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, std::uint32_t itemsize, ::tt::target::DataType dataType) { - std::shared_ptr desc = std::make_shared(); - desc->shape = shape; - desc->stride = stride; - desc->itemsize = itemsize; - desc->dataType = dataType; - return Tensor(static_pointer_cast(desc), data); + TensorDesc desc; + desc.shape = shape; + desc.stride = stride; + desc.itemsize = itemsize; + desc.dataType = dataType; + std::shared_ptr tensor = std::make_shared(desc); + return Tensor(static_pointer_cast(tensor), data); } Device openDevice(std::vector const &deviceIds, @@ -123,6 +130,77 @@ void closeDevice(Device device) { } } +static std::pair, + std::shared_ptr<::tt::tt_metal::Event>> +prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, + void *data, ::tt::target::TensorRef const *tensorRef) { + if (std::holds_alternative(metalTensor)) { + // todo assert that tensorDesc matches hostTensorDesc + std::shared_ptr<::tt::tt_metal::Buffer> buffer = + createBufferFromTensorRef(device, tensorRef); + auto event = std::make_shared<::tt::tt_metal::Event>(); + ::tt::tt_metal::CommandQueue &cq = + device->command_queue(kHostBufferCommandQueueId); + bool const blocking = false; + ::tt::tt_metal::EnqueueWriteBuffer(cq, buffer, data, blocking); + ::tt::tt_metal::EnqueueRecordEvent(cq, event); + return std::make_pair(buffer, event); + } else if (std::holds_alternative>( + metalTensor)) { + std::shared_ptr<::tt::tt_metal::Buffer> buffer = + std::get>(metalTensor); + throw std::runtime_error("Input from buffer not supported yet"); + } + assert(false && "Unsupported tensor type"); + return std::make_pair(nullptr, nullptr); +} + +static std::shared_ptr<::tt::tt_metal::Buffer> +prepareOutput(::tt::tt_metal::Device *device, MetalTensor const *metalTensor, + ::tt::target::TensorRef const *tensorRef) { + assert(metalTensor != nullptr); + if (TensorDesc const *hostTensorDesc = std::get_if(metalTensor); + hostTensorDesc) { + return createBufferFromTensorRef(device, tensorRef); + } else if (std::shared_ptr<::tt::tt_metal::Buffer> const *buffer = + std::get_if>( + metalTensor); + buffer) { + return *buffer; + } + assert(false && "Unsupported tensor type"); + return nullptr; +} + +Events maybeCopyHostOutputs(::tt::tt_metal::Device *device, + std::vector const &outputHandles, + std::vector submitOutputs, + Events submitEvents) { + Events copyEvents; + int i = 0; + for (Tensor const &outputHandle : outputHandles) { + if (TensorDesc const *hostTensor = + std::get_if(&outputHandle.as()); + hostTensor) { + ::tt::tt_metal::CommandQueue &cq = + device->command_queue(kHostBufferCommandQueueId); + for (auto submitEvent : submitEvents) { + ::tt::tt_metal::EnqueueWaitForEvent(cq, submitEvent); + } + submitEvents.clear(); + auto event = std::make_shared<::tt::tt_metal::Event>(); + bool const blocking = false; + auto [global_id, buffer] = submitOutputs[i]; + ::tt::tt_metal::EnqueueReadBuffer(cq, buffer, outputHandle.data.get(), + blocking); + ::tt::tt_metal::EnqueueRecordEvent(cq, event); + copyEvents.push_back(event); + } + ++i; + } + return copyEvents; +} + Event submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles, @@ -130,43 +208,59 @@ Event submit(Device deviceHandle, Binary executableHandle, ::tt::target::metal::TTMetalBinary const &fbb = *getBinary(executableHandle); ::tt::target::metal::Program const *program = fbb.programs()->Get(programIndex); - std::vector>> - inputs; - inputs.reserve(inputHandles.size()); - assert(inputHandles.size() == program->inputs()->size() && - "Input size mismatch"); - for (unsigned i = 0; i < inputHandles.size(); ++i) { - inputs.emplace_back( - program->inputs()->Get(i)->global_id(), - static_pointer_cast<::tt::tt_metal::Buffer>(inputHandles[i].handle)); - } - - std::vector>> - outputs; - outputs.reserve(outputHandles.size()); - assert(outputHandles.size() == program->outputs()->size() && - "Output size mismatch"); - for (unsigned i = 0; i < outputHandles.size(); ++i) { - outputs.emplace_back( - program->outputs()->Get(i)->global_id(), - static_pointer_cast<::tt::tt_metal::Buffer>(outputHandles[i].handle)); - } - DeviceMesh &deviceMesh = deviceHandle.as(); + assert(deviceMesh.size() == 1 && "Only one device is supported for now"); std::shared_ptr events = std::make_shared(); - std::size_t cq_id = 0; assert(program->device_programs()->size() == deviceMesh.size() && "Device programs size mismatch"); for (std::size_t i = 0; i < program->device_programs()->size(); ++i) { ::tt::tt_metal::Device *device = deviceMesh[i]; ::tt::target::metal::DeviceProgram const *deviceProgram = program->device_programs()->Get(i); + Events deviceEvents; + + std::vector inputs; + inputs.reserve(inputHandles.size()); + assert(inputHandles.size() == deviceProgram->inputs()->size() && + "Input size mismatch"); + for (unsigned i = 0; i < inputHandles.size(); ++i) { + ::tt::target::TensorRef const *tensorRef = + deviceProgram->inputs()->Get(i); + auto [buffer, event] = + prepareInput(device, inputHandles[i].as(), + inputHandles[i].data.get(), tensorRef); + inputs.emplace_back(deviceProgram->inputs()->Get(i)->global_id(), buffer, + event); + } + + std::vector outputs; + outputs.reserve(outputHandles.size()); + assert(outputHandles.size() == deviceProgram->outputs()->size() && + "Output size mismatch"); + for (unsigned i = 0; i < outputHandles.size(); ++i) { + ::tt::target::TensorRef const *tensorRef = + deviceProgram->outputs()->Get(i); + std::shared_ptr<::tt::tt_metal::Buffer> buffer = + prepareOutput(device, &outputHandles[i].as(), tensorRef); + outputs.emplace_back(deviceProgram->outputs()->Get(i)->global_id(), + buffer); + } + + std::size_t cq_id = 0; for (::tt::target::metal::CommandQueue const *cq : *deviceProgram->command_queues()) { - events->push_back( + deviceEvents.push_back( executeCommandQueue(device, cq, cq_id, inputs, outputs)); ++cq_id; } + + Events copyEvents = + maybeCopyHostOutputs(device, outputHandles, outputs, deviceEvents); + if (not copyEvents.empty()) { + std::swap(deviceEvents, copyEvents); + } + + events->insert(events->end(), deviceEvents.begin(), deviceEvents.end()); } return static_pointer_cast(events); diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 07622c1c3..5bc348c17 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -39,7 +39,8 @@ std::pair getCurrentSystemDesc() { ::ttmlir::Version ttmlirVersion = ::ttmlir::getVersion(); ::tt::target::Version version(ttmlirVersion.major, ttmlirVersion.minor, ttmlirVersion.patch); - ::tt::target::Dim2d deviceGrid = toFlatbuffer(device.logical_grid_size()); + ::tt::target::Dim2d deviceGrid = + toFlatbuffer(device.compute_with_storage_grid_size()); std::vector<::flatbuffers::Offset> chipDescs = { ::tt::target::CreateChipDesc( fbb, toFlatbuffer(device.arch()), &deviceGrid, (1 << 20), 12, diff --git a/runtime/tools/python/ttrt/common/api.py b/runtime/tools/python/ttrt/common/api.py index 71b7d54e3..a8e7f7ee4 100644 --- a/runtime/tools/python/ttrt/common/api.py +++ b/runtime/tools/python/ttrt/common/api.py @@ -217,11 +217,13 @@ def run(args): total_inputs.append(inputs) total_outputs.append(outputs) + event = None for loop in range(arg_loops): - ttrt.runtime.submit( + event = ttrt.runtime.submit( device, fbb, program_index, total_inputs[loop], total_outputs[loop] ) print(f"finished loop={loop}") + ttrt.runtime.wait(event) print("outputs:\n", torch_outputs) # save artifacts diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index d2bd0f982..be9d34a93 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -13,6 +13,7 @@ close_device, submit, create_tensor, + wait, ) except ModuleNotFoundError: raise ImportError( diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 1a383079c..35338f9cd 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -51,4 +51,5 @@ PYBIND11_MODULE(_C, m) { m.def("submit", &tt::runtime::submit, py::arg("device"), py::arg("executable"), py::arg("program_index"), py::arg("inputs"), py::arg("outputs"), "Submit a binary for execution"); + m.def("wait", &tt::runtime::wait, py::arg("event")); }