From 29281af7749d726e0722a1773b540d96923a492a Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Tue, 3 Dec 2024 20:16:34 -0500 Subject: [PATCH] Runtime stitching APIs and sanity tests, ttnn runtime submit refactor (#1301) --- .github/workflows/build-and-test.yml | 96 ++++ runtime/CMakeLists.txt | 5 +- runtime/include/tt/runtime/detail/ttmetal.h | 12 +- runtime/include/tt/runtime/detail/ttnn.h | 55 ++- runtime/include/tt/runtime/runtime.h | 46 +- runtime/include/tt/runtime/test/utils.h | 17 + runtime/include/tt/runtime/types.h | 14 +- runtime/lib/binary.cpp | 37 +- runtime/lib/common/system_desc.cpp | 4 +- runtime/lib/runtime.cpp | 206 +++++++-- runtime/lib/ttmetal/command_queue.cpp | 14 +- runtime/lib/ttmetal/runtime.cpp | 28 +- runtime/lib/ttnn/CMakeLists.txt | 20 +- .../ttnn/include/tt/runtime/ttnn/types.cpp | 437 ++++++++++++++++++ .../lib/ttnn/include/tt/runtime/ttnn/types.h | 193 ++++---- .../ttnn/include/tt/runtime/ttnn/utils.cpp | 222 +++++++++ .../lib/ttnn/include/tt/runtime/ttnn/utils.h | 143 ++---- runtime/lib/ttnn/operations/CMakeLists.txt | 5 +- .../lib/ttnn/operations/ccl/all_gather.cpp | 3 +- runtime/lib/ttnn/operations/conv/conv2d.cpp | 4 +- .../lib/ttnn/operations/creation/arange.cpp | 2 +- .../lib/ttnn/operations/creation/empty.cpp | 13 +- runtime/lib/ttnn/operations/creation/full.cpp | 11 +- .../operations/data_movement/transpose.cpp | 3 +- .../ttnn/operations/deletion/deallocate.cpp | 7 - .../ttnn/operations/eltwise/binary/binary.cpp | 3 +- .../eltwise/binary/binary_composite.cpp | 3 +- .../operations/eltwise/ternary/ternary.cpp | 2 +- .../ttnn/operations/eltwise/unary/unary.cpp | 7 +- .../eltwise/unary/unary_composite.cpp | 16 +- .../ttnn/operations/embedding/embedding.cpp | 3 +- .../tt/runtime/ttnn/operations/utils.cpp | 107 +---- .../tt/runtime/ttnn/operations/utils.h | 17 - .../ttnn/operations/layout/from_device.cpp | 7 +- .../lib/ttnn/operations/layout/to_device.cpp | 2 +- .../lib/ttnn/operations/layout/to_layout.cpp | 2 +- .../lib/ttnn/operations/layout/typecast.cpp | 2 +- runtime/lib/ttnn/operations/matmul/matmul.cpp | 5 +- .../ttnn/operations/normalization/softmax.cpp | 3 +- .../lib/ttnn/operations/pool/maxpool2d.cpp | 3 +- .../ttnn/operations/reduction/reduction.cpp | 3 +- runtime/lib/ttnn/program.cpp | 124 ++++- runtime/lib/ttnn/runtime.cpp | 256 ++++++++-- runtime/test/CMakeLists.txt | 24 + .../include/tt/runtime/ttnn/test/utils.cpp | 50 ++ runtime/test/python/ttnn/conftest.py | 25 + runtime/test/python/ttnn/test_runtime_api.py | 160 +++++++ runtime/test/python/ttnn/utils.py | 66 +++ runtime/test/ttnn/test_subtract.cpp | 36 +- runtime/tools/python/CMakeLists.txt | 1 + runtime/tools/python/setup.py | 13 +- runtime/tools/python/ttrt/common/run.py | 40 +- runtime/tools/python/ttrt/common/util.py | 6 + runtime/tools/python/ttrt/runtime/__init__.py | 14 + runtime/tools/python/ttrt/runtime/module.cpp | 90 +++- .../unary/isfinite/simple_isfinite.mlir | 6 +- .../eltwise_binary_op_chain.mlir | 49 ++ .../Silicon/StableHLO/Unary/isfinite_op.mlir | 6 +- test/ttmlir/Silicon/StableHLO/select_op.mlir | 20 +- .../TTNN/perf_unit/test_perf_isfinite.mlir | 6 +- .../Silicon/TTNN/perf_unit/test_perf_le.mlir | 21 - .../TTNN/perf_unit/test_perf_where.mlir | 10 +- test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 16 +- 63 files changed, 2201 insertions(+), 620 deletions(-) create mode 100644 runtime/include/tt/runtime/test/utils.h create mode 100644 runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp create mode 100644 runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp create mode 100644 runtime/test/include/tt/runtime/ttnn/test/utils.cpp create mode 100644 runtime/test/python/ttnn/conftest.py create mode 100644 runtime/test/python/ttnn/test_runtime_api.py create mode 100644 runtime/test/python/ttnn/utils.py create mode 100644 test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir delete mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 8ec0c93dc..c54d734b2 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -475,6 +475,102 @@ jobs: report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: Run ttrt tests + run-runtime-api-tests: + + timeout-minutes: 30 + needs: + - build-image + - build-ttmlir + strategy: + fail-fast: false + matrix: + build: [ + {runs-on: n150, enable_perf: OFF, name: "run"}, + ] + + runs-on: + - in-service + - ${{ matrix.build.runs-on }} + + container: + image: ${{ needs.build-image.outputs.docker-image }} + options: --device /dev/tenstorrent/0 + volumes: + - /dev/hugepages:/dev/hugepages + - /dev/hugepages-1G:/dev/hugepages-1G + - /etc/udev/rules.d:/etc/udev/rules.d + - /lib/modules:/lib/modules + - /opt/tt_metal_infra/provisioning/provisioning_env:/opt/tt_metal_infra/provisioning/provisioning_env + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set reusable strings + id: strings + shell: bash + run: | + echo "work-dir=$(pwd)" >> "$GITHUB_OUTPUT" + echo "build-output-dir=$(pwd)/build" >> "$GITHUB_OUTPUT" + echo "install-output-dir=$(pwd)/install" >> "$GITHUB_OUTPUT" + + - name: Git safe dir + run: git config --global --add safe.directory ${{ steps.strings.outputs.work-dir }} + + - name: Use build artifacts + uses: actions/download-artifact@v4 + with: + name: install-artifacts-${{ matrix.build.name }} + path: ${{ steps.strings.outputs.install-output-dir }} + + # This is needed to preserve file permissions + # https://github.com/actions/upload-artifact?tab=readme-ov-file#permission-loss + - name: 'Untar install directory' + shell: bash + working-directory: ${{ steps.strings.outputs.install-output-dir }} + run: tar xvf artifact.tar + + - name: Remove existing whls files + shell: bash + run: | + rm -f *.whl + + - name: Download ttrt run whls + uses: actions/download-artifact@v4 + with: + name: ttrt-whl-${{ matrix.build.name }} + + # Runtime tests currently require ttrt whls to be installed + - name: Install ttrt run whls + shell: bash + run: | + source env/activate + pip show ttrt && pip uninstall -y ttrt + pip install ttrt-${{ env.version }}*.whl --force-reinstall + pip install pytest + + - name: Generate system descriptor + shell: bash + run: | + source env/activate + ttrt query --save-artifacts + + - name: Generate tests + shell: bash + run: | + source env/activate + export LD_LIBRARY_PATH="${TTMLIR_TOOLCHAIN_DIR}/lib:${LD_LIBRARY_PATH}" + export SYSTEM_DESC_PATH="${GITHUB_WORKSPACE}/ttrt-artifacts/system_desc.ttsys" + ln -sf ${{ steps.strings.outputs.install-output-dir }} ${{ steps.strings.outputs.build-output-dir }} + llvm-lit -sv ${{ steps.strings.outputs.build-output-dir }}/test + + - name: ttnn api tests + shell: bash + run: | + source env/activate + pytest -ssv runtime/test/python/ttnn/test_runtime_api.py + build-and-test-explorer: needs: build-image timeout-minutes: 60 diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt index c9dce1094..0a23c6dda 100644 --- a/runtime/CMakeLists.txt +++ b/runtime/CMakeLists.txt @@ -14,6 +14,7 @@ set(TT_RUNTIME_OPTIONS TT_RUNTIME_DEBUG TT_RUNTIME_ENABLE_PERF_TRACE TT_RUNTIME_WORKAROUNDS + TTMLIR_ENABLE_RUNTIME_TESTS ) foreach(OPTION ${TT_RUNTIME_OPTIONS}) @@ -24,6 +25,4 @@ endforeach() add_subdirectory(lib) add_subdirectory(tools) -if (TTMLIR_ENABLE_RUNTIME_TESTS) - add_subdirectory(test) -endif() +add_subdirectory(test) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 5544e1d70..1b043f6e5 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -39,12 +39,16 @@ void closeDevice(Device device); void deallocateBuffers(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 67aa91a71..e7b8fbcf2 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -53,16 +53,27 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy); +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize); + inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType); + return ::tt::runtime::ttnn::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType); } inline Tensor createTensor(std::vector> &data, TensorDesc const &desc, std::unordered_map const &strategy) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType, strategy); + return ::tt::runtime::ttnn::createTensor( + data, desc.shape, desc.stride, desc.itemsize, desc.dataType, strategy); +} + +inline Tensor createTensor(Device device, Layout layout, + TensorDesc const &desc) { + return ::tt::runtime::ttnn::createTensor(device, layout, desc.shape, + desc.stride, desc.itemsize); } tt::target::DataType getTensorDataType(Tensor tensor); @@ -75,12 +86,23 @@ void closeDevice(Device device); void deallocateBuffers(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Tensor toHost(Tensor tensor, bool untilize = false); + +Tensor toLayout(Tensor tensor, Device device, Layout layout); + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex); + +void memcpy(Tensor dst, Tensor src); + +void deallocateTensor(Tensor &tensor, bool force = false); + std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); @@ -90,10 +112,27 @@ Tensor getOpOutputTensor(OpContext opContextHandle, std::vector getTensorData(Tensor tensor); +namespace legacy { +/* Will be deprecated soon once FEs migrate to new API */ + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, std::uint32_t programIndex, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs); +} // namespace legacy + +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputs); + +std::vector runProgram(::ttnn::MeshDevice &meshDevice, + Binary executableHandle, + std::uint32_t programIndex, + std::vector<::ttnn::Tensor *> const &inputs); } // namespace tt::runtime::ttnn diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index e4348da60..56666d564 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -43,16 +43,27 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy); +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize); + inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType); + return ::tt::runtime::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType); } inline Tensor createTensor(std::vector> &data, TensorDesc const &desc, std::unordered_map const &strategy) { - return createTensor(data, desc.shape, desc.stride, desc.itemsize, - desc.dataType, strategy); + return ::tt::runtime::createTensor(data, desc.shape, desc.stride, + desc.itemsize, desc.dataType, strategy); +} + +inline Tensor createTensor(Device device, Layout layout, + TensorDesc const &desc) { + return ::tt::runtime::createTensor(device, layout, desc.shape, desc.stride, + desc.itemsize); } tt::target::DataType getTensorDataType(Tensor tensor); @@ -63,12 +74,23 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1); void closeDevice(Device device); -Event submit(Device device, Binary executable, std::uint32_t programIndex, - std::vector const &inputs, - std::vector const &outputs); - void wait(Event event); +void wait(Tensor tensor); + +void wait(std::vector const &tensors); + +Tensor toHost(Tensor tensor, bool untilize = false); + +Tensor toLayout(Tensor tensor, Device device, Layout layout); + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex); + +void memcpy(Tensor dst, Tensor src); + +void deallocateTensor(Tensor &tensor, bool force = false); + std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); @@ -78,6 +100,14 @@ Tensor getOpOutputTensor(OpContext opContextHandle, std::vector getTensorData(Tensor tensor); +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputs); + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, std::vector const &inputs, + std::vector const &outputs); + } // namespace tt::runtime #endif diff --git a/runtime/include/tt/runtime/test/utils.h b/runtime/include/tt/runtime/test/utils.h new file mode 100644 index 000000000..e4323cc16 --- /dev/null +++ b/runtime/include/tt/runtime/test/utils.h @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TT_RUNTIME_TEST_UTILS_H +#define TT_RUNTIME_TEST_UTILS_H + +#include "tt/runtime/types.h" + +// Utility functions for testing TTNN runtime +namespace tt::runtime::ttnn::test { +Layout getDramInterleavedTileLayout(::tt::target::DataType dataType); +Layout getDramInterleavedRowMajorLayout(::tt::target::DataType dataType); +Layout getHostRowMajorLayout(::tt::target::DataType dataType); +} // namespace tt::runtime::ttnn::test + +#endif // TT_RUNTIME_TEST_UTILS_H diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 8fd641195..cc2791e23 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -122,10 +122,20 @@ struct Event : public detail::RuntimeCheckedObjectImpl { struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; - + Event event; Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime) - : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {} + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(nullptr, runtime) {} + + Tensor(std::shared_ptr handle, std::shared_ptr data, + std::shared_ptr eventHandle, DeviceRuntime runtime) + : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), + event(eventHandle, runtime) {} +}; + +struct Layout : public detail::RuntimeCheckedObjectImpl { + using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl; }; struct CallbackContext : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/binary.cpp b/runtime/lib/binary.cpp index 92be39d27..1d8cbf38b 100644 --- a/runtime/lib/binary.cpp +++ b/runtime/lib/binary.cpp @@ -27,15 +27,12 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema, flatbuffers::Parser parser(opts); if (not parser.Deserialize(binarySchema, schemaSize)) { - throw std::runtime_error("Failed to deserialize schema"); + LOG_FATAL("Failed to deserialize schema"); } std::string text; const char *err = ::flatbuffers::GenerateText(parser, fbb, &text); - if (err) { - throw std::runtime_error("Failed to generate JSON: " + std::string(err)); - } - + LOG_ASSERT(not err, "Failed to generate JSON: ", err); return text; } @@ -44,9 +41,7 @@ namespace ttnn { ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTNN) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTNN, "Unsupported binary format"); return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); } @@ -128,9 +123,7 @@ ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { bool isTTMetal = ::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTMetal) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTMetal, "Unsupported binary format"); return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } @@ -207,7 +200,7 @@ namespace system_desc { ::tt::target::SystemDescRoot const *getBinary(Flatbuffer binary) { if (!::tt::target::SizePrefixedSystemDescRootBufferHasIdentifier( binary.handle.get())) { - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } return ::tt::target::GetSizePrefixedSystemDescRoot(binary.handle.get()); } @@ -234,10 +227,7 @@ std::string asJson(Flatbuffer binary) { Flatbuffer Flatbuffer::loadFromPath(char const *path) { // load a flatbuffer from path std::ifstream fbb(path, std::ios::binary | std::ios::ate); - if (!fbb.is_open()) { - throw std::runtime_error("Failed to open file: " + std::string(path)); - } - + LOG_ASSERT(fbb.is_open(), "Failed to open file: ", path); std::streampos size = fbb.tellg(); fbb.seekg(0, std::ios::beg); auto buffer = ::tt::runtime::utils::malloc_shared(size); @@ -269,7 +259,7 @@ std::string_view Flatbuffer::getFileIdentifier() const { return ::tt::target::SystemDescRootIdentifier(); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string Flatbuffer::getVersion() const { @@ -288,7 +278,7 @@ std::string Flatbuffer::getVersion() const { return system_desc::getVersion(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string_view Flatbuffer::getTTMLIRGitHash() const { @@ -307,7 +297,7 @@ std::string_view Flatbuffer::getTTMLIRGitHash() const { return system_desc::getTTMLIRGitHash(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::string Flatbuffer::asJson() const { @@ -326,7 +316,7 @@ std::string Flatbuffer::asJson() const { return system_desc::asJson(*this); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } SystemDesc SystemDesc::loadFromPath(char const *path) { @@ -349,7 +339,7 @@ Binary::getProgramInputs(std::uint32_t programIndex) const { return metal::getProgramInputs(*this, programIndex); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } std::vector @@ -364,7 +354,7 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const { return metal::getProgramOutputs(*this, programIndex); } - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } const ::tt::target::GoldenTensor * @@ -379,8 +369,7 @@ Binary::getDebugInfoGolden(std::string &loc) const { return metal::getDebugInfoGolden(*this, loc); } - throw std::runtime_error( - "Unsupported binary format for obtaining golden information"); + LOG_FATAL("Unsupported binary format for obtaining golden information"); } } // namespace tt::runtime diff --git a/runtime/lib/common/system_desc.cpp b/runtime/lib/common/system_desc.cpp index f1210d00a..3b4685901 100644 --- a/runtime/lib/common/system_desc.cpp +++ b/runtime/lib/common/system_desc.cpp @@ -32,7 +32,7 @@ static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) { break; } - throw std::runtime_error("Unsupported arch"); + LOG_FATAL("Unsupported arch"); } static std::vector<::tt::target::ChipChannel> @@ -246,7 +246,7 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl( ::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root); ::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize()); if (!::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) { - throw std::runtime_error("Failed to verify system desc root buffer"); + LOG_FATAL("Failed to verify system desc root buffer"); } uint8_t *buf = fbb.GetBufferPointer(); auto size = fbb.GetSize(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index a57ac3fcd..bf6113308 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -42,7 +42,7 @@ void deallocateBuffers(Device device) { return ::tt::runtime::ttmetal::deallocateBuffers(device); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } } // namespace detail @@ -91,15 +91,14 @@ void setCompatibleRuntime(const Binary &binary) { return setCurrentRuntime(DeviceRuntime::TTMetal); } #endif - throw std::runtime_error( - "Unsupported binary file identifier or runtime not enabled"); + LOG_FATAL("Unsupported binary file identifier or runtime not enabled"); } std::pair getCurrentSystemDesc() { #if defined(TT_RUNTIME_ENABLE_TTNN) || defined(TT_RUNTIME_ENABLE_TTMETAL) return system_desc::getCurrentSystemDesc(); #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Tensor createTensor(std::shared_ptr data, @@ -122,7 +121,7 @@ Tensor createTensor(std::shared_ptr data, dataType); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Tensor @@ -143,10 +142,32 @@ createTensor(std::vector> &data, #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - throw std::runtime_error("Not implemented"); + LOG_FATAL("Not implemented"); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); +} + +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize) { + LOG_ASSERT(not shape.empty()); + LOG_ASSERT(not stride.empty()); + LOG_ASSERT(itemsize > 0); +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::createTensor(device, layout, shape, stride, + itemsize); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("Not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); } tt::target::DataType getTensorDataType(Tensor tensor) { @@ -161,7 +182,7 @@ tt::target::DataType getTensorDataType(Tensor tensor) { return ::tt::runtime::ttmetal::getTensorDataType(tensor); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } size_t getNumAvailableDevices() { @@ -176,7 +197,7 @@ size_t getNumAvailableDevices() { return ::tt::runtime::ttmetal::getNumAvailableDevices(); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { @@ -191,7 +212,7 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs) { return ::tt::runtime::ttmetal::openDevice(deviceIds, numHWCQs); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } void closeDevice(Device device) { @@ -206,44 +227,130 @@ void closeDevice(Device device) { return ::tt::runtime::ttmetal::closeDevice(device); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, - std::vector const &inputHandles, - std::vector const &outputHandles) { +void wait(Event event) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); + LOG_WARNING("wait API will be deprecated for TTNN runtime."); + return ::tt::runtime::ttnn::wait(event); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); + return ::tt::runtime::ttmetal::wait(event); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } -void wait(Event event) { +void wait(Tensor tensor) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::wait(event); + return ::tt::runtime::ttnn::wait(tensor); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::wait(event); + return ::tt::runtime::ttmetal::wait(tensor); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); +} + +void wait(std::vector const &tensors) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::wait(tensors); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::wait(tensors); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Tensor toHost(Tensor tensor, bool untilize) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::toHost(tensor, untilize); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Tensor toLayout(Tensor tensor, Device device, Layout layout) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::toLayout(tensor, device, layout); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getLayout(executableHandle, programIndex, + inputIndex); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +void memcpy(Tensor dst, Tensor src) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::memcpy(dst, src); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +void deallocateTensor(Tensor &tensor, bool force) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::deallocateTensor(tensor, force); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); } std::string getOpDebugString(OpContext opContextHandle) { @@ -258,7 +365,7 @@ std::string getOpDebugString(OpContext opContextHandle) { return ::tt::runtime::ttmetal::getOpDebugString(opContextHandle); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } std::string getOpLocInfo(OpContext opContextHandle) { @@ -291,7 +398,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, programContextHandle); } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); } std::vector getTensorData(Tensor tensor) { @@ -307,7 +414,48 @@ std::vector getTensorData(Tensor tensor) { } #endif - throw std::runtime_error("runtime is not enabled"); + LOG_FATAL("runtime is not enabled"); +} + +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle, + programIndex, inputHandles); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + LOG_FATAL("not implemented"); + } +#endif + LOG_FATAL("runtime is not enabled"); } +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles, + std::vector const &outputHandles) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + LOG_WARNING("This submit API will soon be deprecated. Please switch to the " + "new API."); + return ::tt::runtime::ttnn::legacy::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::submit(deviceHandle, executableHandle, + programIndex, inputHandles, + outputHandles); + } +#endif + LOG_FATAL("runtime is not enabled"); +} } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/command_queue.cpp b/runtime/lib/ttmetal/command_queue.cpp index 9a408a66b..3480458e6 100644 --- a/runtime/lib/ttmetal/command_queue.cpp +++ b/runtime/lib/ttmetal/command_queue.cpp @@ -137,7 +137,7 @@ void CQExecutor::execute(::tt::target::metal::Command const *command) { break; } default: - throw std::runtime_error("Unsupported command type"); + LOG_FATAL("Unsupported command type"); break; } } @@ -328,7 +328,7 @@ createKernelConfig(::tt::target::metal::KernelSource const *kernelSource) { break; } } - throw std::runtime_error("Unsupported kernel source type"); + LOG_FATAL("Unsupported kernel source type"); } static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) { @@ -346,7 +346,7 @@ static ::tt::DataFormat toDataFormat(::tt::target::DataType dataType) { case ::tt::target::DataType::UInt8: return ::tt::DataFormat::UInt8; default: - throw std::runtime_error("Unsupported data type"); + LOG_FATAL("Unsupported data type"); } } @@ -358,7 +358,7 @@ static CoreType toCoreType(::tt::target::metal::CoreType coreType) { case ::tt::target::metal::CoreType::ETH: return CoreType::ETH; } - throw std::runtime_error("Unsupported core type"); + LOG_FATAL("Unsupported core type"); } static ::tt::tt_metal::CircularBufferConfig createCircularBufferConfig( @@ -427,7 +427,7 @@ static void processRuntimeArgs( break; } case ::tt::target::metal::RuntimeArg::NONE: - throw std::runtime_error("Unsupported runtime arg type"); + LOG_FATAL("Unsupported runtime arg type"); } } @@ -516,7 +516,7 @@ void CQExecutor::execute( break; } default: - throw std::runtime_error("Unsupported HostBuffer type"); + LOG_FATAL("Unsupported HostBuffer type"); } } @@ -524,7 +524,7 @@ void CQExecutor::execute( ::tt::target::metal::EnqueueReadBufferCommand const *command) { ZoneScopedN("EnqueueReadBufferCommand"); // Maybe we will need this in the future, like paging to system mem? - throw std::runtime_error("Unsupported EnqueueReadBufferCommand"); + LOG_FATAL("Unsupported EnqueueReadBufferCommand"); } void CQExecutor::execute( diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 22d43ba36..2a66aa5e6 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -24,7 +24,7 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { ::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier( binary.handle.get()); if (not isTTMetal) { - throw std::runtime_error("Unsupported binary format"); + LOG_FATAL("Unsupported binary format"); } return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } @@ -56,7 +56,7 @@ tt::target::DataType getTensorDataType(Tensor tensor) { } if (std::holds_alternative>( metalTensor)) { - throw std::runtime_error("Datatype mapping from buffer not supported yet."); + LOG_FATAL("Datatype mapping from buffer not supported yet."); } LOG_ASSERT(false, "Unsupported tensor type"); return ::tt::target::DataType::Float32; @@ -96,6 +96,21 @@ void deallocateBuffers(Device deviceHandle) { } } +void wait(Event event) { + Events events = event.as(DeviceRuntime::TTMetal); + for (auto e : events) { + ::tt::tt_metal::EventSynchronize(e); + } +} + +void wait(Tensor tensor) { ::tt::runtime::ttmetal::wait(tensor.event); } + +void wait(std::vector const &tensors) { + for (Tensor tensor : tensors) { + ::tt::runtime::ttmetal::wait(tensor); + } +} + static std::pair, std::shared_ptr<::tt::tt_metal::Event>> prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, @@ -117,7 +132,7 @@ prepareInput(::tt::tt_metal::Device *device, MetalTensor const &metalTensor, metalTensor)) { std::shared_ptr<::tt::tt_metal::Buffer> buffer = std::get>(metalTensor); - throw std::runtime_error("Input from buffer not supported yet"); + LOG_FATAL("Input from buffer not supported yet"); } LOG_ASSERT(false, "Unsupported tensor type"); return std::make_pair(nullptr, nullptr); @@ -249,13 +264,6 @@ Event submit(Device deviceHandle, Binary executableHandle, return Event(static_pointer_cast(events), DeviceRuntime::TTMetal); } -void wait(Event event) { - Events events = event.as(DeviceRuntime::TTMetal); - for (auto e : events) { - ::tt::tt_metal::EventSynchronize(e); - } -} - std::string getOpDebugString(OpContext opContextHandle) { // Not implemented LOG_WARNING("obtaining op debug string for metal runtime not implemented"); diff --git a/runtime/lib/ttnn/CMakeLists.txt b/runtime/lib/ttnn/CMakeLists.txt index 92581cf46..6a68c4c7b 100644 --- a/runtime/lib/ttnn/CMakeLists.txt +++ b/runtime/lib/ttnn/CMakeLists.txt @@ -1,4 +1,22 @@ +add_library(TTRuntimeTTNNHelpers + STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/types.cpp +) +set_property(TARGET TTRuntimeTTNNHelpers PROPERTY CXX_STANDARD 20) +target_compile_options(TTRuntimeTTNNHelpers PUBLIC -mavx -mavx2 -fsized-deallocation) +target_include_directories(TTRuntimeTTNNHelpers PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common +) +target_include_directories(TTRuntimeTTNNHelpers SYSTEM PUBLIC "$") +add_dependencies(TTRuntimeTTNNHelpers TTNN_LIBRARY tt-metal FBS_GENERATION) +target_link_libraries(TTRuntimeTTNNHelpers PUBLIC TTNN_LIBRARY) + add_subdirectory(operations) + add_library(TTRuntimeTTNN STATIC runtime.cpp @@ -11,5 +29,5 @@ target_include_directories(TTRuntimeTTNN PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) target_include_directories(TTRuntimeTTNN SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps) +target_link_libraries(TTRuntimeTTNN PUBLIC TTRuntimeTTNNOps TTRuntimeTTNNHelpers) add_dependencies(TTRuntimeTTNN TTRuntimeTTNNOps) diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp new file mode 100644 index 000000000..87d081599 --- /dev/null +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/ttnn/types.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/ttnn/utils.h" + +namespace tt::runtime::ttnn { + +// +// LayoutConverter APIs +// +LayoutConverter::LayoutConverter(const LayoutDesc &inputDesc, + const LayoutDesc &outputDesc) + : inputDesc(inputDesc), outputDesc(outputDesc) { + shouldTilize = (inputDesc.layout == ::ttnn::Layout::ROW_MAJOR and + outputDesc.layout == ::ttnn::Layout::TILE); + shouldUntilize = (inputDesc.layout == ::ttnn::Layout::TILE and + outputDesc.layout == ::ttnn::Layout::ROW_MAJOR); + shouldTypecast = (inputDesc.dataType != outputDesc.dataType); + shouldToDevice = (inputDesc.isOnHost() and outputDesc.isOnDevice()); + shouldToMemoryConfig = (not shouldToDevice and outputDesc.isOnDevice() and + (inputDesc.memoryConfig != outputDesc.memoryConfig)); + shouldFromDevice = (inputDesc.isOnDevice() and outputDesc.isOnHost()); +} + +::ttnn::Tensor LayoutConverter::convertTensorLayout( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (inputDesc.isOnHost()) { + return convertHostTensorLayout(input, targetDevice); + } + return convertDeviceTensorLayout(input); +} + +::ttnn::Tensor LayoutConverter::toLayoutIfNeeded(const ::ttnn::Tensor &input) { + if (shouldTilize) { + return ::ttnn::to_layout(input, ::ttnn::Layout::TILE, std::nullopt, + std::nullopt, + static_cast<::ttnn::Device *>(nullptr)); + } + if (shouldUntilize) { + return ::ttnn::to_layout(input, ::ttnn::Layout::ROW_MAJOR, std::nullopt, + std::nullopt, + static_cast<::ttnn::Device *>(nullptr)); + } + return input; +} + +::ttnn::Tensor LayoutConverter::typecastIfNeeded(const ::ttnn::Tensor &input) { + if (shouldTypecast) { + return ::ttnn::typecast(input, outputDesc.dataType); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::toDeviceIfNeeded(const ::ttnn::Tensor &input, + std::optional targetDevice, + bool force) { + if (shouldToDevice or force) { + LOG_ASSERT(targetDevice.has_value()); + return std::visit( + [&](auto &&targetDevice) -> ::ttnn::Tensor { + return ::ttnn::to_device(input, &(targetDevice.get()), + outputDesc.memoryConfig); + }, + targetDevice.value()); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::toMemoryConfigIfNeeded(const ::ttnn::Tensor &input) { + if (shouldToMemoryConfig) { + LOG_ASSERT(outputDesc.memoryConfig.has_value()); + return ::ttnn::to_memory_config(input, outputDesc.memoryConfig.value()); + } + return input; +} + +::ttnn::Tensor +LayoutConverter::fromDeviceIfNeeded(const ::ttnn::Tensor &input) { + if (shouldFromDevice) { + return ::ttnn::from_device(input); + } + return input; +} + +::ttnn::Tensor LayoutConverter::handleHostInputNoLayoutNoTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; +} + +::ttnn::Tensor LayoutConverter::handleHostInputLayoutNoTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (shouldUntilize) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toLayoutIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleHostInputNoLayoutTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (outputDesc.layout == ::ttnn::Layout::TILE) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (outputDesc.layout != ::ttnn::Layout::TILE) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleHostInputLayoutTypecast( + const ::ttnn::Tensor &input, std::optional targetDevice) { + if (shouldUntilize) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(out); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toDeviceIfNeeded(input, targetDevice); + out = toLayoutIfNeeded(out); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and outputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toDeviceIfNeeded(out, targetDevice); + out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + outputDesc.dataType != ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(out); + out = toDeviceIfNeeded(out, targetDevice); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::convertHostTensorLayout( + const ::ttnn::Tensor &input, std::optional targetDevice) { + bool shouldToLayout = (shouldTilize or shouldUntilize); + LOG_ASSERT(not shouldToDevice or targetDevice.has_value(), + "Target device must be provided for ToDevice"); + if (not shouldToLayout and not shouldTypecast) { + return handleHostInputNoLayoutNoTypecast(input, targetDevice); + } + if (shouldToLayout and not shouldTypecast) { + return handleHostInputLayoutNoTypecast(input, targetDevice); + } + if (not shouldToLayout and shouldTypecast) { + return handleHostInputNoLayoutTypecast(input, targetDevice); + } + if (shouldToLayout and shouldTypecast) { + return handleHostInputLayoutTypecast(input, targetDevice); + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputNoLayoutNoTypecast( + const ::ttnn::Tensor &input) { + ::ttnn::Tensor out = toMemoryConfigIfNeeded(input); + out = fromDeviceIfNeeded(out); + return out; +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputLayoutNoTypecast( + const ::ttnn::Tensor &input) { + if (shouldUntilize and shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldUntilize and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device untilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + /* If we should tilize and the input data type is bfloat16, tilize on device + */ + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(out); + return out; + } + + /* If we should tilize and the input data type is not bfloat16, tilize on + * host */ + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device tilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor LayoutConverter::handleDeviceInputNoLayoutTypecast( + const ::ttnn::Tensor &input) { + if (inputDesc.isTilized()) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(input); + return out; + } + + if (not inputDesc.isTilized() and shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = typecastIfNeeded(out); + return out; + } + + if (not inputDesc.isTilized() and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device typecast."); + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor +LayoutConverter::handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input) { + if (shouldUntilize and shouldFromDevice) { + ::ttnn::Tensor out = typecastIfNeeded(input); + out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + return out; + } + + if (shouldUntilize and not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device untilize."); + ::ttnn::Tensor out = typecastIfNeeded(input); + out = toLayoutIfNeeded(input); + out = toMemoryConfigIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType == ::ttnn::DataType::BFLOAT16) { + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + out = fromDeviceIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + shouldFromDevice) { + ::ttnn::Tensor out = fromDeviceIfNeeded(input); + out = toLayoutIfNeeded(out); + out = typecastIfNeeded(out); + return out; + } + + if (shouldTilize and inputDesc.dataType != ::ttnn::DataType::BFLOAT16 and + not shouldFromDevice) { + LOG_WARNING("Currently no constraint checking for on-device tilize."); + ::ttnn::Tensor out = toLayoutIfNeeded(input); + out = typecastIfNeeded(out); + out = toMemoryConfigIfNeeded(out); + return out; + } + + LOG_FATAL("Unreachable code path"); +} + +::ttnn::Tensor +LayoutConverter::convertDeviceTensorLayout(const ::ttnn::Tensor &input) { + bool shouldToLayout = (shouldTilize or shouldUntilize); + if (not shouldToLayout and not shouldTypecast) { + return handleDeviceInputNoLayoutNoTypecast(input); + } + if (shouldToLayout and not shouldTypecast) { + return handleDeviceInputLayoutNoTypecast(input); + } + if (not shouldToLayout and shouldTypecast) { + return handleDeviceInputNoLayoutTypecast(input); + } + if (shouldToLayout and shouldTypecast) { + return handleDeviceInputLayoutTypecast(input); + } + LOG_FATAL("Unreachable code path"); +} + +// +// ProgramTensorPool APIs +// +std::pair::iterator, bool> +ProgramTensorPool::try_emplace(std::uint32_t globalId, + const ::ttnn::Tensor &tensor) { + auto it = liveTensors.find(globalId); + if (it != liveTensors.end()) { + return std::make_pair(it, false); + } + LOG_ASSERT(!intermedTensors.contains(globalId)); + intermedTensors.try_emplace(globalId, tensor); + return liveTensors.try_emplace(globalId, &intermedTensors.at(globalId)); +} + +std::pair::iterator, bool> +ProgramTensorPool::insert_or_assign(std::uint32_t globalId, + const ::ttnn::Tensor &tensor) { + intermedTensors.insert_or_assign(globalId, tensor); + return liveTensors.insert_or_assign(globalId, &intermedTensors.at(globalId)); +} + +::ttnn::Tensor &ProgramTensorPool::at(std::uint32_t globalId) { + LOG_ASSERT(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); +} + +const ::ttnn::Tensor &ProgramTensorPool::at(std::uint32_t globalId) const { + LOG_ASSERT(liveTensors.contains(globalId)); + return *liveTensors.at(globalId); +} + +size_t ProgramTensorPool::erase(std::uint32_t globalId) { + LOG_ASSERT(liveTensors.contains(globalId) && + intermedTensors.contains(globalId)); + intermedTensors.erase(globalId); + return liveTensors.erase(globalId); +} + +std::vector ProgramTensorPool::gatherOutputTensors() { + std::vector outputTensors; + outputTensors.reserve(programOutputs.size()); + std::transform( + programOutputs.begin(), programOutputs.end(), + std::back_inserter(outputTensors), [this](uint32_t outputGlobalId) { + return utils::createRuntimeTensorFromTTNN(this->at(outputGlobalId)); + }); + return outputTensors; +} + +// +// ProgramContext APIs +// +ProgramContext::ProgramContext( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, ::ttnn::MeshDevice *parentMesh) + : tensorPool(ProgramTensorPool(liveTensors, programInputs, programOutputs)), + parentMesh(parentMesh) { + LOG_ASSERT(parentMesh, "Parent mesh cannot be null"); +} + +void ProgramContext::addSubMesh(uint32_t meshId, + std::shared_ptr<::ttnn::MeshDevice> subMesh) { + auto [it, inserted] = subMeshes.try_emplace(meshId, subMesh); + LOG_ASSERT(inserted, "Submesh already exists"); +} + +::ttnn::MeshDevice &ProgramContext::getSubMesh(uint32_t meshId) { + LOG_ASSERT(subMeshes.contains(meshId)); + return *subMeshes.at(meshId); +} + +size_t ProgramContext::subMeshSize(uint32_t meshId) const { + LOG_ASSERT(subMeshes.contains(meshId)); + return subMeshes.at(meshId)->num_devices(); +} + +::ttnn::Device &ProgramContext::getDeviceFromSubMesh(uint32_t meshId, + int physicalDeviceId) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + return *subMesh.get_device(physicalDeviceId); +} + +::ttnn::Device &ProgramContext::getDeviceIndexFromSubMesh(uint32_t meshId, + int deviceIndex) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + return *subMesh.get_device_index(deviceIndex); +} + +DeviceVariant ProgramContext::getTargetDevice(uint32_t meshId) { + LOG_ASSERT(subMeshes.contains(meshId)); + auto &subMesh = *subMeshes.at(meshId); + if (subMesh.num_devices() == 1) { + return std::ref(*subMesh.get_device_index(0)); + } + return std::ref(subMesh); +} + +} // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h index 5cd08c7ed..a5ca800c3 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.h @@ -6,18 +6,88 @@ #define TT_RUNTIME_TTNN_TYPES_H #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/types.h" +#include +#include namespace tt::runtime::ttnn { - -using TensorMap = std::unordered_map; using DeviceVariant = std::variant, std::reference_wrapper<::ttnn::MeshDevice>>; +struct LayoutDesc { + ::ttnn::BufferType bufferType; + ::ttnn::Layout layout; + ::ttnn::DataType dataType; + std::optional<::ttnn::MemoryConfig> memoryConfig; + + LayoutDesc(const ::ttnn::BufferType &bufferType, const ::ttnn::Layout &layout, + const ::ttnn::DataType &dataType, + const std::optional<::ttnn::MemoryConfig> &memoryConfig) + : bufferType(bufferType), layout(layout), dataType(dataType), + memoryConfig(memoryConfig) {} + + bool isOnHost() const { + return bufferType == ::ttnn::BufferType::SYSTEM_MEMORY; + } + bool isOnDevice() const { return !isOnHost(); } + + bool isTilized() const { return layout == ::ttnn::Layout::TILE; } +}; + +class LayoutConverter { +public: + LayoutDesc inputDesc; + LayoutDesc outputDesc; + bool shouldTilize = false; + bool shouldUntilize = false; + bool shouldTypecast = false; + bool shouldToDevice = false; + bool shouldToMemoryConfig = false; + bool shouldFromDevice = false; + + LayoutConverter(const LayoutDesc &inputDesc, const LayoutDesc &outputDesc); + ::ttnn::Tensor convertTensorLayout(const ::ttnn::Tensor &input, + std::optional targetDevice); + +private: + ::ttnn::Tensor toLayoutIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor typecastIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor toDeviceIfNeeded(const ::ttnn::Tensor &input, + std::optional targetDevice, + bool force = false); + ::ttnn::Tensor toMemoryConfigIfNeeded(const ::ttnn::Tensor &input); + ::ttnn::Tensor fromDeviceIfNeeded(const ::ttnn::Tensor &input); + + ::ttnn::Tensor + handleHostInputNoLayoutNoTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputLayoutNoTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputNoLayoutTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + handleHostInputLayoutTypecast(const ::ttnn::Tensor &input, + std::optional targetDevice); + ::ttnn::Tensor + convertHostTensorLayout(const ::ttnn::Tensor &input, + std::optional targetDevice); + + ::ttnn::Tensor + handleDeviceInputNoLayoutNoTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputLayoutNoTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputNoLayoutTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor handleDeviceInputLayoutTypecast(const ::ttnn::Tensor &input); + ::ttnn::Tensor convertDeviceTensorLayout(const ::ttnn::Tensor &input); +}; + class ProgramTensorPool { public: - ProgramTensorPool(const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs) + ProgramTensorPool( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs) : programInputs(programInputs), programOutputs(programOutputs), liveTensors(liveTensors) {} ProgramTensorPool(const ProgramTensorPool &) = delete; @@ -25,72 +95,38 @@ class ProgramTensorPool { ProgramTensorPool(ProgramTensorPool &&) = default; ProgramTensorPool &operator=(ProgramTensorPool &&) = default; - auto try_emplace(std::uint32_t globalId, const ::ttnn::Tensor &tensor) { - auto it = liveTensors.find(globalId); - if (it != liveTensors.end()) { - return std::make_pair(it, false); - } - assert(!intermedTensors.contains(globalId)); - intermedTensors.try_emplace(globalId, tensor); - return liveTensors.try_emplace(globalId, &intermedTensors.at(globalId)); - } + std::pair::iterator, bool> + try_emplace(std::uint32_t globalId, const ::ttnn::Tensor &tensor); - auto insert_or_assign(std::uint32_t globalId, const ::ttnn::Tensor &tensor) { - intermedTensors.insert_or_assign(globalId, tensor); - return liveTensors.insert_or_assign(globalId, - &intermedTensors.at(globalId)); - } + std::pair::iterator, bool> + insert_or_assign(std::uint32_t globalId, const ::ttnn::Tensor &tensor); - ::ttnn::Tensor &at(std::uint32_t globalId) { - assert(liveTensors.contains(globalId)); - return *liveTensors.at(globalId); - } + ::ttnn::Tensor &at(std::uint32_t globalId); - const ::ttnn::Tensor &at(std::uint32_t globalId) const { - assert(liveTensors.contains(globalId)); - return *liveTensors.at(globalId); - } + const ::ttnn::Tensor &at(std::uint32_t globalId) const; - size_t erase(std::uint32_t globalId) { - assert(liveTensors.contains(globalId) && - intermedTensors.contains(globalId)); - intermedTensors.erase(globalId); - return liveTensors.erase(globalId); - } + size_t erase(std::uint32_t globalId); - void copyTensorToUserOutput(std::uint32_t outputGlobalId, - const ::ttnn::Tensor &srcTensor) { - assert(liveTensors.contains(outputGlobalId)); - assert(isUserOutput(outputGlobalId)); - ::ttnn::Tensor &outputTensor = *liveTensors.at(outputGlobalId); - void *src = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); - size_t size = outputTensor.volume() * outputTensor.element_size(); - std::memcpy(dst, src, size); - } + std::vector gatherOutputTensors(); bool contains(std::uint32_t globalId) const { return liveTensors.contains(globalId); } - bool isUserOutput(std::uint32_t globalId) const { - return programOutputs.contains(globalId); - } - - const std::unordered_set &getProgramInputs() const { + const std::vector &getProgramInputs() const { return programInputs; } - const std::unordered_set &getProgramOutputs() const { + const std::vector &getProgramOutputs() const { return programOutputs; } private: - std::unordered_set programInputs; - std::unordered_set programOutputs; + std::vector programInputs; + std::vector programOutputs; // A superset of intermedTensors, containing pointers to all tensors created - // by the program and the input/output tensors passed in by the user - TensorMap liveTensors; + // by the program and the input tensors passed in by the user + std::unordered_map liveTensors; // A subset of liveTensors, containing values of any intermediate tensors // created by the program @@ -99,15 +135,11 @@ class ProgramTensorPool { class ProgramContext { public: - ProgramContext(const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs, - ::ttnn::MeshDevice *parentMesh) - : tensorPool( - ProgramTensorPool(liveTensors, programInputs, programOutputs)), - parentMesh(parentMesh) { - assert(parentMesh && "Parent mesh cannot be null"); - } + ProgramContext( + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, + ::ttnn::MeshDevice *parentMesh); ProgramContext(const ProgramContext &) = delete; ProgramContext &operator=(const ProgramContext &) = delete; ProgramContext(ProgramContext &&) = default; @@ -125,42 +157,17 @@ class ProgramContext { // // Sub Mesh Operations // - void addSubMesh(uint32_t meshId, - std::shared_ptr<::ttnn::MeshDevice> subMesh) { - auto [it, inserted] = subMeshes.try_emplace(meshId, subMesh); - assert(inserted && "Submesh already exists"); - } + void addSubMesh(uint32_t meshId, std::shared_ptr<::ttnn::MeshDevice> subMesh); - ::ttnn::MeshDevice &getSubMesh(uint32_t meshId) { - assert(subMeshes.contains(meshId)); - return *subMeshes.at(meshId); - } + ::ttnn::MeshDevice &getSubMesh(uint32_t meshId); - size_t subMeshSize(uint32_t meshId) const { - assert(subMeshes.contains(meshId)); - return subMeshes.at(meshId)->num_devices(); - } + size_t subMeshSize(uint32_t meshId) const; - ::ttnn::Device &getDeviceFromSubMesh(uint32_t meshId, int physicalDeviceId) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - return *subMesh.get_device(physicalDeviceId); - } + ::ttnn::Device &getDeviceFromSubMesh(uint32_t meshId, int physicalDeviceId); - ::ttnn::Device &getDeviceIndexFromSubMesh(uint32_t meshId, int deviceIndex) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - return *subMesh.get_device_index(deviceIndex); - } + ::ttnn::Device &getDeviceIndexFromSubMesh(uint32_t meshId, int deviceIndex); - DeviceVariant getTargetDevice(uint32_t meshId) { - assert(subMeshes.contains(meshId)); - auto &subMesh = *subMeshes.at(meshId); - if (subMesh.num_devices() == 1) { - return std::ref(*subMesh.get_device_index(0)); - } - return std::ref(subMesh); - } + DeviceVariant getTargetDevice(uint32_t meshId); // // Tensor Pool Operations diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp new file mode 100644 index 000000000..fa8aa82ed --- /dev/null +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/ttnn/utils.h" +#include "tt/runtime/detail/logger.h" + +namespace tt::runtime::ttnn::utils { + +// TODO (bug #701) +// Currently the memory layout/location in flatbuffer is incorrect +// These methods are workarounds for operations such that we query the info +// directly from the TTNN tensor. Ideally, we should be able to get all of this +// info directly from the flatbuffer +bool isOnHost(const ::ttnn::StorageType &storageType) { + return storageType == ::tt::tt_metal::StorageType::BORROWED or + storageType == ::tt::tt_metal::StorageType::OWNED or + storageType == ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST; +} + +bool isOnDevice(const ::ttnn::StorageType &storageType) { + return storageType == ::tt::tt_metal::StorageType::DEVICE or + storageType == ::tt::tt_metal::StorageType::MULTI_DEVICE; +} + +bool isValidTileShape(const ::tt::target::Dim2d *shape) { + return (shape->x() == 1 and shape->y() == 1) or + (shape->x() == 32 and shape->y() == 32); +} + +::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { + switch (dataType) { + case ::tt::target::DataType::Float32: + return ::ttnn::DataType::FLOAT32; + case ::tt::target::DataType::BFloat16: + return ::ttnn::DataType::BFLOAT16; + case ::tt::target::DataType::BFP_BFloat8: + return ::ttnn::DataType::BFLOAT8_B; + case ::tt::target::DataType::BFP_BFloat4: + return ::ttnn::DataType::BFLOAT4_B; + case ::tt::target::DataType::UInt32: + return ::ttnn::DataType::UINT32; + case ::tt::target::DataType::UInt16: + return ::ttnn::DataType::UINT16; + + default: + LOG_FATAL("Unsupported data type"); + } +} + +::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) { + switch (dataType) { + case ::ttnn::DataType::FLOAT32: + return ::tt::target::DataType::Float32; + case ::ttnn::DataType::BFLOAT16: + return ::tt::target::DataType::BFloat16; + case ::ttnn::DataType::BFLOAT8_B: + return ::tt::target::DataType::BFP_BFloat8; + case ::ttnn::DataType::BFLOAT4_B: + return ::tt::target::DataType::BFP_BFloat4; + case ::ttnn::DataType::UINT32: + return ::tt::target::DataType::UInt32; + case ::ttnn::DataType::UINT16: + return ::tt::target::DataType::UInt16; + + default: + LOG_FATAL("Unsupported data type"); + } +} + +::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout) { + switch (layout) { + case ::tt::target::TensorLayout::Tile: + return ::ttnn::Layout::TILE; + case ::tt::target::TensorLayout::RowMajor: + return ::ttnn::Layout::ROW_MAJOR; + default: + LOG_FATAL("Unsupported layout"); + } +} + +::ttnn::TensorMemoryLayout +toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout) { + + switch (tensorMemoryLayout) { + case ::tt::target::TensorMemoryLayout::Interleaved: + return ::ttnn::TensorMemoryLayout::INTERLEAVED; + case ::tt::target::TensorMemoryLayout::SingleBank: + return ::ttnn::TensorMemoryLayout::SINGLE_BANK; + case ::tt::target::TensorMemoryLayout::HeightSharded: + return ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED; + case ::tt::target::TensorMemoryLayout::WidthSharded: + return ::ttnn::TensorMemoryLayout::WIDTH_SHARDED; + case ::tt::target::TensorMemoryLayout::BlockSharded: + return ::ttnn::TensorMemoryLayout::BLOCK_SHARDED; + case ::tt::target::TensorMemoryLayout::None: + LOG_FATAL("Unsupported tensor memory layout None"); + } +} + +// This method will be deprecated in favor of method below +// +::tt::tt_metal::BufferType +toTTNNBufferType(::tt::target::MemorySpace memorySpace) { + switch (memorySpace) { + case ::tt::target::MemorySpace::System: + case ::tt::target::MemorySpace::SystemMMIO: + return ::tt::tt_metal::BufferType::SYSTEM_MEMORY; + case ::tt::target::MemorySpace::DeviceDRAM: + return ::tt::tt_metal::BufferType::DRAM; + case ::tt::target::MemorySpace::DeviceL1: + return ::tt::tt_metal::BufferType::L1; + } +} + +// Prefer to use this method +// +::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType) { + + switch (bufferType) { + case ::tt::target::BufferType::DRAM: + return ::ttnn::BufferType::DRAM; + case ::tt::target::BufferType::L1: + return ::ttnn::BufferType::L1; + case ::tt::target::BufferType::SystemMemory: + return ::ttnn::BufferType::SYSTEM_MEMORY; + case ::tt::target::BufferType::L1Small: + return ::ttnn::BufferType::L1_SMALL; + case ::tt::target::BufferType::Trace: + return ::ttnn::BufferType::TRACE; + } +}; + +std::vector +toShapeFromFBShape(const flatbuffers::Vector &vec) { + return std::vector(vec.begin(), vec.end()); +} + +::ttnn::Layout +inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef) { + const ::tt::target::Dim2d *tileShape = + tensorRef->desc()->layout()->memory_desc()->tile_shape(); + LOG_ASSERT(isValidTileShape(tileShape)); + if (tileShape->x() == 1 and tileShape->y() == 1) { + return ::ttnn::Layout::ROW_MAJOR; + } + return ::ttnn::Layout::TILE; +} + +CoreRangeSet +toCoreRangeSet(const ::flatbuffers::Vector + *coreRangeSet) { + std::set coreRanges; + for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { + CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); + // End is inclusive + CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1, + coreRange->loc().y() + coreRange->size().y() - 1); + + coreRanges.emplace(start, end); + } + return CoreRangeSet(coreRanges); +} + +::tt::tt_metal::MemoryConfig +createMemoryConfig(const ::tt::target::TensorRef *tensorRef) { + const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); + const ::tt::target::TensorMemoryLayout targetMemoryLayout = + layout->memory_desc()->memory_layout(); + const ::tt::target::MemorySpace targetMemorySpace = + layout->memory_desc()->memory_space(); + const ::flatbuffers::Vector + *targetCoreRangeSet = layout->core_range_set(); + const ::flatbuffers::Vector *targetShardShape = + layout->memory_desc()->shape(); + const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); + + LOG_ASSERT(targetCoreRangeSet->size() == 1, + "Currently only single core range/grid is supported"); + + LOG_ASSERT(targetShardShape->size() == 2, + "Only 2D shard shape is supported in TTNN backend"); + + LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape), + "Invalid tile shape"); + + CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); + std::array ttnnShardShape; + std::copy(targetShardShape->begin(), targetShardShape->end(), + ttnnShardShape.begin()); + + ttnnShardShape[0] *= tileShape->y(); + ttnnShardShape[1] *= tileShape->x(); + + ::tt::tt_metal::TensorMemoryLayout ttnnMemLayout = + toTTNNTensorMemoryLayout(targetMemoryLayout); + + ::tt::tt_metal::BufferType ttnnBufferType = + toTTNNBufferType(targetMemorySpace); + + ::tt::tt_metal::ShardSpec shardSpec( + ttnnCoreRangeSet, ttnnShardShape, + ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); + + std::optional<::tt::tt_metal::ShardSpec> shardSpecOpt = + ttnnMemLayout == tt_metal::TensorMemoryLayout::INTERLEAVED + ? std::nullopt + : std::make_optional(shardSpec); + + ::tt::tt_metal::MemoryConfig memoryConfig{.memory_layout = ttnnMemLayout, + .buffer_type = ttnnBufferType, + .shard_spec = shardSpecOpt}; + return memoryConfig; +} + +Tensor createRuntimeTensorFromTTNN(const ::ttnn::Tensor &tensor) { + auto tensorPtr = std::make_shared<::ttnn::Tensor>(tensor); + return Tensor(std::static_pointer_cast(tensorPtr), nullptr, + DeviceRuntime::TTNN); +} + +} // namespace tt::runtime::ttnn::utils diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h index 75b22d114..353195b8d 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h @@ -6,127 +6,50 @@ #define TT_RUNTIME_TTNN_UTILS_H #include "flatbuffers/vector.h" -#include "tt_metal/impl/buffers/buffer.hpp" +#include "tt/runtime/detail/ttnn.h" #include "ttmlir/Target/Common/types_generated.h" #include "ttmlir/Target/TTNN/Target.h" -#include "ttnn/types.hpp" namespace tt::runtime::ttnn::utils { -inline bool isValidTileShape(const ::tt::target::Dim2d *shape) { - return (shape->x() == 1 and shape->y() == 1) or - (shape->x() == 32 and shape->y() == 32); -} - -inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { - switch (dataType) { - case ::tt::target::DataType::Float32: - return ::ttnn::DataType::FLOAT32; - case ::tt::target::DataType::BFloat16: - return ::ttnn::DataType::BFLOAT16; - case ::tt::target::DataType::BFP_BFloat8: - return ::ttnn::DataType::BFLOAT8_B; - case ::tt::target::DataType::BFP_BFloat4: - return ::ttnn::DataType::BFLOAT4_B; - case ::tt::target::DataType::UInt32: - return ::ttnn::DataType::UINT32; - case ::tt::target::DataType::UInt16: - return ::ttnn::DataType::UINT16; - - default: - throw std::runtime_error("Unsupported data type"); - } -} - -inline ::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType) { - switch (dataType) { - case ::ttnn::DataType::FLOAT32: - return ::tt::target::DataType::Float32; - case ::ttnn::DataType::BFLOAT16: - return ::tt::target::DataType::BFloat16; - case ::ttnn::DataType::BFLOAT8_B: - return ::tt::target::DataType::BFP_BFloat8; - case ::ttnn::DataType::BFLOAT4_B: - return ::tt::target::DataType::BFP_BFloat4; - case ::ttnn::DataType::UINT32: - return ::tt::target::DataType::UInt32; - case ::ttnn::DataType::UINT16: - return ::tt::target::DataType::UInt16; - - default: - throw std::runtime_error("Unsupported data type"); - } -} - -inline ::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout) { - switch (layout) { - case ::tt::target::TensorLayout::Tile: - return ::ttnn::Layout::TILE; - case ::tt::target::TensorLayout::RowMajor: - return ::ttnn::Layout::ROW_MAJOR; - default: - throw std::runtime_error("Unsupported layout"); - } -} - -inline ::ttnn::TensorMemoryLayout -toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout) { - - switch (tensorMemoryLayout) { - case ::tt::target::TensorMemoryLayout::Interleaved: - return ::ttnn::TensorMemoryLayout::INTERLEAVED; - case ::tt::target::TensorMemoryLayout::SingleBank: - return ::ttnn::TensorMemoryLayout::SINGLE_BANK; - case ::tt::target::TensorMemoryLayout::HeightSharded: - return ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED; - case ::tt::target::TensorMemoryLayout::WidthSharded: - return ::ttnn::TensorMemoryLayout::WIDTH_SHARDED; - case ::tt::target::TensorMemoryLayout::BlockSharded: - return ::ttnn::TensorMemoryLayout::BLOCK_SHARDED; - case ::tt::target::TensorMemoryLayout::None: - assert(false && - "Unsupported tensor memory layout TensorMemoryLayout::None"); - } -} +bool isOnHost(const ::ttnn::StorageType &storageType); + +bool isOnDevice(const ::ttnn::StorageType &storageType); + +bool isValidTileShape(const ::tt::target::Dim2d *shape); + +::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType); + +::tt::target::DataType fromTTNNDataType(::ttnn::DataType dataType); + +::ttnn::Layout toTTNNLayout(::tt::target::TensorLayout layout); + +::ttnn::TensorMemoryLayout +toTTNNTensorMemoryLayout(::tt::target::TensorMemoryLayout tensorMemoryLayout); // This method will be deprecated in favor of method below // -inline ::tt::tt_metal::BufferType -toTTNNBufferType(::tt::target::MemorySpace memorySpace) { - switch (memorySpace) { - case ::tt::target::MemorySpace::System: - case ::tt::target::MemorySpace::SystemMMIO: - return ::tt::tt_metal::BufferType::SYSTEM_MEMORY; - case ::tt::target::MemorySpace::DeviceDRAM: - return ::tt::tt_metal::BufferType::DRAM; - case ::tt::target::MemorySpace::DeviceL1: - return ::tt::tt_metal::BufferType::L1; - } -} +::tt::tt_metal::BufferType +toTTNNBufferType(::tt::target::MemorySpace memorySpace); // Prefer to use this method // -inline ::ttnn::BufferType -toTTNNBufferType(::tt::target::BufferType bufferType) { - - switch (bufferType) { - case ::tt::target::BufferType::DRAM: - return ::ttnn::BufferType::DRAM; - case ::tt::target::BufferType::L1: - return ::ttnn::BufferType::L1; - case ::tt::target::BufferType::SystemMemory: - return ::ttnn::BufferType::SYSTEM_MEMORY; - case ::tt::target::BufferType::L1Small: - return ::ttnn::BufferType::L1_SMALL; - case ::tt::target::BufferType::Trace: - return ::ttnn::BufferType::TRACE; - } -}; - -inline std::vector -toShapeFromFBShape(const flatbuffers::Vector &vec) { - return std::vector(vec.begin(), vec.end()); -} +::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType); + +std::vector +toShapeFromFBShape(const flatbuffers::Vector &vec); + +::ttnn::Layout +inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef); + +CoreRangeSet +toCoreRangeSet(const ::flatbuffers::Vector + *coreRangeSet); + +::tt::tt_metal::MemoryConfig +createMemoryConfig(const ::tt::target::TensorRef *tensorRef); + +Tensor createRuntimeTensorFromTTNN(const ::ttnn::Tensor &tensor); } // namespace tt::runtime::ttnn::utils diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 38115803f..4d18e3f1c 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -46,12 +46,13 @@ target_include_directories(TTRuntimeTTNNOps PUBLIC ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/operations/include ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common ) + target_include_directories(TTRuntimeTTNNOps SYSTEM PUBLIC "$") -target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY) +target_link_libraries(TTRuntimeTTNNOps PUBLIC TTNN_LIBRARY TTRuntimeTTNNHelpers) if (TT_RUNTIME_ENABLE_PERF_TRACE) target_link_libraries(TTRuntimeTTNNOps PUBLIC TRACY_LIBRARY) endif() -add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION) +add_dependencies(TTRuntimeTTNNOps TTNN_LIBRARY tt-metal FBS_GENERATION TTRuntimeTTNNHelpers) diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.cpp b/runtime/lib/ttnn/operations/ccl/all_gather.cpp index 37bf7427b..eee27e7ba 100644 --- a/runtime/lib/ttnn/operations/ccl/all_gather.cpp +++ b/runtime/lib/ttnn/operations/ccl/all_gather.cpp @@ -5,6 +5,7 @@ #include "all_gather.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::ccl { void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { @@ -13,7 +14,7 @@ void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { int32_t dim = op->dim(); int32_t num_links = op->num_links(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::all_gather(input, dim, num_links, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/conv/conv2d.cpp b/runtime/lib/ttnn/operations/conv/conv2d.cpp index e6670c113..5e00b929e 100644 --- a/runtime/lib/ttnn/operations/conv/conv2d.cpp +++ b/runtime/lib/ttnn/operations/conv/conv2d.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" #include "ttnn/types.hpp" @@ -23,7 +24,8 @@ void run(const ::tt::target::ttnn::Conv2dOp *op, ProgramContext &context) { auto config = ::ttnn::operations::conv::conv2d::Conv2dConfig(); config.dtype = utils::getDataType(op->input()); config.weights_dtype = utils::getDataType(op->weight()); - ::ttnn::MemoryConfig outMemConfig = utils::createMemoryConfig(op->out()); + ::ttnn::MemoryConfig outMemConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); DeviceVariant targetDevice = context.getTargetDevice(op->device()->global_id()); ::ttnn::Tensor out = std::visit( diff --git a/runtime/lib/ttnn/operations/creation/arange.cpp b/runtime/lib/ttnn/operations/creation/arange.cpp index 446cdf72a..8ddb19913 100644 --- a/runtime/lib/ttnn/operations/creation/arange.cpp +++ b/runtime/lib/ttnn/operations/creation/arange.cpp @@ -41,6 +41,6 @@ void run(const ::tt::target::ttnn::ArangeOp *op, ProgramContext &context) { ::ttnn::Tensor out = ::ttnn::arange(op->start(), op->end(), op->step(), dtype, device, memoryConfig); - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/empty.cpp b/runtime/lib/ttnn/operations/creation/empty.cpp index bed68e0f1..d504a798b 100644 --- a/runtime/lib/ttnn/operations/creation/empty.cpp +++ b/runtime/lib/ttnn/operations/creation/empty.cpp @@ -62,11 +62,12 @@ createEmptyOnMultiDevice(ProgramContext &context, EmptyTensorConfig &config, ::tt::tt_metal::DistributedTensorConfig strategy = config.distributedTensorConfig(); std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(config.numShards); - std::generate_n( - tensorShards.begin(), config.numShards, [&config]() -> ::ttnn::Tensor { - return ::ttnn::zeros(config.shape, config.dtype, config.layout); - }); + tensorShards.reserve(config.numShards); + std::generate_n(std::back_inserter(tensorShards), config.numShards, + [&config]() -> ::ttnn::Tensor { + return ::ttnn::zeros(config.shape, config.dtype, + config.layout); + }); ::ttnn::Tensor out = ::ttnn::distributed::api::create_multi_device_tensor( tensorShards, ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, strategy); if (deviceRef) { @@ -101,6 +102,6 @@ void run(const ::tt::target::ttnn::EmptyOp *op, ProgramContext &context) { } else { LOG_FATAL("Unsupported num shards"); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/creation/full.cpp b/runtime/lib/ttnn/operations/creation/full.cpp index 6a224f935..7f6a6c0b6 100644 --- a/runtime/lib/ttnn/operations/creation/full.cpp +++ b/runtime/lib/ttnn/operations/creation/full.cpp @@ -26,7 +26,7 @@ struct FullTensorConfig { fillValue(op->fill_value()), numShards(op->num_shards()), strategy(op->strategy()) { - layout = utils::inferLayoutFromTileShape(op->out()); + layout = ::tt::runtime::ttnn::utils::inferLayoutFromTileShape(op->out()); // TODO(bug #272), determine correct layout by tile shape in the future // currently tile shape is not set correctly, so as a workaround, hardcode @@ -42,8 +42,7 @@ struct FullTensorConfig { } if (!utils::inSystemMemory(op->out())) { - memoryConfig = - ::tt::runtime::ttnn::operations::utils::createMemoryConfig(op->out()); + memoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); } validate(); } @@ -72,8 +71,8 @@ createFullOnMultiDevice(ProgramContext &context, FullTensorConfig &config, ::tt::tt_metal::DistributedTensorConfig strategy = config.distributedTensorConfig(); std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(config.numShards); - std::generate_n(tensorShards.begin(), config.numShards, + tensorShards.reserve(config.numShards); + std::generate_n(std::back_inserter(tensorShards), config.numShards, [&config]() -> ::ttnn::Tensor { return ::ttnn::full(config.shape, config.fillValue, config.dtype, config.layout); @@ -116,6 +115,6 @@ void run(const ::tt::target::ttnn::FullOp *op, ProgramContext &context) { } else { LOG_FATAL("Unsupported num shards"); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::creation diff --git a/runtime/lib/ttnn/operations/data_movement/transpose.cpp b/runtime/lib/ttnn/operations/data_movement/transpose.cpp index ef8dcf1b1..c86c0ee10 100644 --- a/runtime/lib/ttnn/operations/data_movement/transpose.cpp +++ b/runtime/lib/ttnn/operations/data_movement/transpose.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::data_movement { void run(const ::tt::target::ttnn::TransposeOp *op, ProgramContext &context) { @@ -15,7 +16,7 @@ void run(const ::tt::target::ttnn::TransposeOp *op, ProgramContext &context) { int32_t dim0 = op->dim0(); int32_t dim1 = op->dim1(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::transpose(in, dim0, dim1, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/deletion/deallocate.cpp b/runtime/lib/ttnn/operations/deletion/deallocate.cpp index 6204945b3..e871a9ea6 100644 --- a/runtime/lib/ttnn/operations/deletion/deallocate.cpp +++ b/runtime/lib/ttnn/operations/deletion/deallocate.cpp @@ -11,13 +11,6 @@ void run(const ::tt::target::ttnn::DeallocateOp *op, ProgramContext &context) { ::ttnn::Tensor &tensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(tensor.is_allocated()); ::ttnn::deallocate(tensor, op->force()); - - // The tensor should be deallocated after the deallocate call. - // Still this assert may be hit in the future for multidevice/async ttnn - // support. In that case, we will reevaluate the assert/dealloc behaviour and - // adjust it accordingly. - // - DEBUG_ASSERT(!tensor.is_allocated()); tensorPool.erase(op->in()->global_id()); } } // namespace tt::runtime::ttnn::operations::deletion diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index 591397119..ff47bdcdd 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" namespace tt::runtime::ttnn::operations::binary { @@ -26,7 +27,7 @@ static void runEltwiseBinaryOp( ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputDataType, outputMemoryConfig, std::nullopt, std::nullopt, std::nullopt); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index 5c1d056f9..921b542ed 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::binary::composite { @@ -20,7 +21,7 @@ static void runEltwiseBinaryCompositeOp( getEltwiseBinaryOpInputTensors(op, tensorPool, &lhs, &rhs); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*lhs, *rhs, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp index 6afde5d66..44f141389 100644 --- a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp @@ -22,7 +22,7 @@ static void runEltwiseTernaryWhereOp( getEltwiseTernaryOpInputTensors(op, tensorPool, &first, &second, &third); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*first, *second, *third, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index 5a09b43a9..50c53f8db 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" #include "ttnn/operations/copy.hpp" @@ -22,7 +23,7 @@ static void runEltwiseUnaryOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig, std::nullopt); tensorPool.insert_or_assign(op->out()->global_id(), out); @@ -39,7 +40,7 @@ static void runEltwiseUnaryWithFastAndApproximateModeOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, false /* parameter */, outputMemoryConfig, std::nullopt); @@ -56,7 +57,7 @@ static void runEltwiseUnaryWithFloatParameterOp( float parameter = op->params_as_EltwiseOpWithFloatParams()->parameter(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, parameter, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index fd378d5a2..31514f0fe 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include "ttnn/operations/eltwise/unary/unary_composite.hpp" namespace tt::runtime::ttnn::operations::unary::composite { @@ -20,27 +21,26 @@ static void runEltwiseUnaryCompositeOp( getEltwiseUnaryOpInputTensor(op, tensorPool, &in); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } -static void runEltwiseUnaryCompositeClampOP( +static void runEltwiseUnaryCompositeClampOp( const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, - std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float, - const ::tt::tt_metal::MemoryConfig &)> - ttnnOp) { + const std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float, float, + const ::tt::tt_metal::MemoryConfig &)> + &ttnnOp) { ::ttnn::Tensor *in = nullptr; getEltwiseUnaryOpInputTensor(op, tensorPool, &in); float min = op->params_as_ClampOpParams()->min(); float max = op->params_as_ClampOpParams()->max(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ttnnOp(*in, min, max, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); - return; } void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { @@ -51,7 +51,7 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { break; } case ::tt::target::ttnn::EltwiseOpType::Clamp: { - runEltwiseUnaryCompositeClampOP(op, tensorPool, ::ttnn::clamp); + runEltwiseUnaryCompositeClampOp(op, tensorPool, ::ttnn::clamp); break; } case ::tt::target::ttnn::EltwiseOpType::Log1p: { diff --git a/runtime/lib/ttnn/operations/embedding/embedding.cpp b/runtime/lib/ttnn/operations/embedding/embedding.cpp index 47b27ca9a..511d8256d 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::embedding { void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { @@ -24,7 +25,7 @@ void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) { auto embeddingsType = ::ttnn::operations::embedding::EmbeddingsType::GENERIC; ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::ttnn::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::embedding(input, weight, padToken, layout, embeddingsType, outputDataType, outputMemoryConfig); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index c595fe26b..60ee2ddc2 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -7,25 +7,6 @@ namespace tt::runtime::ttnn::operations::utils { -// TODO (bug #701) -// Currently the memory layout/location in flatbuffer is incorrect -// These methods are workarounds such that we query the info directly from the -// TTNN tensor Ideally, we should be able to get all of this info directly from -// the flatbuffer -bool isOnHost(const ::ttnn::Tensor &tensor) { - // Currently only supports borrowed or owned host storage - return tensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or - tensor.storage_type() == ::tt::tt_metal::StorageType::OWNED or - tensor.storage_type() == - ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST; -} - -bool isOnDevice(const ::ttnn::Tensor &tensor) { - // Currently only supports single device storage - return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE or - tensor.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE; -} - bool isTilized(const ::tt::target::TensorRef *tensorRef) { const ::tt::target::Dim2d *tileShape = tensorRef->desc()->layout()->memory_desc()->tile_shape(); @@ -43,96 +24,11 @@ bool inSystemMemory(const ::tt::target::TensorRef *tensorRef) { targetMemorySpace == ::tt::target::MemorySpace::SystemMMIO; } -void updateTensorPool(ProgramTensorPool &tensorPool, - const ::ttnn::Tensor &tensor, uint32_t outputGlobalId) { - if (tensorPool.isUserOutput(outputGlobalId)) { - tensorPool.copyTensorToUserOutput(outputGlobalId, tensor); - } else { - tensorPool.insert_or_assign(outputGlobalId, tensor); - } -} - ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef) { return ::tt::runtime::ttnn::utils::toTTNNDataType( tensorRef->desc()->layout()->memory_desc()->data_type()); } -::ttnn::Layout -inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef) { - const ::tt::target::Dim2d *tileShape = - tensorRef->desc()->layout()->memory_desc()->tile_shape(); - LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape)); - if (tileShape->x() == 1 and tileShape->y() == 1) { - return ::ttnn::Layout::ROW_MAJOR; - } - return ::ttnn::Layout::TILE; -} - -CoreRangeSet -toCoreRangeSet(const ::flatbuffers::Vector - *coreRangeSet) { - std::set coreRanges; - for (::tt::target::Dim2dRange const *coreRange : *coreRangeSet) { - CoreCoord start(coreRange->loc().x(), coreRange->loc().y()); - // End is inclusive - CoreCoord end(coreRange->loc().x() + coreRange->size().x() - 1, - coreRange->loc().y() + coreRange->size().y() - 1); - - coreRanges.emplace(start, end); - } - return CoreRangeSet(coreRanges); -} - -// This method will soon be deprecated, prefer to use the method below -// -::tt::tt_metal::MemoryConfig -createMemoryConfig(const ::tt::target::TensorRef *tensorRef) { - const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); - const ::tt::target::TensorMemoryLayout targetMemoryLayout = - layout->memory_desc()->memory_layout(); - const ::tt::target::MemorySpace targetMemorySpace = - layout->memory_desc()->memory_space(); - const ::flatbuffers::Vector - *targetCoreRangeSet = layout->core_range_set(); - const ::flatbuffers::Vector *targetShardShape = - layout->memory_desc()->shape(); - const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); - - LOG_ASSERT(targetCoreRangeSet->size() == 1, - "Currently only single core range/grid is supported"); - - LOG_ASSERT(targetShardShape->size() == 2, - "Only 2D shard shape is supported in TTNN backend"); - - LOG_ASSERT(::tt::runtime::ttnn::utils::isValidTileShape(tileShape), - "Invalid tile shape"); - - CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); - std::array ttnnShardShape; - std::copy(targetShardShape->begin(), targetShardShape->end(), - ttnnShardShape.begin()); - - ttnnShardShape[0] *= tileShape->y(); - ttnnShardShape[1] *= tileShape->x(); - - ::tt::tt_metal::ShardSpec shardSpec( - ttnnCoreRangeSet, ttnnShardShape, - ::tt::tt_metal::ShardOrientation::ROW_MAJOR, false); - - ::tt::tt_metal::TensorMemoryLayout ttnnMemLayout = - ::tt::runtime::ttnn::utils::toTTNNTensorMemoryLayout(targetMemoryLayout); - - ::tt::tt_metal::BufferType ttnnBufferType = - ::tt::runtime::ttnn::utils::toTTNNBufferType(targetMemorySpace); - - return {ttnnMemLayout, ttnnBufferType, - ttnnMemLayout == tt_metal::TensorMemoryLayout::INTERLEAVED - ? std::nullopt - : std::make_optional(shardSpec)}; -} - -// Prefer to use this method over the one above -// ::tt::tt_metal::MemoryConfig createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::TensorRef *tensorRef) { @@ -147,7 +43,8 @@ createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::LayoutDesc *layout = tensorRef->desc()->layout(); const ::flatbuffers::Vector *targetCoreRangeSet = layout->core_range_set(); - CoreRangeSet ttnnCoreRangeSet = toCoreRangeSet(targetCoreRangeSet); + CoreRangeSet ttnnCoreRangeSet = + ::tt::runtime::ttnn::utils::toCoreRangeSet(targetCoreRangeSet); const ::flatbuffers::Vector *shardShape = memcfg->shard_spec()->shard_shape(); const ::tt::target::Dim2d *tileShape = layout->memory_desc()->tile_shape(); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index b922e120a..269e0328f 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -13,32 +13,15 @@ namespace tt::runtime::ttnn::operations::utils { -bool isOnHost(const ::ttnn::Tensor &tensor); - -bool isOnDevice(const ::ttnn::Tensor &tensor); - bool isTilized(const ::tt::target::TensorRef *tensorRef); bool inSystemMemory(const ::tt::target::TensorRef *tensorRef); -void updateTensorPool(ProgramTensorPool &tensorPool, - const ::ttnn::Tensor &tensor, uint32_t outputGlobalId); - ::tt::target::MemorySpace getMemorySpace(const ::tt::target::TensorRef *tensorRef); ::ttnn::DataType getDataType(const ::tt::target::TensorRef *tensorRef); -::ttnn::Layout -inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef); - -CoreRangeSet -toCoreRangeSet(const ::flatbuffers::Vector - *coreRangeSet); - -::tt::tt_metal::MemoryConfig -createMemoryConfig(const ::tt::target::TensorRef *tensorRef); - ::tt::tt_metal::MemoryConfig createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, const ::tt::target::TensorRef *tensorRef); diff --git a/runtime/lib/ttnn/operations/layout/from_device.cpp b/runtime/lib/ttnn/operations/layout/from_device.cpp index b6820b6ec..e26e3be2a 100644 --- a/runtime/lib/ttnn/operations/layout/from_device.cpp +++ b/runtime/lib/ttnn/operations/layout/from_device.cpp @@ -12,10 +12,11 @@ void run(const ::tt::target::ttnn::FromDeviceOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(inputTensor.is_allocated()); - LOG_ASSERT(utils::isOnDevice(inputTensor), - "Calling ttnn::from_device on a host tensor"); + DEBUG_ASSERT( + ::tt::runtime::ttnn::utils::isOnDevice(inputTensor.storage_type()), + "Calling ttnn::from_device on a host tensor"); ::ttnn::Tensor out = ::ttnn::from_device(inputTensor); - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/layout/to_device.cpp b/runtime/lib/ttnn/operations/layout/to_device.cpp index 34af89f50..414afc9f0 100644 --- a/runtime/lib/ttnn/operations/layout/to_device.cpp +++ b/runtime/lib/ttnn/operations/layout/to_device.cpp @@ -14,7 +14,7 @@ void run(const ::tt::target::ttnn::ToDeviceOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); const ::ttnn::Tensor &inputTensor = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(inputTensor.is_allocated()); - DEBUG_ASSERT(utils::isOnHost(inputTensor), + DEBUG_ASSERT(::tt::runtime::ttnn::utils::isOnHost(inputTensor.storage_type()), "Calling ttnn::to_device on a device tensor"); std::optional<::ttnn::MemoryConfig> memoryConfig = std::nullopt; diff --git a/runtime/lib/ttnn/operations/layout/to_layout.cpp b/runtime/lib/ttnn/operations/layout/to_layout.cpp index 5e78a6718..bf80ef292 100644 --- a/runtime/lib/ttnn/operations/layout/to_layout.cpp +++ b/runtime/lib/ttnn/operations/layout/to_layout.cpp @@ -57,7 +57,7 @@ void run(const ::tt::target::ttnn::ToLayoutOp *op, ProgramContext &context) { out = ::ttnn::to_layout(inputTensor, layout, dtype, memoryConfig, static_cast<::ttnn::Device *>(nullptr)); } - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/layout/typecast.cpp b/runtime/lib/ttnn/operations/layout/typecast.cpp index 5529c6112..e59a64a40 100644 --- a/runtime/lib/ttnn/operations/layout/typecast.cpp +++ b/runtime/lib/ttnn/operations/layout/typecast.cpp @@ -17,7 +17,7 @@ void run(const ::tt::target::ttnn::TypecastOp *op, ProgramContext &context) { ::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype()); ::ttnn::Tensor out = ::ttnn::typecast(inputTensor, targetDataType); - utils::updateTensorPool(tensorPool, out, op->out()->global_id()); + tensorPool.insert_or_assign(op->out()->global_id(), out); } } // namespace tt::runtime::ttnn::operations::layout diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index a25102d9a..896797d59 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" #include namespace tt::runtime::ttnn::operations::matmul { @@ -18,7 +19,7 @@ void run(const ::tt::target::ttnn::MatmulOp *op, ProgramContext &context) { DEBUG_ASSERT(rhs.is_allocated()); ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); const std::optional memoryConfig = std::make_optional(outputMemoryConfig); @@ -49,7 +50,7 @@ void run(const ::tt::target::ttnn::LinearOp *op, ProgramContext &context) { ::ttnn::DataType outputDataType = utils::getDataType(op->out()); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); const std::optional memoryConfig = std::make_optional(outputMemoryConfig); diff --git a/runtime/lib/ttnn/operations/normalization/softmax.cpp b/runtime/lib/ttnn/operations/normalization/softmax.cpp index a83358567..432f92095 100644 --- a/runtime/lib/ttnn/operations/normalization/softmax.cpp +++ b/runtime/lib/ttnn/operations/normalization/softmax.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::normalization { void run(const ::tt::target::ttnn::SoftmaxOp *op, ProgramContext &context) { @@ -14,7 +15,7 @@ void run(const ::tt::target::ttnn::SoftmaxOp *op, ProgramContext &context) { DEBUG_ASSERT(in.is_allocated()); int32_t dimension = op->dimension(); ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = ::ttnn::softmax(in, dimension, outputMemoryConfig); tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index 4fc6fca87..c405a86f1 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -61,7 +61,8 @@ void run(const ::tt::target::ttnn::MaxPool2dOp *op, ProgramContext &context) { }, targetDevice); } - ::ttnn::MemoryConfig outMemConfig = utils::createMemoryConfig(op->out()); + ::ttnn::MemoryConfig outMemConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); ::ttnn::Tensor out = operation.invoke( 0, input, op->batch_size(), op->input_height(), op->input_width(), op->channels(), {op->kernel_height(), op->kernel_width()}, diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 3af46efc9..a74373ee9 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -6,6 +6,7 @@ #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" +#include "tt/runtime/ttnn/utils.h" namespace tt::runtime::ttnn::operations::reduction { static void runReductionOp( @@ -17,7 +18,7 @@ static void runReductionOp( const std::optional<::ttnn::DeviceComputeKernelConfig> &, float)> &ttnnOp) { ::tt::tt_metal::MemoryConfig outputMemoryConfig = - utils::createMemoryConfig(op->out()); + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(in.is_allocated()); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 3aab3a94c..a45c2de9a 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -30,6 +30,7 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/types.h" +#include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" @@ -49,36 +50,25 @@ void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) { static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( binary.handle.get()); - if (not isTTNN) { - throw std::runtime_error("Unsupported binary format"); - } + LOG_ASSERT(isTTNN, "Unsupported binary format"); return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); } class ProgramExecutor { public: - ProgramExecutor(Binary &executableHandle, const TensorMap &liveTensors, - const std::unordered_set &programInputs, - const std::unordered_set &programOutputs, - ::ttnn::MeshDevice *meshDevice) + ProgramExecutor( + const Binary &executableHandle, + const std::unordered_map &liveTensors, + const std::vector &programInputs, + const std::vector &programOutputs, + ::ttnn::MeshDevice *meshDevice) : executableHandle(executableHandle), context(ProgramContext(liveTensors, programInputs, programOutputs, meshDevice)) {} void runCallback(Binary &executableHandle, const ::tt::target::ttnn::Operation *opContext, - ProgramContext *programContext) { - if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { - std::shared_ptr programContextPtr = - ::tt::runtime::utils::unsafe_borrow_shared(programContext); - std::shared_ptr opContextPtr = - ::tt::runtime::utils::unsafe_borrow_shared( - const_cast<::tt::target::ttnn::Operation *>(opContext)); - (*callback)(executableHandle, - CallbackContext(programContextPtr, DeviceRuntime::TTNN), - OpContext(opContextPtr, DeviceRuntime::TTNN)); - } - } + ProgramContext *programContext); void execute(const ::tt::target::ttnn::Program *program) { for (const ::tt::target::ttnn::Operation *op : *program->operations()) { @@ -91,6 +81,9 @@ class ProgramExecutor { } ProgramContext &getContext() { return context; } + std::vector gatherOutputTensors() { + return context.getTensorPool().gatherOutputTensors(); + } private: Binary executableHandle; @@ -99,6 +92,21 @@ class ProgramExecutor { void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op); }; +void ProgramExecutor::runCallback( + Binary &executableHandle, const ::tt::target::ttnn::Operation *opContext, + ProgramContext *programContext) { + if (auto callback = debug::Hooks::get().getOperatorCallback(); callback) { + std::shared_ptr programContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared(programContext); + std::shared_ptr opContextPtr = + ::tt::runtime::utils::unsafe_borrow_shared( + const_cast<::tt::target::ttnn::Operation *>(opContext)); + (*callback)(executableHandle, + CallbackContext(programContextPtr, DeviceRuntime::TTNN), + OpContext(opContextPtr, DeviceRuntime::TTNN)); + } +} + void ProgramExecutor::runEltwiseOperation( const ::tt::target::ttnn::EltwiseOp *op) { auto runUnaryOp = [&]() { @@ -211,6 +219,26 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { } // Nop is single input, output tensor where input is returned as output. +static bool isNopProgram(const ::tt::target::ttnn::Program *program) { + return program->inputs()->size() == 1 && program->outputs()->size() == 1 && + program->inputs()->Get(0)->global_id() == + program->outputs()->Get(0)->global_id(); +} + +static ::ttnn::Tensor +handleNopProgram(::tt::target::ttnn::Program const *program, + std::vector<::ttnn::Tensor *> const &inputs) { + const ::ttnn::Tensor &input = *inputs[0]; + ::ttnn::Tensor output = + ::ttnn::zeros(input.get_shape(), input.get_dtype(), input.get_layout()); + const void *src = ::tt::tt_metal::get_raw_host_data_ptr(input); + void *dst = ::tt::tt_metal::get_raw_host_data_ptr(output); + std::memcpy(dst, src, input.volume() * input.element_size()); + return output; +} + +namespace legacy { + static bool handleNopProgram(::tt::target::ttnn::Program const *program, std::vector<::ttnn::Tensor *> const &inputs, std::vector<::ttnn::Tensor *> const &outputs) { @@ -239,8 +267,8 @@ void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, if (handleNopProgram(program, inputs, outputs)) { return; } - TensorMap liveTensors; - std::unordered_set programInputs; + std::unordered_map liveTensors; + std::vector programInputs; int inputIndex = 0; LOG_ASSERT(program->inputs()->size() == inputs.size(), "Program input size mismatch: ", program->inputs()->size(), @@ -249,21 +277,69 @@ void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, auto [iter, inserted] = liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); LOG_ASSERT(inserted, "Duplicate input tensor"); - programInputs.emplace(input->global_id()); + programInputs.push_back(input->global_id()); } int outputIndex = 0; - std::unordered_set programOutputs; + std::vector programOutputs; LOG_ASSERT(program->outputs()->size() == outputs.size()); for (::tt::target::TensorRef const *output : *program->outputs()) { auto [iter, inserted] = liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]); LOG_ASSERT(inserted, "Duplicate output tensor"); - programOutputs.emplace(output->global_id()); + programOutputs.push_back(output->global_id()); + } + ProgramExecutor executor(executableHandle, liveTensors, programInputs, + programOutputs, &meshDevice); + executor.execute(program); + outputIndex = 0; + for (uint32_t outputId : programOutputs) { + const ::ttnn::Tensor &src = + executor.getContext().getTensorPool().at(outputId); + const ::ttnn::Tensor &dst = *(outputs[outputIndex++]); + size_t srcSize = src.volume() * src.element_size(); + size_t dstSize = dst.volume() * dst.element_size(); + LOG_ASSERT(srcSize == dstSize, "Output tensor size mismatch"); + const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(src); + void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dst); + std::memcpy(dstPtr, srcPtr, dstSize); + } +} +} // namespace legacy + +std::vector runProgram(::ttnn::MeshDevice &meshDevice, + Binary executableHandle, + std::uint32_t programIndex, + std::vector<::ttnn::Tensor *> const &inputs) { + ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); + ::tt::target::ttnn::Program const *program = + fbb.programs()->Get(programIndex); + if (isNopProgram(program)) { + Tensor out = + utils::createRuntimeTensorFromTTNN(handleNopProgram(program, inputs)); + return {out}; + } + std::unordered_map liveTensors; + std::vector programInputs; + int inputIndex = 0; + LOG_ASSERT(program->inputs()->size() == inputs.size(), + "Program input size mismatch: ", program->inputs()->size(), + " != ", inputs.size()); + for (::tt::target::TensorRef const *input : *program->inputs()) { + auto [iter, inserted] = + liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); + LOG_ASSERT(inserted, "Duplicate input tensor"); + programInputs.push_back(input->global_id()); + } + std::vector programOutputs; + for (::tt::target::TensorRef const *output : *program->outputs()) { + programOutputs.push_back(output->global_id()); } ProgramExecutor executor(executableHandle, liveTensors, programInputs, programOutputs, &meshDevice); executor.execute(program); + std::vector outputTensors = executor.gatherOutputTensors(); + return outputTensors; } } // namespace tt::runtime::ttnn diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 2dfc07788..466bf318b 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -1,7 +1,6 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "tt/runtime/runtime.h" #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" @@ -21,16 +20,28 @@ using ::tt::tt_metal::DistributedTensorConfig; using ::tt::tt_metal::OwnedStorage; using ::tt::tt_metal::raise_unsupported_storage; +template +static OwnedStorage createOwnedStorage(ElementType *ptr, + std::uint32_t numElements) { + ::tt::tt_metal::owned_buffer::Buffer buffer; + if (ptr != nullptr) { + auto data = std::vector(ptr, ptr + numElements); + buffer = ::tt::tt_metal::owned_buffer::create(std::move(data)); + } else { + buffer = ::tt::tt_metal::owned_buffer::create(numElements); + } + return OwnedStorage(std::move(buffer)); +} + template static StorageType createStorage(ElementType *ptr, std::uint32_t numElements) { if constexpr (std::is_same_v) { + LOG_ASSERT(ptr != nullptr, "Cannot create borrowed storage from nullptr"); return BorrowedStorage( ::tt::tt_metal::borrowed_buffer::Buffer(ptr, numElements), [] {}, [] {}); } else if constexpr (std::is_same_v) { - auto data = std::vector(ptr, ptr + numElements); - auto buffer = ::tt::tt_metal::owned_buffer::create(std::move(data)); - return OwnedStorage(std::move(buffer)); + return createOwnedStorage(ptr, numElements); } else { raise_unsupported_storage(); } @@ -76,6 +87,21 @@ static Tensor createNullTensor() { return Tensor(nullptr, nullptr, DeviceRuntime::TTNN); } +static DeviceVariant getTargetDevice(::ttnn::MeshDevice &meshDevice) { + if (meshDevice.num_devices() == 1) { + return std::ref(*(meshDevice.get_device_index(0))); + } + return std::ref(meshDevice); +} + +static ::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) { + bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( + binary.handle.get()); + LOG_ASSERT(isTTNN, "Unsupported binary format"); + return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get()); +} + +// Create a borrowed tensor from user-owned data Tensor createTensor(std::shared_ptr data, std::vector const &shape, std::vector const &stride, @@ -89,10 +115,11 @@ Tensor createTensor(std::shared_ptr data, createStorage(data.get(), numElements, dataType), ::ttnn::Shape(small_vector_shape), utils::toTTNNDataType(dataType), ::ttnn::Layout::ROW_MAJOR); - return Tensor(std::static_pointer_cast(tensor), data, + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } +// Create a owned multi-device host tensor from user-owned data Tensor createTensor(std::vector> &data, std::vector const &shape, @@ -100,8 +127,8 @@ createTensor(std::vector> &data, ::tt::target::DataType dataType, std::unordered_map const &strategy) { std::vector<::ttnn::Tensor> tensorShards; - tensorShards.resize(data.size()); - std::transform(data.begin(), data.end(), tensorShards.begin(), + tensorShards.reserve(data.size()); + std::transform(data.begin(), data.end(), std::back_inserter(tensorShards), [&](std::shared_ptr &dataShard) -> ::ttnn::Tensor { return createOwnedTensor(dataShard, shape, stride, itemsize, dataType); @@ -112,13 +139,35 @@ createTensor(std::vector> &data, ::ttnn::distributed::api::create_multi_device_tensor( tensorShards, ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, distributionStrategy)); - std::shared_ptr>> borrowedData = - std::make_shared>>(data); - return Tensor(std::static_pointer_cast(tensor), - std::static_pointer_cast(borrowedData), + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } +// Create an owned empty tensor on host/device +Tensor createTensor(Device device, Layout layout, + std::vector const &shape, + std::vector const &stride, + std::uint32_t itemsize) { + const LayoutDesc &layoutDesc = layout.as(DeviceRuntime::TTNN); + if (layoutDesc.isOnHost()) { + ::ttnn::Tensor tensor = + createOwnedTensor(nullptr, shape, stride, itemsize, + utils::fromTTNNDataType(layoutDesc.dataType)); + Tensor out = utils::createRuntimeTensorFromTTNN(tensor); + return toLayout(out, device, layout); + } + DeviceVariant targetDevice = + getTargetDevice(device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN)); + ::ttnn::Tensor tensor = std::visit( + [&](auto &&device) -> ::ttnn::Tensor { + return ::ttnn::operations::core::allocate_tensor_on_device( + ::ttnn::Shape(shape), layoutDesc.dataType, layoutDesc.layout, + &(device.get()), layoutDesc.memoryConfig); + }, + targetDevice); + return utils::createRuntimeTensorFromTTNN(tensor); +} + tt::target::DataType getTensorDataType(Tensor tensor) { const ::ttnn::Tensor &nnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); @@ -166,34 +215,110 @@ void deallocateBuffers(Device deviceHandle) { } } -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, - std::vector const &inputHandles, - std::vector const &outputHandles) { - ::ttnn::MeshDevice &meshDevice = - deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); - std::vector<::ttnn::Tensor *> inputs; - inputs.reserve(inputHandles.size()); - for (auto &input : inputHandles) { - LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); - inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); +void wait(Event event) { + // Nothing to do for ttnn runtime + LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); +} + +void wait(Tensor tensor) { + LOG_ASSERT(tensor.matchesRuntime(DeviceRuntime::TTNN), + "Expected ttnn tensor"); + ::tt::runtime::ttnn::wait(tensor.event); +} + +void wait(std::vector const &tensors) { + for (const Tensor &tensor : tensors) { + ::tt::runtime::ttnn::wait(tensor); } +} - std::vector<::ttnn::Tensor *> outputs; - outputs.reserve(outputHandles.size()); - for (auto &output : outputHandles) { - LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); - outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); +Tensor toHost(Tensor tensor, bool untilize) { + const ::ttnn::Tensor &deviceTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::shared_ptr<::ttnn::Tensor> hostTensor = + std::make_shared<::ttnn::Tensor>(::ttnn::from_device(deviceTensor)); + + if (untilize) { + hostTensor = std::make_shared<::ttnn::Tensor>(::ttnn::to_layout( + *hostTensor, ::ttnn::Layout::ROW_MAJOR, std::nullopt, std::nullopt, + static_cast<::ttnn::Device *>(nullptr))); } - tt::runtime::ttnn::runProgram(meshDevice, executableHandle, programIndex, - inputs, outputs); - return Event(nullptr, DeviceRuntime::TTNN); + return Tensor(std::static_pointer_cast(hostTensor), nullptr, + DeviceRuntime::TTNN); } -void wait(Event event) { - // Not implemented - LOG_ASSERT(event.matchesRuntime(DeviceRuntime::TTNN)); +Tensor toLayout(Tensor tensor, Device device, Layout layout) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + const ::ttnn::Layout &inputLayout = ttnnTensor.get_layout(); + const ::ttnn::DataType &inputDataType = ttnnTensor.get_dtype(); + LayoutDesc inputLayoutDesc(::ttnn::BufferType::SYSTEM_MEMORY, inputLayout, + inputDataType, std::nullopt); + + const LayoutDesc &outputLayoutDesc = + layout.as(DeviceRuntime::TTNN); + + ::ttnn::MeshDevice &meshDevice = + device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + DeviceVariant targetDevice = getTargetDevice(meshDevice); + LayoutConverter converter(inputLayoutDesc, outputLayoutDesc); + std::shared_ptr<::ttnn::Tensor> out = std::make_shared<::ttnn::Tensor>( + converter.convertTensorLayout(ttnnTensor, targetDevice)); + + return Tensor(std::static_pointer_cast(out), nullptr, + DeviceRuntime::TTNN); +} + +Layout getLayout(Binary executableHandle, std::uint32_t programIndex, + std::uint32_t inputIndex) { + const ::tt::target::ttnn::TTNNBinary &fbb = *getBinary(executableHandle); + LOG_ASSERT(programIndex < fbb.programs()->size(), "Invalid program index"); + const ::tt::target::ttnn::Program *program = + fbb.programs()->Get(programIndex); + LOG_ASSERT(inputIndex < program->inputs()->size(), "Invalid input index"); + const ::tt::target::TensorRef *input = program->inputs()->Get(inputIndex); + + ::ttnn::BufferType inputBufferType = utils::toTTNNBufferType( + input->desc()->layout()->memory_desc()->memory_space()); + ::ttnn::Layout inputLayout = utils::inferLayoutFromTileShape(input); + ::ttnn::DataType inputDataType = utils::toTTNNDataType( + input->desc()->layout()->memory_desc()->data_type()); + std::optional<::ttnn::MemoryConfig> inputMemoryConfig = std::nullopt; + if (inputBufferType != ::ttnn::BufferType::SYSTEM_MEMORY) { + inputMemoryConfig = utils::createMemoryConfig(input); + } + + std::shared_ptr layoutDesc = std::make_shared( + inputBufferType, inputLayout, inputDataType, inputMemoryConfig); + + return Layout(std::static_pointer_cast(layoutDesc), + DeviceRuntime::TTNN); +} + +void memcpy(Tensor dst, Tensor src) { + ::ttnn::Tensor &dstTensor = dst.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + const ::ttnn::Tensor &srcTensor = src.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + LOG_ASSERT(srcTensor.volume() * srcTensor.element_size() == + dstTensor.volume() * dstTensor.element_size(), + "Input output tensor size mismatch in memcpy: ", + srcTensor.volume(), " * ", srcTensor.element_size(), + " != ", dstTensor.volume(), " * ", dstTensor.element_size()); + + if (utils::isOnHost(srcTensor.storage_type()) and + utils::isOnHost(dstTensor.storage_type())) { + void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dstTensor); + void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(srcTensor); + size_t size = srcTensor.volume() * srcTensor.element_size(); + std::memcpy(dstPtr, srcPtr, size); + } else { + ::tt::tt_metal::memcpy(dstTensor, srcTensor); + } +} + +void deallocateTensor(Tensor &tensor, bool force) { + ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + ::ttnn::deallocate(ttnnTensor, force); } std::string getOpDebugString(OpContext opContextHandle) { @@ -305,7 +430,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, return createNullTensor(); } default: { - throw std::runtime_error("Unsupported operation type"); + LOG_FATAL("Unsupported operation type"); } } @@ -332,12 +457,13 @@ Tensor getOpOutputTensor(OpContext opContextHandle, outCopy.shape().value, ::ttnn::DataType::FLOAT32, ::ttnn::Layout::ROW_MAJOR); - return Tensor(std::static_pointer_cast(tensor), data, + return Tensor(std::static_pointer_cast(tensor), nullptr, DeviceRuntime::TTNN); } std::vector getTensorData(Tensor tensor) { - ::ttnn::Tensor *nnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + const ::ttnn::Tensor *nnTensor = + static_cast<::ttnn::Tensor *>(tensor.handle.get()); if (nnTensor == nullptr) { return {}; } @@ -347,4 +473,62 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } +namespace legacy { + +Event submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles, + std::vector const &outputHandles) { + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + std::vector<::ttnn::Tensor *> inputs; + inputs.reserve(inputHandles.size()); + for (auto &input : inputHandles) { + LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); + inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); + } + + std::vector<::ttnn::Tensor *> outputs; + outputs.reserve(outputHandles.size()); + for (auto &output : outputHandles) { + LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); + outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); + } + + tt::runtime::ttnn::legacy::runProgram(meshDevice, executableHandle, + programIndex, inputs, outputs); + return Event(nullptr, DeviceRuntime::TTNN); +} +} // namespace legacy + +std::vector submit(Device deviceHandle, Binary executableHandle, + std::uint32_t programIndex, + std::vector const &inputHandles) { + ::ttnn::MeshDevice &meshDevice = + deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); + + // Convert input tensors to the layout expected by the program + std::vector inputsWithLayout; + inputsWithLayout.reserve(inputHandles.size()); + std::transform( + inputHandles.begin(), inputHandles.end(), + std::back_inserter(inputsWithLayout), [&](const Tensor &input) -> Tensor { + Layout inputLayout = ::tt::runtime::ttnn::getLayout( + executableHandle, programIndex, inputsWithLayout.size()); + return ::tt::runtime::ttnn::toLayout(input, deviceHandle, inputLayout); + }); + + std::vector<::ttnn::Tensor *> ttnnInputs; + ttnnInputs.reserve(inputsWithLayout.size()); + std::transform(inputsWithLayout.begin(), inputsWithLayout.end(), + std::back_inserter(ttnnInputs), + [](Tensor &input) -> ::ttnn::Tensor * { + return &input.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + }); + + std::vector outputs = ::tt::runtime::ttnn::runProgram( + meshDevice, executableHandle, programIndex, ttnnInputs); + return outputs; +} + } // namespace tt::runtime::ttnn diff --git a/runtime/test/CMakeLists.txt b/runtime/test/CMakeLists.txt index e4a7adc40..f55a6c761 100644 --- a/runtime/test/CMakeLists.txt +++ b/runtime/test/CMakeLists.txt @@ -1,7 +1,31 @@ +if (NOT TTMLIR_ENABLE_RUNTIME_TESTS) + add_library(TTRuntimeTTNNTestHelpers INTERFACE) + return() +endif() + if (NOT TTMLIR_ENABLE_RUNTIME OR (NOT TT_RUNTIME_ENABLE_TTNN AND NOT TT_RUNTIME_ENABLE_TTMETAL)) message(FATAL_ERROR "Runtime tests require -DTTMLIR_ENABLE_RUNTIME=ON and at least one backend runtime to be enabled") endif() +if (NOT TT_RUNTIME_ENABLE_TTNN) + add_library(TTRuntimeTTNNTestHelpers INTERFACE) +else() + add_library(TTRuntimeTTNNTestHelpers + STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/test/utils.cpp + ) + set_property(TARGET TTRuntimeTTNNTestHelpers PROPERTY CXX_STANDARD 20) + target_compile_options(TTRuntimeTTNNTestHelpers PUBLIC -mavx -mavx2 -fsized-deallocation) + target_include_directories(TTRuntimeTTNNTestHelpers PUBLIC + ${PROJECT_SOURCE_DIR}/runtime/include + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include + ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common + ) + target_include_directories(TTRuntimeTTNNTestHelpers SYSTEM PUBLIC "$") + add_dependencies(TTRuntimeTTNNTestHelpers TTRuntime tt-metal FBS_GENERATION) + target_link_libraries(TTRuntimeTTNNTestHelpers PUBLIC TTRuntime TTNN_LIBRARY) +endif() + enable_testing() include(FetchContent) FetchContent_Declare( diff --git a/runtime/test/include/tt/runtime/ttnn/test/utils.cpp b/runtime/test/include/tt/runtime/ttnn/test/utils.cpp new file mode 100644 index 000000000..e0cc969b7 --- /dev/null +++ b/runtime/test/include/tt/runtime/ttnn/test/utils.cpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt/runtime/test/utils.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/runtime.h" +#include "tt/runtime/ttnn/types.h" +#include "tt/runtime/ttnn/utils.h" +#include "tt/runtime/types.h" + +namespace tt::runtime::ttnn::test { +using ::tt::runtime::DeviceRuntime; +Layout getDramInterleavedTileLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::DRAM, + ::ttnn::Layout::TILE, ttnnDataType, + std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +Layout getDramInterleavedRowMajorLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::DRAM, + ::ttnn::Layout::ROW_MAJOR, + ttnnDataType, std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +::tt::runtime::Layout getHostRowMajorLayout(::tt::target::DataType dataType) { + LOG_ASSERT(getCurrentRuntime() == DeviceRuntime::TTNN); + ::ttnn::DataType ttnnDataType = + ::tt::runtime::ttnn::utils::toTTNNDataType(dataType); + ::tt::runtime::ttnn::LayoutDesc layoutDesc(::ttnn::BufferType::SYSTEM_MEMORY, + ::ttnn::Layout::ROW_MAJOR, + ttnnDataType, std::nullopt); + return Layout( + std::static_pointer_cast( + std::make_shared<::tt::runtime::ttnn::LayoutDesc>(layoutDesc)), + ::tt::runtime::DeviceRuntime::TTNN); +} +} // namespace tt::runtime::ttnn::test diff --git a/runtime/test/python/ttnn/conftest.py b/runtime/test/python/ttnn/conftest.py new file mode 100644 index 000000000..854cb42a3 --- /dev/null +++ b/runtime/test/python/ttnn/conftest.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +try: + import ttrt +except (ImportError, ModuleNotFoundError): + raise ImportError( + "Error: runtime python tests require ttrt to built and installed. Please run `cmake --build build -- ttrt`" + ) +import ttrt.runtime +from ttrt.common.api import API +from utils import Helper +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def initialize(): + API.initialize_apis() + ttrt.runtime.set_current_runtime(ttrt.runtime.DeviceRuntime.TTNN) + + +@pytest.fixture(scope="module") +def helper(): + helper = Helper() + yield helper diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py new file mode 100644 index 000000000..fe914d0c9 --- /dev/null +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import ttrt +import ttrt.runtime +import torch +from ttrt.common.util import * +from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc + + +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_to_layout(helper: Helper, shape, dtype, request): + helper.initialize(request.node.name) + helper.check_constraints() + torch_input_tensor = torch.randn(shape, dtype=dtype) + torch_result_tensor = torch.zeros(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + runtime_input_tensor = ttrt.runtime.create_tensor( + torch_input_tensor.data_ptr(), + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + runtime_dtype, + ) + runtime_output_tensor = ttrt.runtime.create_tensor( + torch_result_tensor.data_ptr(), + list(torch_result_tensor.shape), + list(torch_result_tensor.stride()), + torch_result_tensor.element_size(), + runtime_dtype, + ) + device_layout = ttrt.runtime.testing.get_dram_interleaved_tile_layout(runtime_dtype) + host_layout = ttrt.runtime.testing.get_host_row_major_layout(runtime_dtype) + with DeviceContext([helper.query.device_ids[0]]) as device: + device_tensor = ttrt.runtime.to_layout( + runtime_input_tensor, device, device_layout + ) + host_tensor = ttrt.runtime.to_layout(device_tensor, device, host_layout) + ttrt.runtime.deallocate_tensor(device_tensor, force=True) + ttrt.runtime.memcpy(runtime_output_tensor, host_tensor) + ttrt.runtime.deallocate_tensor(host_tensor, force=True) + + lambda: assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.999) + helper.teardown() + + +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_create_tensor_memcpy(helper: Helper, shape, dtype, request): + helper.initialize(request.node.name) + helper.check_constraints() + torch_input_tensor = torch.randn(shape, dtype=dtype) + torch_result_tensor = torch.zeros(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + runtime_input_tensor = ttrt.runtime.create_tensor( + torch_input_tensor.data_ptr(), + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + runtime_dtype, + ) + runtime_output_tensor = ttrt.runtime.create_tensor( + torch_result_tensor.data_ptr(), + list(torch_result_tensor.shape), + list(torch_result_tensor.stride()), + torch_result_tensor.element_size(), + runtime_dtype, + ) + device_layout = ttrt.runtime.testing.get_dram_interleaved_row_major_layout( + runtime_dtype + ) + with DeviceContext([helper.query.device_ids[0]]) as device: + device_tensor = ttrt.runtime.create_empty_tensor( + device, + device_layout, + list(torch_input_tensor.shape), + list(torch_input_tensor.stride()), + torch_input_tensor.element_size(), + ) + ttrt.runtime.memcpy(device_tensor, runtime_input_tensor) + host_tensor = ttrt.runtime.to_host(device_tensor, untilize=True) + ttrt.runtime.deallocate_tensor(device_tensor, force=True) + ttrt.runtime.memcpy(runtime_output_tensor, host_tensor) + ttrt.runtime.deallocate_tensor(host_tensor, force=True) + lambda: assert_pcc(torch_input_tensor, torch_result_tensor, threshold=0.999) + helper.teardown() + + +def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request): + binary_path = f"{TT_MLIR_HOME}/build/test/ttmlir/Runtime/TTNN/runtime_stitching/Output/eltwise_binary_op_chain.mlir.tmp.ttnn" + helper.initialize(request.node.name, binary_path) + helper.check_constraints() + first_program: Binary.Program = helper.binary.get_program(0) + assert first_program.num_inputs() == 2 + inputs_torch = [] + inputs_runtime = [] + input_layouts = [] + for i in first_program.program["inputs"]: + torch_tensor = torch.randn( + i["desc"]["shape"], + dtype=Binary.Program.from_data_type( + i["desc"]["layout"]["memory_desc"]["data_type"] + ), + ) + runtime_dtype = Binary.Program.to_data_type(torch_tensor.dtype) + inputs_torch.append(torch_tensor) + runtime_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + inputs_runtime.append(runtime_tensor) + input_layouts.append( + ttrt.runtime.testing.get_dram_interleaved_row_major_layout(runtime_dtype) + ) + + activations, weights = inputs_runtime + activations_layout, weights_layout = input_layouts + with DeviceContext([helper.query.device_ids[0]]) as device: + activations = ttrt.runtime.to_layout(activations, device, activations_layout) + weights = ttrt.runtime.to_layout(weights, device, weights_layout) + program_indices = list(range(helper.binary.get_num_programs())) + for program_index in program_indices: + program = helper.binary.get_program(program_index) + assert program.num_inputs() == 2 and program.num_outputs() == 1 + outputs = ttrt.runtime.submit( + device, helper.binary.fbb, program_index, [activations, weights] + ) + activations = ttrt.runtime.to_layout(outputs[0], device, activations_layout) + ttrt.runtime.deallocate_tensor(outputs[0], force=True) + activations = ttrt.runtime.to_host(activations, untilize=True) + ttrt.runtime.deallocate_tensor(weights, force=True) + + last_program: Binary.Program = helper.binary.get_program(program_indices[-1]) + torch_result_tensor = torch.randn( + last_program.program["outputs"][0]["desc"]["shape"], + dtype=Binary.Program.from_data_type( + last_program.program["outputs"][0]["desc"]["layout"]["memory_desc"][ + "data_type" + ] + ), + ) + runtime_result_tensor = ttrt.runtime.create_tensor( + torch_result_tensor.data_ptr(), + list(torch_result_tensor.shape), + list(torch_result_tensor.stride()), + torch_result_tensor.element_size(), + Binary.Program.to_data_type(torch_result_tensor.dtype), + ) + ttrt.runtime.memcpy(runtime_result_tensor, activations) + golden = ( + (inputs_torch[0] + inputs_torch[1]).mul(inputs_torch[1]).sub(inputs_torch[1]) + ) + assert_pcc(golden, torch_result_tensor, threshold=0.999), program_index + helper.teardown() diff --git a/runtime/test/python/ttnn/utils.py b/runtime/test/python/ttnn/utils.py new file mode 100644 index 000000000..6596811ff --- /dev/null +++ b/runtime/test/python/ttnn/utils.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import ttrt +import ttrt.runtime +import torch +from ttrt.common.query import Query +from ttrt.common.util import * + +TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "") + + +class Helper: + def __init__(self, logger=None): + self.artifacts_dir = f"{os.getcwd()}/ttrt-artifacts" + self.logger = logger if logger is not None else Logger() + self.logging = self.logger.get_logger() + self.file_manager = FileManager(self.logger) + self.artifacts = Artifacts( + self.logger, self.file_manager, artifacts_folder_path=self.artifacts_dir + ) + self.query = Query({"--quiet": True}, self.logger, self.artifacts) + self.query() + self.test_name = None + self.binary_path = None + self.binary = None + + def initialize(self, test_name, binary_path=None): + self.test_name = test_name + if binary_path: + self.binary_path = binary_path + self.binary = Binary(self.logger, self.file_manager, binary_path) + + def teardown(self): + self.test_name = None + self.binary_path = None + self.binary = None + + def check_constraints(self): + if not self.binary: + return + self.binary.check_version() + self.binary.check_system_desc(self.query) + + +class DeviceContext: + def __init__(self, device_ids): + self.device = ttrt.runtime.open_device(device_ids) + + def __enter__(self): + return self.device + + def __exit__(self, exc_type, exc_value, traceback): + ttrt.runtime.close_device(self.device) + + +def assert_tensors_match(tensor1, tensor2): + assert torch.allclose(tensor1, tensor2) + + +def assert_pcc(x, y, threshold=0.99): + combined = torch.stack([x.flatten(), y.flatten()]) + pcc = torch.corrcoef(combined)[0, 1].item() + assert pcc >= threshold, f"Expected pcc {pcc} >= {threshold}" diff --git a/runtime/test/ttnn/test_subtract.cpp b/runtime/test/ttnn/test_subtract.cpp index 00aebe20f..995b95665 100644 --- a/runtime/test/ttnn/test_subtract.cpp +++ b/runtime/test/ttnn/test_subtract.cpp @@ -21,12 +21,13 @@ TEST(TTNNSubtract, Equal) { const char *fbPath = std::getenv("TTMLIR_SUBTRACT_FB_PATH"); assert(fbPath && "Path to subtract flatbuffer must be provided"); ::tt::runtime::Binary fbb = ::tt::runtime::Binary::loadFromPath(fbPath); - EXPECT_EQ(fbb.getFileIdentifier(), "TTNN"); + ASSERT_EQ(fbb.getFileIdentifier(), "TTNN"); ::tt::runtime::setCompatibleRuntime(fbb); std::vector<::tt::runtime::TensorDesc> inputDescs = fbb.getProgramInputs(0); + assert(inputDescs.size() == 2); std::vector<::tt::runtime::TensorDesc> outputDescs = fbb.getProgramOutputs(0); - std::vector<::tt::runtime::Tensor> inputTensors, outputTensors; - + assert(outputDescs.size() == 1); + std::vector<::tt::runtime::Tensor> inputTensors; std::uint32_t tensorSize = inputDescs[0].itemsize; for (const int dim : inputDescs[0].shape) { tensorSize *= dim; @@ -38,26 +39,27 @@ TEST(TTNNSubtract, Equal) { std::memset(data.get(), 1, tensorSize); inputTensors.emplace_back(::tt::runtime::createTensor(data, desc)); } - for (const auto &desc : outputDescs) { - std::shared_ptr data = - ::tt::runtime::utils::malloc_shared(tensorSize); - // Set to wrong value on purpose here - std::memset(data.get(), 1, tensorSize); - outputTensors.emplace_back(::tt::runtime::createTensor(data, desc)); - } + + std::shared_ptr outputDataPtr = + ::tt::runtime::utils::malloc_shared(tensorSize); + // Set to wrong value on purpose here + std::memset(outputDataPtr.get(), 1, tensorSize); + ::tt::runtime::Tensor outputTensor = + ::tt::runtime::createTensor(outputDataPtr, outputDescs[0]); size_t numDevices = ::tt::runtime::getNumAvailableDevices(); std::vector deviceIds(numDevices); std::iota(deviceIds.begin(), deviceIds.end(), 0); - auto device = ::tt::runtime::openDevice(deviceIds); - auto ev = ::tt::runtime::submit(device, fbb, 0, inputTensors, outputTensors); + auto device = ::tt::runtime::openDevice({deviceIds[0]}); + std::vector<::tt::runtime::Tensor> output = + ::tt::runtime::submit(device, fbb, 0, inputTensors); ::tt::runtime::closeDevice(device); - + assert(output.size() == 1); std::shared_ptr expected = ::tt::runtime::utils::malloc_shared(tensorSize); std::memset(expected.get(), 0, tensorSize); - for (const auto &outputTensor : outputTensors) { - EXPECT_EQ(std::memcmp(outputTensor.data.get(), expected.get(), tensorSize), - 0); - } + ::tt::runtime::Tensor submitOutput = output[0]; + ASSERT_NE(std::memcmp(outputDataPtr.get(), expected.get(), tensorSize), 0); + ::tt::runtime::memcpy(outputTensor, submitOutput); + ASSERT_EQ(std::memcmp(outputDataPtr.get(), expected.get(), tensorSize), 0); } diff --git a/runtime/tools/python/CMakeLists.txt b/runtime/tools/python/CMakeLists.txt index 353ebbe7d..966ee9681 100644 --- a/runtime/tools/python/CMakeLists.txt +++ b/runtime/tools/python/CMakeLists.txt @@ -9,6 +9,7 @@ add_custom_target(ttrt COMMAND TTMLIR_ENABLE_RUNTIME=${TTMLIR_ENABLE_RUNTIME} TT_RUNTIME_ENABLE_TTNN=${TT_RUNTIME_ENABLE_TTNN} TT_RUNTIME_ENABLE_TTMETAL=${TT_RUNTIME_ENABLE_TTMETAL} + TTMLIR_ENABLE_RUNTIME_TESTS=${TTMLIR_ENABLE_RUNTIME_TESTS} TT_RUNTIME_ENABLE_PERF_TRACE=${TT_RUNTIME_ENABLE_PERF_TRACE} TT_RUNTIME_DEBUG=${TT_RUNTIME_DEBUG} TT_RUNTIME_WORKAROUNDS=${TT_RUNTIME_WORKAROUNDS} diff --git a/runtime/tools/python/setup.py b/runtime/tools/python/setup.py index f5d148578..e22783502 100644 --- a/runtime/tools/python/setup.py +++ b/runtime/tools/python/setup.py @@ -31,6 +31,7 @@ enable_runtime = os.environ.get("TTMLIR_ENABLE_RUNTIME", "OFF") == "ON" enable_ttnn = os.environ.get("TT_RUNTIME_ENABLE_TTNN", "OFF") == "ON" enable_ttmetal = os.environ.get("TT_RUNTIME_ENABLE_TTMETAL", "OFF") == "ON" +enable_runtime_tests = os.environ.get("TTMLIR_ENABLE_RUNTIME_TESTS", "OFF") == "ON" enable_perf = os.environ.get("TT_RUNTIME_ENABLE_PERF_TRACE", "OFF") == "ON" debug_runtime = os.environ.get("TT_RUNTIME_DEBUG", "OFF") == "ON" configure_workarounds_runtime = os.environ.get("TT_RUNTIME_WORKAROUNDS", "OFF") == "ON" @@ -64,7 +65,15 @@ linklibs = ["TTBinary"] if enable_ttnn: runlibs += ["_ttnn.so"] - linklibs += ["TTRuntimeTTNN", "TTRuntimeTTNNOps", ":_ttnn.so"] + linklibs += [ + "TTRuntimeTTNN", + "TTRuntimeTTNNOps", + "TTRuntimeTTNNHelpers", + ":_ttnn.so", + ] + +if enable_ttnn and enable_runtime_tests: + linklibs += ["TTRuntimeTTNNTestHelpers"] if enable_ttmetal: runlibs += ["libtt_metal.so"] @@ -237,6 +246,7 @@ def package_files(directory): f"{ttmlir_build_dir}/runtime/lib/ttnn", f"{ttmlir_build_dir}/runtime/lib/ttnn/operations", f"{ttmlir_build_dir}/runtime/lib/ttmetal", + f"{ttmlir_build_dir}/runtime/test", f"{toolchain}/lib", f"{ttmlir_build_dir}/runtime/tools/python/ttrt/runtime", f"{metaldir}/lib", @@ -248,6 +258,7 @@ def package_files(directory): "TT_RUNTIME_WORKAROUNDS", "1" if configure_workarounds_runtime else "0", ), + ("TTMLIR_ENABLE_RUNTIME_TESTS", "1" if enable_runtime_tests else "0"), ], ) ) diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index c2ae10ac9..be9711587 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -380,6 +380,7 @@ def _execute(binaries): self.logging.debug(f"setting torch manual seed={self['--seed']}") torch.manual_seed(self["--seed"]) ttrt.runtime.set_compatible_runtime(binaries[0].fbb) + current_runtime = ttrt.runtime.get_current_runtime() self.logging.debug(f"opening devices={self.query.device_ids}") device = ttrt.runtime.open_device(self.query.device_ids) @@ -459,20 +460,43 @@ def _execute(binaries): self.logging.debug( f"starting loop={loop+1}/{self['--loops']} for binary={bin.file_path}" ) + if ( + current_runtime + == ttrt.runtime.DeviceRuntime.TTMetal + ): + event = ttrt.runtime.submit( + device, + bin.fbb, + program_index, + total_inputs[loop], + total_outputs[loop], + ) - event = ttrt.runtime.submit( - device, - bin.fbb, - program_index, - total_inputs[loop], - total_outputs[loop], - ) + elif current_runtime == ttrt.runtime.DeviceRuntime.TTNN: + runtime_outputs = ttrt.runtime.submit( + device, + bin.fbb, + program_index, + total_inputs[loop], + ) + ttrt.runtime.wait(runtime_outputs) + for i, runtime_output_tensor in enumerate( + runtime_outputs + ): + ttrt.runtime.memcpy( + total_outputs[loop][i], + runtime_output_tensor, + ) + ttrt.runtime.deallocate_tensor( + runtime_output_tensor, force=True + ) self.logging.debug( f"finished loop={loop+1}/{self['--loops']} for binary={bin.file_path}" ) - ttrt.runtime.wait(event) + if event is not None: + ttrt.runtime.wait(event) if self["--identity"]: self.logging.debug( diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 370643e7d..45e0a9db9 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -586,6 +586,12 @@ def __init__(self, index, program): self.input_tensors = [] self.output_tensors = [] + def num_inputs(self): + return len(self.program["inputs"]) + + def num_outputs(self): + return len(self.program["outputs"]) + def populate_inputs(self, init_fn, golden_inputs=[]): if len(golden_inputs) > 0: assert len(golden_inputs) == len(self.program["inputs"]) diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index 642b0401f..0376c07b5 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -12,19 +12,33 @@ DebugEnv, DebugHooks, get_current_runtime, + set_current_runtime, set_compatible_runtime, get_current_system_desc, open_device, close_device, submit, create_tensor, + create_empty_tensor, create_multi_device_tensor, wait, + to_host, + to_layout, + get_layout, get_op_output_tensor, get_op_debug_str, + memcpy, + deallocate_tensor, WorkaroundEnv, ) except ModuleNotFoundError: raise ImportError( "Error: Project was not built with runtime enabled, rebuild with: -DTTMLIR_ENABLE_RUNTIME=ON" ) + +try: + from ._C import testing +except ImportError: + print( + "Warning: not importing testing submodule since project was not built with runtime testing enabled. To enable, rebuild with: -DTTMLIR_ENABLE_RUNTIME_TESTS=ON" + ) diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index c0378727c..e1db607c5 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -8,6 +8,9 @@ #include "tt/runtime/detail/workarounds.h" #include "tt/runtime/runtime.h" #include "tt/runtime/utils.h" +#if defined(TTMLIR_ENABLE_RUNTIME_TESTS) && TTMLIR_ENABLE_RUNTIME_TESTS == 1 +#include "tt/runtime/test/utils.h" +#endif #include #include @@ -22,6 +25,7 @@ PYBIND11_MODULE(_C, m) { .def("deallocate_buffers", &tt::runtime::detail::deallocateBuffers); py::class_(m, "Event"); py::class_(m, "Tensor"); + py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext"); py::enum_<::tt::target::DataType>(m, "DataType") @@ -48,6 +52,8 @@ PYBIND11_MODULE(_C, m) { m.def("set_compatible_runtime", &tt::runtime::setCompatibleRuntime, py::arg("binary"), "Set the backend device runtime type to match the binary"); + m.def("set_current_runtime", &tt::runtime::setCurrentRuntime, + py::arg("runtime"), "Set the backend device runtime type"); m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, "Get the current system descriptor"); m.def( @@ -61,6 +67,15 @@ PYBIND11_MODULE(_C, m) { shape, stride, itemsize, dataType); }, "Create a tensor with borrowed memory"); + m.def( + "create_empty_tensor", + [](::tt::runtime::Device device, ::tt::runtime::Layout layout, + std::vector const &shape, + std::vector const &stride, std::uint32_t itemsize) { + return tt::runtime::createTensor(device, layout, shape, stride, + itemsize); + }, + "Create an empty tensor with the specified layout"); m.def( "create_multi_device_tensor", [](std::vector &ptrs, @@ -69,8 +84,8 @@ PYBIND11_MODULE(_C, m) { ::tt::target::DataType dataType, std::unordered_map const &strategy) { std::vector> data; - data.resize(ptrs.size()); - std::transform(ptrs.begin(), ptrs.end(), data.begin(), + data.reserve(ptrs.size()); + std::transform(ptrs.begin(), ptrs.end(), std::back_inserter(data), [](std::uintptr_t ptr) { return ::tt::runtime::utils::unsafe_borrow_shared( reinterpret_cast(ptr)); @@ -85,10 +100,50 @@ PYBIND11_MODULE(_C, m) { py::arg("num_hw_cqs") = size_t{1}, "Open a mesh of devices for execution"); m.def("close_device", &tt::runtime::closeDevice, "Close a mesh device"); - m.def("submit", &tt::runtime::submit, py::arg("device"), - py::arg("executable"), py::arg("program_index"), py::arg("inputs"), - py::arg("outputs"), "Submit a binary for execution"); - m.def("wait", &tt::runtime::wait, py::arg("event")); + m.def("to_host", &tt::runtime::toHost, py::arg("tensor"), + py::arg("untilize") = false, "Copy the tensor to the host"); + m.def("to_layout", &tt::runtime::toLayout, py::arg("tensor"), + py::arg("device"), py::arg("layout"), + "Create a copy of the tensor with the specified layout"); + m.def("get_layout", &tt::runtime::getLayout, py::arg("executable"), + py::arg("program_index"), py::arg("input_index"), + "Get the layout of the input tensor"); + m.def( + "submit", + [](::tt::runtime::Device device, ::tt::runtime::Binary executable, + std::uint32_t programIndex, + const std::vector<::tt::runtime::Tensor> &inputs) + -> std::vector<::tt::runtime::Tensor> { + return ::tt::runtime::submit(device, executable, programIndex, inputs); + }, + py::arg("device"), py::arg("executable"), py::arg("program_index"), + py::arg("inputs"), + "Submit a ttnn binary for execution, returns a vector of output tensors"); + m.def( + "submit", + [](::tt::runtime::Device device, ::tt::runtime::Binary executable, + std::uint32_t programIndex, + const std::vector<::tt::runtime::Tensor> &inputs, + const std::vector<::tt::runtime::Tensor> &outputs) + -> ::tt::runtime::Event { + return ::tt::runtime::submit(device, executable, programIndex, inputs, + outputs); + }, + py::arg("device"), py::arg("executable"), py::arg("program_index"), + py::arg("inputs"), py::arg("outputs"), + "Submit a ttmetal binary for execution. returns event wrapper"); + m.def( + "wait", [](::tt::runtime::Event event) { ::tt::runtime::wait(event); }, + py::arg("event")); + m.def( + "wait", [](::tt::runtime::Tensor tensor) { ::tt::runtime::wait(tensor); }, + py::arg("tensor")); + m.def( + "wait", + [](const std::vector<::tt::runtime::Tensor> &tensors) { + ::tt::runtime::wait(tensors); + }, + py::arg("tensors")); m.def( "get_op_output_tensor", [](tt::runtime::OpContext &opContextHandle, @@ -102,7 +157,15 @@ PYBIND11_MODULE(_C, m) { "Get the debug string of the op"); m.def("get_op_loc_info", &tt::runtime::getOpLocInfo, "Get the location info of the op"); - + m.def( + "memcpy", + [](::tt::runtime::Tensor dst, ::tt::runtime::Tensor src) { + ::tt::runtime::memcpy(dst, src); + }, + py::arg("dst"), py::arg("src"), + "Copy the data from src tensor to dst tensor"); + m.def("deallocate_tensor", &tt::runtime::deallocateTensor, py::arg("tensor"), + py::arg("force") = false, "Deallocate the tensor memory"); py::class_(m, "DebugEnv") .def_static("get", &tt::runtime::debug::Env::get) .def("__str__", [](const tt::runtime::debug::Env &env) { @@ -138,4 +201,17 @@ PYBIND11_MODULE(_C, m) { os << env; return os.str(); }); + +#if defined(TTMLIR_ENABLE_RUNTIME_TESTS) && TTMLIR_ENABLE_RUNTIME_TESTS == 1 + auto testing = m.def_submodule("testing"); + testing.def("get_dram_interleaved_tile_layout", + &tt::runtime::ttnn::test::getDramInterleavedTileLayout, + py::arg("dtype"), "Get dram interleaved tile layout"); + testing.def("get_dram_interleaved_row_major_layout", + &tt::runtime::ttnn::test::getDramInterleavedRowMajorLayout, + py::arg("dtype"), "Get dram interleaved row major layout"); + testing.def("get_host_row_major_layout", + &tt::runtime::ttnn::test::getHostRowMajorLayout, py::arg("dtype"), + "Get host row major layout"); +#endif } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir index e819e68f4..3089da669 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir @@ -1,15 +1,15 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { - func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { + func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir new file mode 100644 index 000000000..97690df78 --- /dev/null +++ b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir @@ -0,0 +1,49 @@ +// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" %s > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +// TODO: this is a workaround for compiler assuming input tensors are always on host. The ideal is to directly compile ttir graphs. +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#system_memory = #ttnn.buffer_type +#dram = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, interleaved> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, interleaved> + +module attributes {tt.device = #device} { + func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #dram, <<64x128>>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} + +module attributes {tt.device = #device} { + func.func @multiply(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #dram, <<64x128>>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} + +module attributes {tt.device = #device} { + func.func @subtract(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, #dram, <<64x128>>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> + %4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> + %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> + %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> + return %6 : tensor<64x128xbf16, #ttnn_layout> + } +} diff --git a/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir b/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir index 04b9f1fef..35682c8c0 100644 --- a/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/Unary/isfinite_op.mlir @@ -7,14 +7,14 @@ // RUN: FileCheck --input-file=%t.mlir %s module @jit_eltwise_isfinite attributes {} { - func.func public @test_isfinite(%arg0: tensor<64x128xf32>) -> tensor<64x128xi1> { + func.func public @test_isfinite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xi1> { // CHECK-LABEL: func.func public @test_isfinite // CHECK: ttnn.empty // CHECK: ttnn.isfinite - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: -> tensor<64x128xbf16, - %0 = stablehlo.is_finite %arg0 : (tensor<64x128xf32>) -> tensor<64x128xi1> + %0 = stablehlo.is_finite %arg0 : (tensor<64x128xbf16>) -> tensor<64x128xi1> return %0 : tensor<64x128xi1> } } diff --git a/test/ttmlir/Silicon/StableHLO/select_op.mlir b/test/ttmlir/Silicon/StableHLO/select_op.mlir index 23b7182ce..1cdc5e9d0 100644 --- a/test/ttmlir/Silicon/StableHLO/select_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/select_op.mlir @@ -6,23 +6,23 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module @jit_eltwise_select attributes {} { - func.func public @test_select(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + func.func public @test_select(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK-LABEL: func.func public @test_select // CHECK: tensor.empty // CHECK: [[EQ:{{0-9}}+]] = "ttnn.eq" - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: -> tensor<64x128xbf16 - %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xi1> + %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xi1> // CHECK: ttnn.where // CHECK-SAME: [[EQ]] // CHECK-SAME: tensor<64x128xbf16 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: tensor<64x128xf32 - // CHECK-SAME: -> tensor<64x128xf32 - %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<64x128xi1>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: tensor<64x128xbf16 + // CHECK-SAME: -> tensor<64x128xbf16 + %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<64x128xi1>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir index ce0146be4..f1489a5eb 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir @@ -4,14 +4,14 @@ #any_device = #tt.operand_constraint #any_device_tile = #tt.operand_constraint -func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { +func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir deleted file mode 100644 index 79de8c062..000000000 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_le.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - -module attributes {} { - func.func @less_equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty - // CHECK-SAME: [[TENSOR:tensor<13x31xf32,]] - %0 = tensor.empty() : tensor<13x31xf32> - // CHECK: %[[C:.*]] = "ttnn.le" - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> - return %1 : tensor<13x31xf32> - } -} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir index 3bed0528c..647f94e61 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir @@ -4,13 +4,13 @@ #any_device = #tt.operand_constraint #any_device_tile = #tt.operand_constraint -func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { +func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - %2 = tensor.empty() : tensor<13x37xf32> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) - return %3 : tensor<13x37xf32> + return %3 : tensor<13x37xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index b7912d4c1..2674a66fd 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -65,15 +65,15 @@ func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { return %1 : tensor<64x128xf32> } -func.func @is_finite(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { +func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -278,15 +278,15 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { // CHECK: return [[VAL]] : tensor<1xi32, {{.*}}> } -func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { +func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - %2 = tensor.empty() : tensor<13x37xf32> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) - return %3 : tensor<13x37xf32> + return %3 : tensor<13x37xbf16> } func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {