From 4b4fa3d2dcd863199d39579fe470df1a1931013d Mon Sep 17 00:00:00 2001 From: Raymond Kim Date: Thu, 12 Dec 2024 16:26:25 +0000 Subject: [PATCH 01/13] #0: [skip ci] Bump device perf timeout for WH because we keep inching up, prob because more tests --- .github/workflows/perf-device-models-impl.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/perf-device-models-impl.yaml b/.github/workflows/perf-device-models-impl.yaml index 7757285dd67..ee56a304c2f 100644 --- a/.github/workflows/perf-device-models-impl.yaml +++ b/.github/workflows/perf-device-models-impl.yaml @@ -12,7 +12,7 @@ jobs: matrix: test-info: [ {name: "GS", arch: grayskull, runs-on: ["perf-no-reset-grayskull", "bare-metal", "in-service"], machine-type: "bare_metal", timeout: 40}, - {name: "N300 WH B0", arch: wormhole_b0, runs-on: ["N300", "pipeline-perf", "bare-metal", "in-service"], machine-type: "bare_metal", timeout: 30}, + {name: "N300 WH B0", arch: wormhole_b0, runs-on: ["N300", "pipeline-perf", "bare-metal", "in-service"], machine-type: "bare_metal", timeout: 40}, ] name: "${{ matrix.test-info.name }} device perf" env: From ce20864fbb282a1f0bf8a80ebca1f143b0ba7f9a Mon Sep 17 00:00:00 2001 From: Samarth Agarwal Date: Thu, 12 Dec 2024 11:48:19 -0500 Subject: [PATCH 02/13] Prepend DPrint messages with device id, core coords, and RISC (#15790) ### Ticket #14487 --- .../source/tt-metalium/tools/kernel_print.rst | 15 +-- .../tt_metal/debug_tools/CMakeLists.txt | 3 +- .../debug_tools/debug_tools_fixture.hpp | 4 + .../test_print_prepend_device_core_risc.cpp | 113 ++++++++++++++++++ .../test_kernels/misc/print_simple.cpp | 23 ++++ tt_metal/impl/debug/dprint_server.cpp | 54 +++++++-- tt_metal/llrt/rtoptions.cpp | 7 ++ tt_metal/llrt/rtoptions.hpp | 21 ++-- 8 files changed, 215 insertions(+), 25 deletions(-) create mode 100644 tests/tt_metal/tt_metal/debug_tools/dprint/test_print_prepend_device_core_risc.cpp create mode 100644 tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp diff --git a/docs/source/tt-metalium/tools/kernel_print.rst b/docs/source/tt-metalium/tools/kernel_print.rst index 58edf78f38f..3d6844cb2f2 100644 --- a/docs/source/tt-metalium/tools/kernel_print.rst +++ b/docs/source/tt-metalium/tools/kernel_print.rst @@ -21,12 +21,13 @@ Note that the core coordinates are logical coordinates, so worker cores and ethe .. code-block:: - export TT_METAL_DPRINT_CORES=0,0 # required, x,y OR (x1,y1),(x2,y2),(x3,y3) OR (x1,y1)-(x2,y2) OR all OR worker OR dispatch - export TT_METAL_DPRINT_ETH_CORES=0,0 # optional, x,y OR (x1,y1),(x2,y2),(x3,y3) OR (x1,y1)-(x2,y2) OR all OR worker OR dispatch - export TT_METAL_DPRINT_CHIPS=0 # optional, comma separated list of chips - export TT_METAL_DPRINT_RISCVS=BR # optional, default is all RISCs. Use a subset of BR,NC,TR0,TR1,TR2 - export TT_METAL_DPRINT_FILE=log.txt # optional, default is to print to the screen - export TT_METAL_DPRINT_ONE_FILE_PER_RISC=1 # optional, splits DPRINT data on a per-RISC basis into files under $TT_METAL_HOME/generated/dprint/. Overrides TT_METAL_DPRINT_FILE. + export TT_METAL_DPRINT_CORES=0,0 # required, x,y OR (x1,y1),(x2,y2),(x3,y3) OR (x1,y1)-(x2,y2) OR all OR worker OR dispatch + export TT_METAL_DPRINT_ETH_CORES=0,0 # optional, x,y OR (x1,y1),(x2,y2),(x3,y3) OR (x1,y1)-(x2,y2) OR all OR worker OR dispatch + export TT_METAL_DPRINT_CHIPS=0 # optional, comma separated list of chips + export TT_METAL_DPRINT_RISCVS=BR # optional, default is all RISCs. Use a subset of BR,NC,TR0,TR1,TR2 + export TT_METAL_DPRINT_FILE=log.txt # optional, default is to print to the screen + export TT_METAL_DPRINT_PREPEND_DEVICE_CORE_RISC=0 # optional, enabled by default. Prepends prints with :(, )::. + export TT_METAL_DPRINT_ONE_FILE_PER_RISC=1 # optional, splits DPRINT data on a per-RISC basis into files under $TT_METAL_HOME/generated/dprint/. Overrides TT_METAL_DPRINT_FILE and disables TT_METAL_DPRINT_PREPEND_DEVICE_CORE_RISC. To generate kernel debug prints on the device, include the ``debug/dprint.h`` header and use the APIs defined there. An example with the different features available is shown below: @@ -122,4 +123,4 @@ formats for printing from CBs are ``DataFormat::Float32``, ``DataFormat::Float16 } .. note:: - Note that the DPRINT buffer for a RISC is flushed when ``ENDL()`` is called, a ``\n`` character is read, or the device that the RISC belongs to is closed. + The DPRINT buffer for a RISC is only flushed when ``ENDL()`` is called, a ``\n`` character is read, or the device that the RISC belongs to is closed. diff --git a/tests/tt_metal/tt_metal/debug_tools/CMakeLists.txt b/tests/tt_metal/tt_metal/debug_tools/CMakeLists.txt index 6cafd9ea200..76f23590016 100644 --- a/tests/tt_metal/tt_metal/debug_tools/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/debug_tools/CMakeLists.txt @@ -7,6 +7,7 @@ set(UNIT_TESTS_DEBUG_TOOLS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_all_harts.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_before_finish.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_hanging.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_prepend_device_core_risc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_tensix_dest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_print_tiles.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dprint/test_raise_wait.cpp @@ -22,7 +23,7 @@ set(UNIT_TESTS_DEBUG_TOOLS_SRC add_executable(unit_tests_debug_tools ${UNIT_TESTS_DEBUG_TOOLS_SRC}) TT_ENABLE_UNITY_BUILD(unit_tests_debug_tools) -target_link_libraries(unit_tests_debug_tools PUBLIC test_metal_common_libs) +target_link_libraries(unit_tests_debug_tools PRIVATE test_metal_common_libs) target_include_directories( unit_tests_debug_tools PRIVATE diff --git a/tests/tt_metal/tt_metal/debug_tools/debug_tools_fixture.hpp b/tests/tt_metal/tt_metal/debug_tools/debug_tools_fixture.hpp index 9d4f6714258..481bdfac670 100644 --- a/tests/tt_metal/tt_metal/debug_tools/debug_tools_fixture.hpp +++ b/tests/tt_metal/tt_metal/debug_tools/debug_tools_fixture.hpp @@ -46,6 +46,8 @@ class DPrintFixture : public DebugToolsFixture { // used by all tests using this fixture, so set dprint enabled for // all cores and all devices tt::llrt::RunTimeOptions::get_instance().set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, true); + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, false); tt::llrt::RunTimeOptions::get_instance().set_feature_all_cores( tt::llrt::RunTimeDebugFeatureDprint, CoreType::WORKER, tt::llrt::RunTimeDebugClassWorker); tt::llrt::RunTimeOptions::get_instance().set_feature_all_cores( @@ -79,6 +81,8 @@ class DPrintFixture : public DebugToolsFixture { tt::llrt::RunTimeDebugFeatureDprint, CoreType::ETH, tt::llrt::RunTimeDebugClassNoneSpecified); tt::llrt::RunTimeOptions::get_instance().set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, false); tt::llrt::RunTimeOptions::get_instance().set_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint, ""); + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, true); tt::llrt::RunTimeOptions::get_instance().set_test_mode_enabled(false); } diff --git a/tests/tt_metal/tt_metal/debug_tools/dprint/test_print_prepend_device_core_risc.cpp b/tests/tt_metal/tt_metal/debug_tools/dprint/test_print_prepend_device_core_risc.cpp new file mode 100644 index 00000000000..7d3f55f0ca5 --- /dev/null +++ b/tests/tt_metal/tt_metal/debug_tools/dprint/test_print_prepend_device_core_risc.cpp @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "core_coord.hpp" +#include "debug_tools_fixture.hpp" +#include "gtest/gtest.h" +#include "debug_tools_test_utils.hpp" +#include "kernels/kernel_types.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/host_api.hpp" + +//////////////////////////////////////////////////////////////////////////////// +// A test for checking that prints are prepended with their corresponding device, core and RISC. +//////////////////////////////////////////////////////////////////////////////// +using namespace tt; +using namespace tt::tt_metal; + +namespace { +namespace CMAKE_UNIQUE_NAMESPACE { +static void UpdateGoldenOutput(std::vector& golden_output, const Device* device, const string& risc) { + // Using wildcard characters in lieu of actual values for the physical coordinates as physical coordinates can vary + // by machine + const string& device_core_risc = std::to_string(device->id()) + ":(x=*,y=*):" + risc + ": "; + + const string& output_line_all_riscs = device_core_risc + "Printing on a RISC."; + golden_output.push_back(output_line_all_riscs); + + if (risc != "ER") { + const string& output_line_risc = device_core_risc + "Printing on " + risc + "."; + golden_output.push_back(output_line_risc); + } +} + +static void RunTest(DPrintFixture* fixture, Device* device, const bool add_active_eth_kernel = false) { + std::vector golden_output; + + CoreRange cores({0, 0}, {0, 1}); + Program program = Program(); + + KernelHandle brisc_kernel_id = CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp", + cores, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + KernelHandle ncrisc_kernel_id = CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp", + cores, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); + + KernelHandle trisc_kernel_id = + CreateKernel(program, "tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp", cores, ComputeConfig{}); + + for (const CoreCoord& core : cores) { + UpdateGoldenOutput(golden_output, device, "BR"); + UpdateGoldenOutput(golden_output, device, "NC"); + UpdateGoldenOutput(golden_output, device, "TR0"); + UpdateGoldenOutput(golden_output, device, "TR1"); + UpdateGoldenOutput(golden_output, device, "TR2"); + } + + if (add_active_eth_kernel) { + const std::unordered_set& active_eth_cores = device->get_active_ethernet_cores(true); + CoreRangeSet crs(std::set(active_eth_cores.begin(), active_eth_cores.end())); + KernelHandle erisc_kernel_id = CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp", + crs, + EthernetConfig{.noc = NOC::NOC_0}); + + for (const CoreCoord& core : active_eth_cores) { + UpdateGoldenOutput(golden_output, device, "ER"); + } + } + + fixture->RunProgram(device, program); + + // Check the print log against golden output. + EXPECT_TRUE(FileContainsAllStrings(DPrintFixture::dprint_file_name, golden_output)); +} +} // namespace CMAKE_UNIQUE_NAMESPACE +} // namespace + +TEST_F(DPrintFixture, TensixTestPrintPrependDeviceCoreRisc) { + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, true); + for (Device* device : this->devices_) { + this->RunTestOnDevice( + [](DPrintFixture* fixture, Device* device) { CMAKE_UNIQUE_NAMESPACE::RunTest(fixture, device); }, device); + } + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, false); +} + +TEST_F(DPrintFixture, TensixActiveEthTestPrintPrependDeviceCoreRisc) { + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, true); + for (Device* device : this->devices_) { + if (device->get_active_ethernet_cores(true).empty()) { + log_info(tt::LogTest, "Skipping device {} due to no active ethernet cores...", device->id()); + continue; + } + this->RunTestOnDevice( + [](DPrintFixture* fixture, Device* device) { CMAKE_UNIQUE_NAMESPACE::RunTest(fixture, device, true); }, + device); + } + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc( + tt::llrt::RunTimeDebugFeatureDprint, false); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp b/tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp new file mode 100644 index 00000000000..3a55bc909cd --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/misc/print_simple.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "debug/dprint.h" + +#if defined(COMPILE_FOR_TRISC) +#include "compute_kernel_api/common.h" +namespace NAMESPACE { +void MAIN { +#else +void kernel_main() { +#endif + DPRINT << "Printing on a RISC." << ENDL(); + DPRINT_UNPACK(DPRINT << "Printing on TR0." << ENDL();); + DPRINT_MATH(DPRINT << "Printing on TR1." << ENDL();); + DPRINT_PACK(DPRINT << "Printing on TR2." << ENDL();); + DPRINT_DATA0(DPRINT << "Printing on BR." << ENDL();); + DPRINT_DATA1(DPRINT << "Printing on NC." << ENDL();); +} +#if defined(COMPILE_FOR_TRISC) +} +#endif diff --git a/tt_metal/impl/debug/dprint_server.cpp b/tt_metal/impl/debug/dprint_server.cpp index d0e62e69210..8b875682fac 100644 --- a/tt_metal/impl/debug/dprint_server.cpp +++ b/tt_metal/impl/debug/dprint_server.cpp @@ -57,21 +57,21 @@ static inline float bfloat16_to_float(uint16_t bfloat_val) { return f; } -static string GetRiscName(CoreType core_type, int hart_id) { +static string GetRiscName(CoreType core_type, int hart_id, bool abbreviated = false) { if (core_type == CoreType::ETH) { switch (hart_id) { case DPRINT_RISCV_INDEX_ER: - return "ERISC"; + return abbreviated ? "ER" : "ERISC"; // Default case falls through and handled at end. } } else { switch (hart_id) { - case DPRINT_RISCV_INDEX_NC: return "NCRISC"; - case DPRINT_RISCV_INDEX_TR0: return "TRISC0"; - case DPRINT_RISCV_INDEX_TR1: return "TRISC1"; - case DPRINT_RISCV_INDEX_TR2: return "TRISC2"; + case DPRINT_RISCV_INDEX_NC: return abbreviated ? "NC" : "NCRISC"; + case DPRINT_RISCV_INDEX_TR0: return abbreviated ? "TR0" : "TRISC0"; + case DPRINT_RISCV_INDEX_TR1: return abbreviated ? "TR1" : "TRISC1"; + case DPRINT_RISCV_INDEX_TR2: return abbreviated ? "TR2" : "TRISC2"; case DPRINT_RISCV_INDEX_BR: - return "BRISC"; + return abbreviated ? "BR" : "BRISC"; // Default case falls through and handled at end. } } @@ -226,6 +226,9 @@ struct DebugPrintServerContext { // data is visible to the user. void TransferToAndFlushOutputStream(const HartKey& hart_key, ostringstream* intermediate_stream); + // Returns the dprint data that should be outputted by the output stream. + string GetDataToOutput(const HartKey& hart_key, const ostringstream* stream); + // Returns the stream that the dprint data should be output to. Can be auto-generated files, the user-selected file, // stdout, or nothing. ostream* GetOutputStream(const HartKey& hart_key); @@ -466,6 +469,8 @@ DebugPrintServerContext::DebugPrintServerContext() { tt::llrt::RunTimeOptions::get_instance().get_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint); bool one_file_per_risc = tt::llrt::RunTimeOptions::get_instance().get_feature_one_file_per_risc(tt::llrt::RunTimeDebugFeatureDprint); + bool prepend_device_core_risc = + tt::llrt::RunTimeOptions::get_instance().get_feature_prepend_device_core_risc(tt::llrt::RunTimeDebugFeatureDprint); // One file per risc auto-generates the output files and ignores the env var for it. Print a warning if both are // specified just in case. @@ -475,6 +480,13 @@ DebugPrintServerContext::DebugPrintServerContext() { "TT_METAL_DPRINT_FILE_NAME will be ignored."); } + if (prepend_device_core_risc && one_file_per_risc) { + log_warning( + "Both TT_METAL_DPRINT_PREPEND_DEVICE_CORE_RISC and TT_METAL_DPRINT_ONE_FILE_PER_RISC are specified. " + "TT_METAL_DPRINT_PREPEND_DEVICE_CORE_RISC will be disabled."); + tt::llrt::RunTimeOptions::get_instance().set_feature_prepend_device_core_risc(tt::llrt::RunTimeDebugFeatureDprint, false); + } + // Set the output stream according to RTOptions, either a file name or stdout if none specified. std::filesystem::path output_dir(tt::llrt::RunTimeOptions::get_instance().get_root_dir() + logfile_path); std::filesystem::create_directories(output_dir); @@ -1190,12 +1202,36 @@ void DebugPrintServerContext::TransferIntermediateStreamsToOutputStreamAndFlush( void DebugPrintServerContext::TransferToAndFlushOutputStream( const HartKey& hart_key, ostringstream* intermediate_stream) { - const string& intermediate_stream_data = intermediate_stream->str(); + const string& output_data = GetDataToOutput(hart_key, intermediate_stream); ostream* output_stream = GetOutputStream(hart_key); - *output_stream << intermediate_stream_data << flush; + *output_stream << output_data << flush; ResetStream(intermediate_stream); } // TransferToAndFlushOutputStream +string DebugPrintServerContext::GetDataToOutput(const HartKey& hart_key, const ostringstream* stream) { + string output; + const bool prepend_device_core_risc = + tt::llrt::RunTimeOptions::get_instance().get_feature_prepend_device_core_risc(tt::llrt::RunTimeDebugFeatureDprint); + if (prepend_device_core_risc) { + const chip_id_t device_id = get<0>(hart_key); + const CoreDescriptor& core_desc = get<1>(hart_key); + const uint32_t risc_id = get<2>(hart_key); + + const string& device_id_str = to_string(device_id); + const string& core_coord_str = core_desc.coord.str(); + const string& risc_name = GetRiscName(core_desc.type, risc_id, true); + output += fmt::format("{}:{}:{}: ", device_id_str, core_coord_str, risc_name); + } + + if (stream->str().empty()) { + output = ""; + } else { + output += stream->str(); + } + + return output; +} + ostream* DebugPrintServerContext::GetOutputStream(const HartKey& hart_key) { ostream* output_stream = stream_; if (tt::llrt::RunTimeOptions::get_instance().get_feature_one_file_per_risc(tt::llrt::RunTimeDebugFeatureDprint)) { diff --git a/tt_metal/llrt/rtoptions.cpp b/tt_metal/llrt/rtoptions.cpp index 174bd2cdbce..b49316825fb 100644 --- a/tt_metal/llrt/rtoptions.cpp +++ b/tt_metal/llrt/rtoptions.cpp @@ -204,6 +204,7 @@ void RunTimeOptions::ParseFeatureEnv(RunTimeDebugFeatures feature) { ParseFeatureRiscvMask(feature, feature_env_prefix + "_RISCVS"); ParseFeatureFileName(feature, feature_env_prefix + "_FILE"); ParseFeatureOneFilePerRisc(feature, feature_env_prefix + "_ONE_FILE_PER_RISC"); + ParseFeaturePrependDeviceCoreRisc(feature, feature_env_prefix + "_PREPEND_DEVICE_CORE_RISC"); // Set feature enabled if the user asked for any feature cores feature_targets[feature].enabled = false; @@ -356,6 +357,12 @@ void RunTimeOptions::ParseFeatureOneFilePerRisc(RunTimeDebugFeatures feature, co feature_targets[feature].one_file_per_risc = (env_var_str != nullptr); } +void RunTimeOptions::ParseFeaturePrependDeviceCoreRisc(RunTimeDebugFeatures feature, const std::string &env_var) { + char *env_var_str = std::getenv(env_var.c_str()); + feature_targets[feature].prepend_device_core_risc = + (env_var_str != nullptr) ? (strcmp(env_var_str, "1") == 0) : true; +} + } // namespace llrt } // namespace tt diff --git a/tt_metal/llrt/rtoptions.hpp b/tt_metal/llrt/rtoptions.hpp index b7723f0609c..dcfb336b3b3 100644 --- a/tt_metal/llrt/rtoptions.hpp +++ b/tt_metal/llrt/rtoptions.hpp @@ -81,6 +81,7 @@ struct TargetSelection { uint32_t riscv_mask = 0; std::string file_name; // File name to write output to. bool one_file_per_risc = false; + bool prepend_device_core_risc; }; class RunTimeOptions { @@ -225,6 +226,12 @@ class RunTimeOptions { inline void set_feature_one_file_per_risc(RunTimeDebugFeatures feature, bool one_file_per_risc) { feature_targets[feature].one_file_per_risc = one_file_per_risc; } + inline bool get_feature_prepend_device_core_risc(RunTimeDebugFeatures feature) { + return feature_targets[feature].prepend_device_core_risc; + } + inline void set_feature_prepend_device_core_risc(RunTimeDebugFeatures feature, bool prepend_device_core_risc) { + feature_targets[feature].prepend_device_core_risc = prepend_device_core_risc; + } inline TargetSelection get_feature_targets(RunTimeDebugFeatures feature) { return feature_targets[feature]; } inline void set_feature_targets(RunTimeDebugFeatures feature, TargetSelection targets) { feature_targets[feature] = targets; @@ -293,11 +300,12 @@ class RunTimeOptions { private: // Helper functions to parse feature-specific environment vaiables. void ParseFeatureEnv(RunTimeDebugFeatures feature); - void ParseFeatureCoreRange(RunTimeDebugFeatures feature, const std::string& env_var, CoreType core_type); - void ParseFeatureChipIds(RunTimeDebugFeatures feature, const std::string& env_var); - void ParseFeatureRiscvMask(RunTimeDebugFeatures feature, const std::string& env_var); - void ParseFeatureFileName(RunTimeDebugFeatures feature, const std::string& env_var); - void ParseFeatureOneFilePerRisc(RunTimeDebugFeatures feature, const std::string& env_var); + void ParseFeatureCoreRange(RunTimeDebugFeatures feature, const std::string &env_var, CoreType core_type); + void ParseFeatureChipIds(RunTimeDebugFeatures feature, const std::string &env_var); + void ParseFeatureRiscvMask(RunTimeDebugFeatures feature, const std::string &env_var); + void ParseFeatureFileName(RunTimeDebugFeatures feature, const std::string &env_var); + void ParseFeatureOneFilePerRisc(RunTimeDebugFeatures feature, const std::string &env_var); + void ParseFeaturePrependDeviceCoreRisc(RunTimeDebugFeatures feature, const std::string &env_var); // Helper function to parse watcher-specific environment variables. void ParseWatcherEnv(); @@ -315,9 +323,6 @@ class RunTimeOptions { bool watcher_feature_disabled(const std::string& name) { return watcher_disabled_features.find(name) != watcher_disabled_features.end(); } - - // Helper function to generate a message string when an environment variable has not been set - std::string generate_env_var_not_set_message(const std::string& env_var) const; }; } // namespace llrt From 4cf4a1f656297e77826ef0d79b0535bed3bf9878 Mon Sep 17 00:00:00 2001 From: Ligang Long Date: Thu, 12 Dec 2024 09:35:33 -0800 Subject: [PATCH 03/13] #0: added test case for unaligned padding kernels (#15946) ### Ticket 15602 ### Problem description To validate the unaligned padding kernel functionality ### What's changed Added one simple test case to front pad 4 and back pad 6, testing it's capacity of padding arbitrary size in kernel. ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12286944116 - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- tests/ttnn/unit_tests/operations/test_pad.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ttnn/unit_tests/operations/test_pad.py b/tests/ttnn/unit_tests/operations/test_pad.py index 2ff6a79097a..60a3edc8e33 100644 --- a/tests/ttnn/unit_tests/operations/test_pad.py +++ b/tests/ttnn/unit_tests/operations/test_pad.py @@ -16,7 +16,10 @@ @pytest.mark.parametrize("c", [3]) @pytest.mark.parametrize("h", [230]) @pytest.mark.parametrize("w", [224]) -@pytest.mark.parametrize("padding,torch_padding", [(((0, 1), (3, 25), (32, 32)), (32, 32, 3, 25, 0, 1))]) +@pytest.mark.parametrize( + "padding,torch_padding", + [(((0, 1), (3, 25), (32, 32)), (32, 32, 3, 25, 0, 1)), (((0, 1), (3, 25), (4, 6)), (4, 6, 3, 25, 0, 1))], +) @pytest.mark.parametrize("value", [0]) def test_pad_rm(device, n, c, h, w, padding, torch_padding, value): torch.manual_seed(0) From 04ab0d6e00e391cd1118d103146e32bf0133a564 Mon Sep 17 00:00:00 2001 From: Raymond Kim <109366641+tt-rkim@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:51:17 -0500 Subject: [PATCH 04/13] #15971: [skip ci] Update dead link (#15975) ### Ticket #15971 ### Problem description @dsklavos reported a dead link to resnet Fixed ### What's changed cc: @mywoodstock @LPanosTT ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12301771527 - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- docs/source/ttnn/ttnn/get_started.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/ttnn/ttnn/get_started.rst b/docs/source/ttnn/ttnn/get_started.rst index c9a7adc7322..0c2cd2e4de2 100644 --- a/docs/source/ttnn/ttnn/get_started.rst +++ b/docs/source/ttnn/ttnn/get_started.rst @@ -26,7 +26,7 @@ Get started with the Falcon 7B demo. Navigate to the `Falcon 7B demo folder for details. You can also check our demos for -`ResNet `_, +`ResNet `_, `BERT `_, `Mistral 7B `_, and From 9a2f119376e4565f932560a450de8b2d7f23c72b Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Thu, 12 Dec 2024 10:11:24 -0800 Subject: [PATCH 05/13] [TT-Train] Added LR Schedulers and updated serialization (#15625) ### Problem description Added LR Schedulers, similar to the pytorch's: * Linear Scheduler * StepScheduler * LambdaScheduler * SequentialScheduler ### What's changed ### Checklist - [X] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12285600185 --- tt-train/sources/examples/mnist_mlp/main.cpp | 10 +- tt-train/sources/examples/mnist_mlp/utils.hpp | 4 +- tt-train/sources/examples/nano_gpt/main.cpp | 26 +- tt-train/sources/examples/nano_gpt/utils.cpp | 19 ++ tt-train/sources/examples/nano_gpt/utils.hpp | 33 ++- .../sources/ttml/autograd/module_base.cpp | 6 +- .../sources/ttml/autograd/module_base.hpp | 4 +- tt-train/sources/ttml/optimizers/adamw.cpp | 91 ++++--- tt-train/sources/ttml/optimizers/adamw.hpp | 31 ++- .../ttml/optimizers/optimizer_base.cpp | 2 +- .../ttml/optimizers/optimizer_base.hpp | 13 +- tt-train/sources/ttml/optimizers/sgd.cpp | 23 +- tt-train/sources/ttml/optimizers/sgd.hpp | 24 +- .../ttml/schedulers/lambda_scheduler.cpp | 40 ++++ .../ttml/schedulers/lambda_scheduler.hpp | 32 +++ .../ttml/schedulers/linear_scheduler.cpp | 55 +++++ .../ttml/schedulers/linear_scheduler.hpp | 32 +++ .../ttml/schedulers/scheduler_base.cpp | 15 ++ .../ttml/schedulers/scheduler_base.hpp | 37 +++ .../ttml/schedulers/sequential_scheduler.cpp | 88 +++++++ .../ttml/schedulers/sequential_scheduler.hpp | 42 ++++ .../ttml/schedulers/step_scheduler.cpp | 50 ++++ .../ttml/schedulers/step_scheduler.hpp | 36 +++ .../ttml/serialization/msgpack_file.cpp | 67 ++++-- .../ttml/serialization/msgpack_file.hpp | 22 ++ .../ttml/serialization/serializable.hpp | 28 +++ .../ttml/serialization/serialization.cpp | 72 ++++-- .../ttml/serialization/serialization.hpp | 8 +- .../ttml/ttnn_fixed/trivial_ttnn_ops.hpp | 1 - .../autograd/module_base_parameters_test.cpp | 1 + .../model/linear_regression_full_test.cpp | 1 + tt-train/tests/schedulers/schedulers_test.cpp | 226 ++++++++++++++++++ 32 files changed, 975 insertions(+), 164 deletions(-) create mode 100644 tt-train/sources/ttml/schedulers/lambda_scheduler.cpp create mode 100644 tt-train/sources/ttml/schedulers/lambda_scheduler.hpp create mode 100644 tt-train/sources/ttml/schedulers/linear_scheduler.cpp create mode 100644 tt-train/sources/ttml/schedulers/linear_scheduler.hpp create mode 100644 tt-train/sources/ttml/schedulers/scheduler_base.cpp create mode 100644 tt-train/sources/ttml/schedulers/scheduler_base.hpp create mode 100644 tt-train/sources/ttml/schedulers/sequential_scheduler.cpp create mode 100644 tt-train/sources/ttml/schedulers/sequential_scheduler.hpp create mode 100644 tt-train/sources/ttml/schedulers/step_scheduler.cpp create mode 100644 tt-train/sources/ttml/schedulers/step_scheduler.hpp create mode 100644 tt-train/sources/ttml/serialization/serializable.hpp create mode 100644 tt-train/tests/schedulers/schedulers_test.cpp diff --git a/tt-train/sources/examples/mnist_mlp/main.cpp b/tt-train/sources/examples/mnist_mlp/main.cpp index 0528933d7bc..868e827d296 100644 --- a/tt-train/sources/examples/mnist_mlp/main.cpp +++ b/tt-train/sources/examples/mnist_mlp/main.cpp @@ -4,6 +4,8 @@ #include #include +#include +#include #include #include #include @@ -19,7 +21,6 @@ #include "optimizers/sgd.hpp" #include "utils.hpp" #include "yaml-cpp/node/node.h" - using ttml::autograd::TensorPtr; using DatasetSample = std::pair, uint8_t>; @@ -95,7 +96,6 @@ int main(int argc, char **argv) { CLI11_PARSE(app, argc, argv); auto yaml_config = YAML::LoadFile(config_name); TrainingConfig config = parse_config(yaml_config); - // Load MNIST data const size_t num_targets = 10; const size_t num_features = 784; @@ -151,7 +151,7 @@ int main(int argc, char **argv) { auto optimizer = ttml::optimizers::SGD(model->parameters(), sgd_config); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { fmt::print("Loading model from {}\n", config.model_path); - load_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + load_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } // evaluate model before training (sanity check to get reasonable accuracy @@ -176,7 +176,7 @@ int main(int argc, char **argv) { } if (!config.model_path.empty() && training_step % config.model_save_interval == 0) { fmt::print("Saving model to {}\n", config.model_path); - save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } loss->backward(); @@ -196,7 +196,7 @@ int main(int argc, char **argv) { if (!config.model_path.empty()) { fmt::print("Saving model to {}\n", config.model_path); - save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } return 0; diff --git a/tt-train/sources/examples/mnist_mlp/utils.hpp b/tt-train/sources/examples/mnist_mlp/utils.hpp index 00b28a6ffe7..863cb9311eb 100644 --- a/tt-train/sources/examples/mnist_mlp/utils.hpp +++ b/tt-train/sources/examples/mnist_mlp/utils.hpp @@ -38,7 +38,7 @@ class Timers { }; template -void save_model_and_optimizer( +void save_training_state( std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, @@ -51,7 +51,7 @@ void save_model_and_optimizer( } template -void load_model_and_optimizer( +void load_training_state( std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, diff --git a/tt-train/sources/examples/nano_gpt/main.cpp b/tt-train/sources/examples/nano_gpt/main.cpp index e086b21095e..87136c9f079 100644 --- a/tt-train/sources/examples/nano_gpt/main.cpp +++ b/tt-train/sources/examples/nano_gpt/main.cpp @@ -142,6 +142,8 @@ struct TrainingConfig { uint32_t gradient_accumulation_steps = 1; std::string model_path; std::string data_path; + std::string scheduler_type = "identity"; + ttml::models::gpt2::TransformerConfig transformer_config; }; @@ -161,10 +163,17 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) { training_config["gradient_accumulation_steps"].as(config.gradient_accumulation_steps); config.model_path = training_config["model_path"].as(""); config.data_path = training_config["data_path"].as(std::string(DATA_FOLDER) + "/shakespeare.txt"); + config.scheduler_type = training_config["scheduler_type"].as(config.scheduler_type); + config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]); return config; } +const std::unordered_map< + std::string, + std::function(ttml::optimizers::OptimizerBase *, size_t)>> + schedulers = {{"identity", create_idendity_scheduler}, {"warmup_linear", create_warmup_with_linear_scheduler}}; + int main(int argc, char **argv) { auto result = signal(SIGINT, signal_handler); if (result == SIG_ERR) { @@ -186,7 +195,6 @@ int main(int argc, char **argv) { CLI11_PARSE(app, argc, argv); auto yaml_config = YAML::LoadFile(config_name); TrainingConfig config = parse_config(yaml_config); - wandbcpp::init({.project = config.project_name, .name = generate_run_name(config, add_time_to_name)}); wandbcpp::update_config({ {"model", "transformer"}, @@ -206,10 +214,12 @@ int main(int argc, char **argv) { config.transformer_config.positional_embedding_type == ttml::models::gpt2::PositionalEmbeddingType::Trainable ? "trainable" : "fixed"}, + {"scheduler_type", config.scheduler_type}, }); // set seed ttml::autograd::ctx().set_seed(config.seed); + auto schedule_func = schedulers.at(config.scheduler_type); std::string text; try { @@ -218,11 +228,11 @@ int main(int argc, char **argv) { std::cerr << e.what() << std::endl; return -1; } - fmt::print("Max steps {}\n", config.max_steps); fmt::print("Batch size {}\n", config.batch_size); fmt::print("Gradient accumulation steps {}\n", config.gradient_accumulation_steps); fmt::print("Total batch size {}\n", config.batch_size * config.gradient_accumulation_steps); + fmt::print("Scheduler type {}\n", config.scheduler_type); fmt::print("Seed {}\n", ttml::autograd::ctx().get_seed()); auto sequence_length = config.transformer_config.max_sequence_length; @@ -304,10 +314,10 @@ int main(int argc, char **argv) { fmt::print(" Weight decay: {}\n", adamw_params.weight_decay); fmt::print(" Use Kahan summation: {}\n", adamw_params.use_kahan_summation); auto optimizer = ttml::optimizers::AdamW(model->parameters(), adamw_params); - + auto scheduler = schedule_func(&optimizer, config.max_steps); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { fmt::print("Loading model from {}\n", config.model_path); - load_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + load_training_state(config.model_path, model, scheduler, "transformer", "adamw"); fmt::print("Model loaded after {} steps\n", optimizer.get_steps()); } @@ -345,6 +355,7 @@ int main(int argc, char **argv) { if (gradient_accumulator_helper.should_step()) { optimizer.step(); + scheduler->step(); auto global_step = optimizer.get_steps(); fmt::print("Step: {}, Loss: {}\n", global_step, gradient_accumulator_helper.average_loss()); loss_meter.update(gradient_accumulator_helper.average_loss()); @@ -353,11 +364,12 @@ int main(int argc, char **argv) { wandbcpp::log( {{"Step", (int)global_step}, {"Samples", (int)get_samples_count(global_step)}, - {"Loss", loss_meter.average()}}); + {"Loss", loss_meter.average()}, + {"Learning rate", optimizer.get_lr()}}); loss_meter.reset(); } if (!config.model_path.empty() && global_step % config.model_save_interval == 0) { - save_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + save_training_state(config.model_path, model, scheduler, "transformer", "adamw"); } if (global_step >= config.max_steps) { @@ -379,7 +391,7 @@ int main(int argc, char **argv) { } if (!config.model_path.empty()) { - save_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + save_training_state(config.model_path, model, scheduler, "transformer", "adamw"); } auto end_timer = std::chrono::high_resolution_clock::now(); diff --git a/tt-train/sources/examples/nano_gpt/utils.cpp b/tt-train/sources/examples/nano_gpt/utils.cpp index 408a3c01a38..e89b90a8f29 100644 --- a/tt-train/sources/examples/nano_gpt/utils.cpp +++ b/tt-train/sources/examples/nano_gpt/utils.cpp @@ -73,3 +73,22 @@ void GradientAccumulator::reset() { float GradientAccumulator::average_loss() const { return m_total_loss / static_cast(m_total_samples); } + +std::unique_ptr create_idendity_scheduler( + ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps) { + return std::make_unique(optimizer, [](int epoch) { return 1.0F; }); +} + +std::unique_ptr create_warmup_with_linear_scheduler( + ttml::optimizers::OptimizerBase *optimizer, size_t total_steps) { + const float default_warmup_factor = 0.1F; + const size_t warmup_steps = size_t(total_steps * default_warmup_factor); + const size_t linear_decay_steps = total_steps - warmup_steps; + + std::vector> schedulers; + schedulers.push_back(std::make_unique(optimizer, 0.0F, 1.0F, warmup_steps)); + schedulers.push_back( + std::make_unique(optimizer, 1.0F, 0.01F, linear_decay_steps)); + std::vector steps = {warmup_steps, linear_decay_steps}; + return std::make_unique(optimizer, std::move(schedulers), std::move(steps)); +} diff --git a/tt-train/sources/examples/nano_gpt/utils.hpp b/tt-train/sources/examples/nano_gpt/utils.hpp index e390dc3f483..c7383c1e9ac 100644 --- a/tt-train/sources/examples/nano_gpt/utils.hpp +++ b/tt-train/sources/examples/nano_gpt/utils.hpp @@ -10,6 +10,10 @@ #include #include "autograd/tensor.hpp" +#include "schedulers/lambda_scheduler.hpp" +#include "schedulers/linear_scheduler.hpp" +#include "schedulers/scheduler_base.hpp" +#include "schedulers/sequential_scheduler.hpp" #include "serialization/msgpack_file.hpp" #include "serialization/serialization.hpp" @@ -25,32 +29,42 @@ class LossAverageMeter { void reset(); }; +std::unique_ptr create_idendity_scheduler( + ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps); + +std::unique_ptr create_warmup_with_linear_scheduler( + ttml::optimizers::OptimizerBase *optimizer, size_t total_steps); + std::string read_file_to_str(const std::string &file_path); -template -void save_model_and_optimizer( +template +void save_training_state( std::string &model_path, const std::shared_ptr &model, - Optimizer &optimizer, + const std::unique_ptr &scheduler, const std::string &model_name, const std::string &optimizer_name) { ttml::serialization::MsgPackFile serializer; ttml::serialization::write_module(serializer, model_name, model.get()); - ttml::serialization::write_optimizer(serializer, optimizer_name, &optimizer); + ttml::serialization::write_optimizer(serializer, optimizer_name, scheduler->get_optimizer().get()); + ttml::serialization::write_state_dict(serializer, "scheduler", scheduler->get_state_dict()); serializer.serialize(model_path); } -template -void load_model_and_optimizer( +template +void load_training_state( std::string &model_path, const std::shared_ptr &model, - Optimizer &optimizer, + const std::unique_ptr &scheduler, const std::string &model_name, const std::string &optimizer_name) { ttml::serialization::MsgPackFile deserializer; deserializer.deserialize(model_path); ttml::serialization::read_module(deserializer, model_name, model.get()); - ttml::serialization::read_optimizer(deserializer, optimizer_name, &optimizer); + ttml::serialization::read_optimizer(deserializer, optimizer_name, scheduler->get_optimizer().get()); + auto state_dict = scheduler->get_state_dict(); + ttml::serialization::read_state_dict(deserializer, "scheduler", state_dict); + scheduler->set_state_dict(state_dict); } uint32_t round_up_to_tile(uint32_t value, uint32_t tile_size = 32); @@ -110,11 +124,12 @@ std::string generate_run_name(const TrainingConfig &config, bool add_time_to_run if (config.gradient_accumulation_steps > 1) { ss << "_grad_acc_" << config.gradient_accumulation_steps; } - + ss << "_sched_" << config.scheduler_type; if (add_time_to_run_name) { auto now = std::chrono::system_clock::now(); std::time_t current_time = std::chrono::system_clock::to_time_t(now); ss << "_date_" << std::put_time(std::localtime(¤t_time), "%Y-%m-%d_%H:%M:%S"); } + return ss.str(); } diff --git a/tt-train/sources/ttml/autograd/module_base.cpp b/tt-train/sources/ttml/autograd/module_base.cpp index 4cc13b09826..b3f771b3688 100644 --- a/tt-train/sources/ttml/autograd/module_base.cpp +++ b/tt-train/sources/ttml/autograd/module_base.cpp @@ -4,8 +4,6 @@ #include "module_base.hpp" -#include "auto_context.hpp" - namespace ttml::autograd { void ModuleBase::register_tensor(const TensorPtr& tensor_ptr, const std::string& name) { @@ -30,8 +28,8 @@ const std::string& ModuleBase::get_name() const { return m_name; } -NamedParameters ModuleBase::parameters() const { - NamedParameters params; +serialization::NamedParameters ModuleBase::parameters() const { + serialization::NamedParameters params; std::queue> modules_to_process; modules_to_process.emplace(this, get_name() + "/"); diff --git a/tt-train/sources/ttml/autograd/module_base.hpp b/tt-train/sources/ttml/autograd/module_base.hpp index 442d0dc36f1..b2729bde46e 100644 --- a/tt-train/sources/ttml/autograd/module_base.hpp +++ b/tt-train/sources/ttml/autograd/module_base.hpp @@ -7,6 +7,7 @@ #include #include +#include "serialization/serializable.hpp" #include "tensor.hpp" namespace ttml::autograd { @@ -15,7 +16,6 @@ enum class RunMode { TRAIN, EVAL }; class ModuleBase; using ModuleBasePtr = std::shared_ptr; -using NamedParameters = std::unordered_map; class ModuleBase { private: @@ -39,7 +39,7 @@ class ModuleBase { ModuleBase& operator=(ModuleBase&&) = default; [[nodiscard]] const std::string& get_name() const; - [[nodiscard]] NamedParameters parameters() const; + [[nodiscard]] serialization::NamedParameters parameters() const; void train(); void eval(); diff --git a/tt-train/sources/ttml/optimizers/adamw.cpp b/tt-train/sources/ttml/optimizers/adamw.cpp index 8770b9da5a8..d18901797d8 100644 --- a/tt-train/sources/ttml/optimizers/adamw.cpp +++ b/tt-train/sources/ttml/optimizers/adamw.cpp @@ -10,19 +10,19 @@ #include "core/debug.hpp" #include "core/tt_tensor_utils.hpp" #include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" #include "ttnn_fixed/trivial_ttnn_ops.hpp" - namespace { -const std::string kFirstMoment = "first_moment/"; -const std::string kSecondMoment = "second_moment/"; -const std::string kKahanCompensation = "kahan_compensation/"; - +const std::string kFirstMoment = "first_moment"; +const std::string kSecondMoment = "second_moment"; +const std::string kKahanCompensation = "kahan_compensation"; +const std::string kSteps = "steps"; } // namespace namespace ttml::optimizers { -MorehAdamW::MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : +MorehAdamW::MorehAdamW(serialization::NamedParameters parameters, const AdamWConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { if (m_config.use_kahan_summation) { throw std::runtime_error("MorehAdamW: Kahan summation is not supported. Use default AdamW instead."); @@ -95,29 +95,19 @@ void MorehAdamW::step() { } } -[[nodiscard]] autograd::NamedParameters MorehAdamW::get_state_dict() const { - autograd::NamedParameters state_dict; - for (const auto& [key, first_moment] : m_first_moment) { - state_dict.emplace(kFirstMoment + key, first_moment); - } - - for (const auto& [key, second_moment] : m_second_moment) { - state_dict.emplace(kSecondMoment + key, second_moment); - } +[[nodiscard]] serialization::StateDict MorehAdamW::get_state_dict() const { + serialization::StateDict state_dict; + state_dict[kFirstMoment] = m_first_moment; + state_dict[kSecondMoment] = m_second_moment; + state_dict[kSteps] = m_steps; return state_dict; } -void MorehAdamW::set_state_dict(const autograd::NamedParameters& dict) { - for (const auto& [key, tensor] : dict) { - if (key.starts_with(kFirstMoment)) { - m_first_moment[key.substr(kFirstMoment.size())] = tensor; - } else if (key.starts_with(kSecondMoment)) { - m_second_moment[key.substr(kSecondMoment.size())] = tensor; - } else { - throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key)); - } - } +void MorehAdamW::set_state_dict(const serialization::StateDict& dict) { + m_first_moment = std::get(dict.at(kFirstMoment)); + m_second_moment = std::get(dict.at(kSecondMoment)); + m_steps = serialization::get_value_type(dict, kSteps); } [[nodiscard]] size_t MorehAdamW::get_steps() const { @@ -128,7 +118,14 @@ void MorehAdamW::set_steps(size_t steps) { m_steps = steps; } -AdamW::AdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : +float MorehAdamW::get_lr() const { + return m_config.lr; +} +void MorehAdamW::set_lr(float lr) { + m_config.lr = lr; +} + +AdamW::AdamW(serialization::NamedParameters parameters, const AdamWConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { for (const auto& [key, tensor_ptr] : m_parameters) { if (tensor_ptr->get_requires_grad()) { @@ -226,35 +223,21 @@ void AdamW::step() { } } -[[nodiscard]] autograd::NamedParameters AdamW::get_state_dict() const { - autograd::NamedParameters state_dict; - for (const auto& [key, first_moment] : m_first_moment) { - state_dict.emplace(kFirstMoment + key, first_moment); - } - - for (const auto& [key, second_moment] : m_second_moment) { - state_dict.emplace(kSecondMoment + key, second_moment); - } - - for (const auto& [key, kahan_compensation] : m_kahan_compensation) { - state_dict.emplace(kKahanCompensation + key, kahan_compensation); - } +[[nodiscard]] serialization::StateDict AdamW::get_state_dict() const { + serialization::StateDict state_dict; + state_dict[kFirstMoment] = m_first_moment; + state_dict[kSecondMoment] = m_second_moment; + state_dict[kKahanCompensation] = m_kahan_compensation; + state_dict[kSteps] = m_steps; return state_dict; } -void AdamW::set_state_dict(const autograd::NamedParameters& dict) { - for (const auto& [key, tensor] : dict) { - if (key.starts_with(kFirstMoment)) { - m_first_moment[key.substr(kFirstMoment.size())] = tensor; - } else if (key.starts_with(kSecondMoment)) { - m_second_moment[key.substr(kSecondMoment.size())] = tensor; - } else if (key.starts_with(kKahanCompensation)) { - m_kahan_compensation[key.substr(kKahanCompensation.size())] = tensor; - } else { - throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key)); - } - } +void AdamW::set_state_dict(const serialization::StateDict& dict) { + m_first_moment = std::get(dict.at(kFirstMoment)); + m_second_moment = std::get(dict.at(kSecondMoment)); + m_kahan_compensation = std::get(dict.at(kKahanCompensation)); + m_steps = serialization::get_value_type(dict, kSteps); } [[nodiscard]] size_t AdamW::get_steps() const { @@ -265,4 +248,10 @@ void AdamW::set_steps(size_t steps) { m_steps = steps; } +float AdamW::get_lr() const { + return m_config.lr; +} +void AdamW::set_lr(float lr) { + m_config.lr = lr; +} } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/adamw.hpp b/tt-train/sources/ttml/optimizers/adamw.hpp index da3847f66db..d4505d8cb01 100644 --- a/tt-train/sources/ttml/optimizers/adamw.hpp +++ b/tt-train/sources/ttml/optimizers/adamw.hpp @@ -4,8 +4,8 @@ #include -#include "autograd/module_base.hpp" #include "optimizer_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { @@ -23,45 +23,52 @@ struct AdamWConfig { class MorehAdamW : public OptimizerBase { public: - MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config); + MorehAdamW(serialization::NamedParameters parameters, const AdamWConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override; + void set_lr(float lr) override; + private: size_t m_steps{0}; AdamWConfig m_config; - autograd::NamedParameters m_first_moment; - autograd::NamedParameters m_second_moment; + serialization::NamedParameters m_first_moment; + serialization::NamedParameters m_second_moment; }; class AdamW : public OptimizerBase { public: - AdamW(autograd::NamedParameters parameters, const AdamWConfig& config); + AdamW(serialization::NamedParameters parameters, const AdamWConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override; + + void set_lr(float lr) override; + private: size_t m_steps{0}; AdamWConfig m_config; - autograd::NamedParameters m_first_moment; - autograd::NamedParameters m_second_moment; - autograd::NamedParameters m_kahan_compensation; + serialization::NamedParameters m_first_moment; + serialization::NamedParameters m_second_moment; + serialization::NamedParameters m_kahan_compensation; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/optimizer_base.cpp b/tt-train/sources/ttml/optimizers/optimizer_base.cpp index 446f23d6714..7971998d087 100644 --- a/tt-train/sources/ttml/optimizers/optimizer_base.cpp +++ b/tt-train/sources/ttml/optimizers/optimizer_base.cpp @@ -8,7 +8,7 @@ namespace ttml::optimizers { -OptimizerBase::OptimizerBase(autograd::NamedParameters&& parameters) : m_parameters(std::move(parameters)) { +OptimizerBase::OptimizerBase(serialization::NamedParameters&& parameters) : m_parameters(std::move(parameters)) { } void OptimizerBase::print_stats() const { diff --git a/tt-train/sources/ttml/optimizers/optimizer_base.hpp b/tt-train/sources/ttml/optimizers/optimizer_base.hpp index 49f1f4a32aa..690d0fd9ed6 100644 --- a/tt-train/sources/ttml/optimizers/optimizer_base.hpp +++ b/tt-train/sources/ttml/optimizers/optimizer_base.hpp @@ -4,13 +4,13 @@ #pragma once -#include "autograd/module_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { class OptimizerBase { public: - explicit OptimizerBase(autograd::NamedParameters&& parameters); + explicit OptimizerBase(serialization::NamedParameters&& parameters); OptimizerBase(const OptimizerBase&) = delete; OptimizerBase& operator=(const OptimizerBase&) = delete; OptimizerBase(OptimizerBase&&) = delete; @@ -21,16 +21,19 @@ class OptimizerBase { virtual void step() = 0; - [[nodiscard]] virtual autograd::NamedParameters get_state_dict() const = 0; - virtual void set_state_dict(const autograd::NamedParameters& dict) = 0; + [[nodiscard]] virtual serialization::StateDict get_state_dict() const = 0; + virtual void set_state_dict(const serialization::StateDict& dict) = 0; [[nodiscard]] virtual size_t get_steps() const = 0; virtual void set_steps(size_t steps) = 0; + virtual void set_lr(float lr) = 0; + [[nodiscard]] virtual float get_lr() const = 0; + virtual void print_stats() const; protected: - autograd::NamedParameters m_parameters; + serialization::NamedParameters m_parameters; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/sgd.cpp b/tt-train/sources/ttml/optimizers/sgd.cpp index 0e25feb95fe..48298585644 100644 --- a/tt-train/sources/ttml/optimizers/sgd.cpp +++ b/tt-train/sources/ttml/optimizers/sgd.cpp @@ -9,10 +9,11 @@ #include "autograd/autocast_tensor.hpp" #include "core/debug.hpp" #include "core/tt_tensor_utils.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { -SGD::SGD(ttml::autograd::NamedParameters parameters, const SGDConfig& config) : +SGD::SGD(ttml::serialization::NamedParameters parameters, const SGDConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { for (const auto& [name, tensor_ptr] : m_parameters) { if (tensor_ptr->get_requires_grad()) { @@ -53,7 +54,7 @@ void SGD::step() { } if (m_config.momentum != 0.0F) { - if (steps != 0) { + if (m_steps != 0) { // apply momentum theta = ttnn::multiply(theta, m_config.momentum); // dampening @@ -76,23 +77,27 @@ void SGD::step() { tensor_ptr->set_value(ttnn::subtract( tensor_ptr->get_value(autograd::PreferredPrecision::FULL), ttnn::multiply(gradients, m_config.lr))); } - steps++; + m_steps++; } -autograd::NamedParameters SGD::get_state_dict() const { - return m_theta; +serialization::StateDict SGD::get_state_dict() const { + serialization::StateDict dict; + dict["theta"] = m_theta; + dict["steps"] = m_steps; + return dict; } -void SGD::set_state_dict(const autograd::NamedParameters& dict) { - m_theta = dict; +void SGD::set_state_dict(const serialization::StateDict& dict) { + m_theta = std::get(dict.at("theta")); + m_steps = serialization::get_value_type(dict, "steps"); } size_t SGD::get_steps() const { - return steps; + return m_steps; } void SGD::set_steps(size_t steps) { - this->steps = steps; + this->m_steps = steps; } } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/sgd.hpp b/tt-train/sources/ttml/optimizers/sgd.hpp index 756facdf26c..298aef045f6 100644 --- a/tt-train/sources/ttml/optimizers/sgd.hpp +++ b/tt-train/sources/ttml/optimizers/sgd.hpp @@ -4,12 +4,8 @@ #pragma once -#include - -#include "autograd/module_base.hpp" -#include "autograd/tensor.hpp" -#include "core/tt_tensor_utils.hpp" #include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { @@ -23,22 +19,30 @@ struct SGDConfig { class SGD : public OptimizerBase { public: - explicit SGD(ttml::autograd::NamedParameters parameters, const SGDConfig& config); + explicit SGD(ttml::serialization::NamedParameters parameters, const SGDConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override { + return m_config.lr; + } + + void set_lr(float lr) override { + m_config.lr = lr; + } + private: - size_t steps{0}; + size_t m_steps{0}; SGDConfig m_config; - ttml::autograd::NamedParameters m_theta; + ttml::serialization::NamedParameters m_theta; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp b/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp new file mode 100644 index 00000000000..f49ac00732d --- /dev/null +++ b/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "lambda_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" +namespace ttml::schedulers { + +LambdaScheduler::LambdaScheduler(optimizers::OptimizerBase *optimizer, std::function lr_lambda) : + LRSchedulerBase(optimizer), + m_lr_lambda(std::move(lr_lambda)), + m_last_step(0), + m_base_lr(optimizer->get_lr()), + m_last_lr(optimizer->get_lr()) { +} +void LambdaScheduler::step() { + m_last_step += 1; + float lr_factor = m_lr_lambda(m_last_step); + float new_lr = m_base_lr * lr_factor; + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} +float LambdaScheduler::get_last_lr() const { + return m_last_lr; +} +float LambdaScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} +void LambdaScheduler::set_state_dict(const serialization::StateDict &dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} +serialization::StateDict LambdaScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp b/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp new file mode 100644 index 00000000000..e75b167104c --- /dev/null +++ b/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { +class LambdaScheduler : public LRSchedulerBase { +public: + explicit LambdaScheduler(optimizers::OptimizerBase *optimizer, std::function lr_lambda); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + + void set_state_dict(const serialization::StateDict &dict) override; + +private: + std::function m_lr_lambda; + size_t m_last_step = 0; + float m_base_lr = 0.0F; + float m_last_lr = 0.0F; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.cpp b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp new file mode 100644 index 00000000000..964bfb0b4f1 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "linear_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" + +namespace ttml::schedulers { + +LinearScheduler::LinearScheduler( + optimizers::OptimizerBase* optimizer, float start_factor, float end_factor, size_t total_steps) : + LRSchedulerBase(optimizer), + m_base_lr(optimizer->get_lr()), + m_last_lr(m_base_lr), + m_start_factor(start_factor), + m_end_factor(end_factor), + m_total_steps(total_steps), + m_last_step(0) { +} + +void LinearScheduler::step() { + m_last_step += 1; + + float progress = static_cast(m_last_step) / m_total_steps; + progress = std::min(progress, 1.0f); + + float current_factor = m_start_factor + (m_end_factor - m_start_factor) * progress; + float new_lr = m_base_lr * current_factor; + + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} + +void LinearScheduler::set_state_dict(const serialization::StateDict& dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} + +serialization::StateDict LinearScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; + +float LinearScheduler::get_last_lr() const { + return m_last_lr; +} + +float LinearScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.hpp b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp new file mode 100644 index 00000000000..9a8edcf18bb --- /dev/null +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class LinearScheduler : public LRSchedulerBase { +public: + LinearScheduler(optimizers::OptimizerBase *optimizer, float start_factor, float end_factor, size_t total_steps); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict &dict) override; + +private: + float m_base_lr = 0.F; + float m_start_factor = 0.F; + float m_end_factor = 0.F; + int m_total_steps = 0; + size_t m_last_step = 0; + float m_last_lr = 0.F; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/scheduler_base.cpp b/tt-train/sources/ttml/schedulers/scheduler_base.cpp new file mode 100644 index 00000000000..7e9e90c9092 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/scheduler_base.cpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +core::not_null ttml::schedulers::LRSchedulerBase::get_optimizer() const { + return m_optimizer; +} +LRSchedulerBase::LRSchedulerBase(optimizers::OptimizerBase *optimizer) : m_optimizer(optimizer) { +} + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/scheduler_base.hpp b/tt-train/sources/ttml/schedulers/scheduler_base.hpp new file mode 100644 index 00000000000..4fd52ff5526 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/scheduler_base.hpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "core/not_null.hpp" +#include "serialization/serializable.hpp" + +namespace ttml::optimizers { +class OptimizerBase; +} + +namespace ttml::schedulers { + +class LRSchedulerBase { +public: + explicit LRSchedulerBase(optimizers::OptimizerBase *optimizer); + + virtual ~LRSchedulerBase() = default; + + virtual void step() = 0; + + [[nodiscard]] virtual float get_last_lr() const = 0; + + [[nodiscard]] virtual float get_current_lr() const = 0; + + [[nodiscard]] core::not_null get_optimizer() const; + + [[nodiscard]] virtual serialization::StateDict get_state_dict() const = 0; + virtual void set_state_dict(const serialization::StateDict &dict) = 0; + +private: + core::not_null m_optimizer; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp b/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp new file mode 100644 index 00000000000..cac72d08156 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "sequential_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" +namespace { +const std::string kCurrentScheduler = "current_scheduler/"; +} +namespace ttml::schedulers { +SequentialScheduler::SequentialScheduler( + optimizers::OptimizerBase *optimizer, + std::vector> schedulers, + std::vector milestones) : + LRSchedulerBase(optimizer), + m_schedulers(std::move(schedulers)), + m_milestones(std::move(milestones)), + m_current_scheduler_index(0), + m_current_step_in_scheduler(0), + m_last_lr(optimizer->get_lr()) { + if (m_schedulers.empty()) { + throw std::invalid_argument("SequentialScheduler requires at least one scheduler."); + } + + // Validate that each scheduler is non-null + for (auto &scheduler : m_schedulers) { + if (!scheduler) { + throw std::invalid_argument("Null scheduler provided to SequentialScheduler."); + } + } +} +void SequentialScheduler::step() { + if (m_current_scheduler_index >= m_schedulers.size()) { + return; + } + + auto ¤t_scheduler = m_schedulers[m_current_scheduler_index]; + auto current_sched_steps = m_milestones[m_current_scheduler_index]; + current_scheduler->step(); + m_current_step_in_scheduler += 1; + m_last_lr = current_scheduler->get_last_lr(); + + if (m_current_step_in_scheduler >= current_sched_steps) { + m_current_scheduler_index += 1; + m_current_step_in_scheduler = 0; + } +} +float SequentialScheduler::get_last_lr() const { + if (m_current_scheduler_index == 0) { + return (m_current_scheduler_index < m_schedulers.size()) + ? m_schedulers[m_current_scheduler_index]->get_last_lr() + : m_last_lr; + } else if (m_current_scheduler_index < m_schedulers.size()) { + return m_schedulers[m_current_scheduler_index]->get_last_lr(); + } + return m_last_lr; +} +float SequentialScheduler::get_current_lr() const { + // The current LR of the optimizer should reflect the last scheduler's step + return get_optimizer()->get_lr(); +} + +void SequentialScheduler::set_state_dict(const serialization::StateDict &dict) { + m_current_step_in_scheduler = serialization::get_value_type(dict, "m_current_step_in_scheduler"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); + m_current_scheduler_index = serialization::get_value_type(dict, "m_current_scheduler_index"); + serialization::StateDict current_scheduler_dict; + for (auto &[key, value] : dict) { + if (key.find(kCurrentScheduler) == 0) { + current_scheduler_dict[key.substr(kCurrentScheduler.length())] = value; + } + } + m_schedulers[m_current_scheduler_index]->set_state_dict(current_scheduler_dict); +} +serialization::StateDict SequentialScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_current_step_in_scheduler"] = m_current_step_in_scheduler; + res["m_last_lr"] = m_last_lr; + res["m_current_scheduler_index"] = m_current_scheduler_index; + for (auto &[key, value] : m_schedulers[m_current_scheduler_index]->get_state_dict()) { + res[kCurrentScheduler + key] = value; + } + return res; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp new file mode 100644 index 00000000000..eac686c1b2c --- /dev/null +++ b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class SequentialScheduler : public LRSchedulerBase { +public: + // Each element in the schedulers vector is a (scheduler, steps) pair. + // The scheduler runs for 'steps' times, then we move on to the next one. + // A little bit different from the PyTorch implementation, where the milestones might be less then the number of + // schedulers which is missleading + SequentialScheduler( + optimizers::OptimizerBase *optimizer, + std::vector> schedulers, + std::vector milestones); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict &dict) override; + +private: + std::vector> m_schedulers; + std::vector m_milestones; + size_t m_current_scheduler_index = 0; + int m_current_step_in_scheduler = 0; + float m_last_lr = 0.F; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/step_scheduler.cpp b/tt-train/sources/ttml/schedulers/step_scheduler.cpp new file mode 100644 index 00000000000..ec1acf8cb01 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/step_scheduler.cpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "step_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" + +namespace ttml::schedulers { + +StepScheduler::StepScheduler(optimizers::OptimizerBase *optimizer, size_t step_size, float gamma) : + LRSchedulerBase(optimizer), + m_step_size(step_size), + m_gamma(gamma), + m_last_step(0), + m_base_lr(optimizer->get_lr()), + m_last_lr(m_base_lr) { + if (gamma <= 0.0f) { + throw std::invalid_argument(fmt::format("gamma = {} must be greater than zero.", gamma)); + } +} +void StepScheduler::step() { + m_last_step += 1; + + // Every step_size epochs, lr is scaled by gamma + int num_steps = m_last_step / m_step_size; + float new_lr = m_base_lr * std::pow(m_gamma, static_cast(num_steps)); + + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} +float StepScheduler::get_last_lr() const { + return m_last_lr; +} +float StepScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} + +void StepScheduler::set_state_dict(const serialization::StateDict &dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} +serialization::StateDict StepScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/step_scheduler.hpp b/tt-train/sources/ttml/schedulers/step_scheduler.hpp new file mode 100644 index 00000000000..2f0189e9d78 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/step_scheduler.hpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class StepScheduler : public LRSchedulerBase { +public: + StepScheduler(optimizers::OptimizerBase *optimizer, size_t step_size, float gamma = 0.1f); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + + void set_state_dict(const serialization::StateDict &dict) override; + +private: + size_t m_step_size = 0; + float m_gamma = 0; + size_t m_last_step = 0; + + float m_base_lr = 0.F; + float m_last_lr = 0.F; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/serialization/msgpack_file.cpp b/tt-train/sources/ttml/serialization/msgpack_file.cpp index 42fb0b53378..573218a0b5d 100644 --- a/tt-train/sources/ttml/serialization/msgpack_file.cpp +++ b/tt-train/sources/ttml/serialization/msgpack_file.cpp @@ -6,12 +6,11 @@ #include -#include +#include #include #define MSGPACK_NO_BOOST #include #include -#include #include #include #include @@ -122,6 +121,10 @@ class MsgPackFile::Impl { } // Overloads for std::span + void put(std::string_view key, std::span value) { + m_data[std::string(key)] = std::vector(value.begin(), value.end()); + } + void put(std::string_view key, std::span value) { m_data[std::string(key)] = std::vector(value.begin(), value.end()); } @@ -142,6 +145,10 @@ class MsgPackFile::Impl { m_data[std::string(key)] = std::vector(value.begin(), value.end()); } + void put(std::string_view key, const ValueType& value) { + m_data[std::string(key)] = value; + } + // Serialization method void serialize(const std::string& filename) { // Create a buffer for packing @@ -216,6 +223,10 @@ class MsgPackFile::Impl { return get_value(key, value); } + bool get(std::string_view key, std::vector& value) const { + return get_value(key, value); + } + bool get(std::string_view key, std::vector& value) const { return get_value(key, value); } @@ -236,23 +247,11 @@ class MsgPackFile::Impl { return get_value(key, value); } -private: - using ValueType = std::variant< - bool, - char, - int, - float, - double, - uint32_t, - size_t, - std::string, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector>; + bool get(std::string_view key, ValueType& value) const { + return get_value(key, value); + } +private: std::unordered_map m_data; // Helper function to get value from m_data @@ -271,6 +270,17 @@ class MsgPackFile::Impl { throw std::runtime_error(fmt::format("Key not found: {}", key)); } } + template <> + bool get_value(std::string_view key, ValueType& value) const { + auto it = m_data.find(std::string(key)); + if (it != m_data.end()) { + value = it->second; + return true; + } else { + // Key not found + throw std::runtime_error(fmt::format("Key not found: {}", key)); + } + } }; MsgPackFile::MsgPackFile() : m_impl(std::make_unique()) { @@ -312,6 +322,10 @@ void MsgPackFile::put(std::string_view key, std::string_view value) { m_impl->put(key, value); } +void MsgPackFile::put(std::string_view key, std::span value) { + m_impl->put(key, value); +} + void MsgPackFile::put(std::string_view key, std::span value) { m_impl->put(key, value); } @@ -332,6 +346,14 @@ void MsgPackFile::put(std::string_view key, std::span value) m_impl->put(key, value); } +void MsgPackFile::put(std::string_view key, const char* value) { + put(key, std::string_view(value)); +} + +void MsgPackFile::put(std::string_view key, const ValueType& value) { + m_impl->put(key, value); +} + void MsgPackFile::serialize(const std::string& filename) { m_impl->serialize(filename); } @@ -372,6 +394,10 @@ void MsgPackFile::get(std::string_view key, std::string& value) const { m_impl->get(key, value); } +void MsgPackFile::get(std::string_view key, std::vector& value) const { + m_impl->get(key, value); +} + void MsgPackFile::get(std::string_view key, std::vector& value) const { m_impl->get(key, value); } @@ -392,7 +418,8 @@ void MsgPackFile::get(std::string_view key, std::vector& value) con m_impl->get(key, value); } -void MsgPackFile::put(std::string_view key, const char* value) { - put(key, std::string_view(value)); +void MsgPackFile::get(std::string_view key, ValueType& value) const { + m_impl->get(key, value); } + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/msgpack_file.hpp b/tt-train/sources/ttml/serialization/msgpack_file.hpp index 19f36f6cca9..6e170483a6f 100644 --- a/tt-train/sources/ttml/serialization/msgpack_file.hpp +++ b/tt-train/sources/ttml/serialization/msgpack_file.hpp @@ -13,6 +13,23 @@ namespace ttml::serialization { +using ValueType = std::variant< + bool, + char, + int, + float, + double, + uint32_t, + size_t, + std::string, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector>; + class MsgPackFile { public: MsgPackFile(); @@ -44,12 +61,14 @@ class MsgPackFile { void put(std::string_view key, const char* value); // Overloads for std::span + void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); + void put(std::string_view key, const ValueType& value); // Serialization method void serialize(const std::string& filename); @@ -67,12 +86,15 @@ class MsgPackFile { void get(std::string_view key, std::string& value) const; // Methods to get vectors (from spans) + void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; + void get(std::string_view key, ValueType& type) const; + private: class Impl; std::unique_ptr m_impl; diff --git a/tt-train/sources/ttml/serialization/serializable.hpp b/tt-train/sources/ttml/serialization/serializable.hpp new file mode 100644 index 00000000000..689aa24d7ed --- /dev/null +++ b/tt-train/sources/ttml/serialization/serializable.hpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include +#include + +#include "autograd/tensor.hpp" +#include "msgpack_file.hpp" + +namespace ttml::serialization { +using NamedParameters = std::unordered_map; +using SerializableType = std::variant; +using StateDict = std::unordered_map; + +template +concept IsValueType = requires { + { std::get(std::declval()) }; +}; + +template +const T& get_value_type(const StateDict& dict, const std::string& key) { + const auto& val_type = std::get(dict.at(key)); + return std::get(val_type); +} + +} // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/serialization.cpp b/tt-train/sources/ttml/serialization/serialization.cpp index d96e26f014f..401b96a26bf 100644 --- a/tt-train/sources/ttml/serialization/serialization.cpp +++ b/tt-train/sources/ttml/serialization/serialization.cpp @@ -21,24 +21,23 @@ namespace ttml::serialization { // trivial type to the std::string template -std::string to_bytes(const T& value) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - std::string bytes(sizeof(T), '\0'); - std::memcpy(bytes.data(), &value, sizeof(T)); - return bytes; +std::span to_bytes(T& value) { + static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); + auto ptr = reinterpret_cast(&value); + return std::span(ptr, sizeof(T)); } template -void from_bytes(const std::string& bytes, T& value) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); +void from_bytes(std::span bytes, T& value) { + static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); if (bytes.size() != sizeof(T)) { - throw std::invalid_argument(fmt::format( - "Invalid byte size for conversion to type T. Expected: {} Actual: {}, type: {} ", - sizeof(T), - bytes.size(), - core::demangle(typeid(T).name()))); + std::ostringstream oss; + oss << "Invalid byte size for conversion to type T. Expected: " << sizeof(T) << " Actual: " << bytes.size() + << ", type: " << typeid(T).name(); + throw std::invalid_argument(oss.str()); } + std::memcpy(&value, bytes.data(), sizeof(T)); } @@ -77,7 +76,7 @@ void read_ttnn_tensor(MsgPackFile& file, std::string_view name, tt::tt_metal::Te tt::tt_metal::StorageType storage_type{}; auto shape = core::create_shape({1, 1, 1, 1}); - std::string bytes; + std::vector bytes; file.get(std::string(name) + "/shape", bytes); from_bytes(bytes, shape); @@ -127,12 +126,13 @@ void read_autograd_tensor(MsgPackFile& file, std::string_view name, ttml::autogr } } -void write_named_parameters(MsgPackFile& file, std::string_view name, const ttml::autograd::NamedParameters& params) { +void write_named_parameters( + MsgPackFile& file, std::string_view name, const ttml::serialization::NamedParameters& params) { for (const auto& [key, value] : params) { write_autograd_tensor(file, std::string(name) + "/" + key, value); } } -void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autograd::NamedParameters& params) { +void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::serialization::NamedParameters& params) { for (auto& [key, value] : params) { read_autograd_tensor(file, std::string(name) + "/" + key, value); } @@ -141,22 +141,15 @@ void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autog void write_optimizer(MsgPackFile& file, std::string_view name, const optimizers::OptimizerBase* optimizer) { assert(optimizer); auto state_dict = optimizer->get_state_dict(); - for (const auto& [key, value] : state_dict) { - ttml::serialization::write_autograd_tensor(file, std::string(name) + "/" + key, value); - } - file.put(std::string(name) + "/steps", optimizer->get_steps()); + write_state_dict(file, std::string(name), state_dict); } void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::OptimizerBase* optimizer) { assert(optimizer); size_t steps = 0; auto state_dict = optimizer->get_state_dict(); - for (auto& [key, value] : state_dict) { - ttml::serialization::read_autograd_tensor(file, std::string(name) + "/" + key, value); - } + read_state_dict(file, name, state_dict); optimizer->set_state_dict(state_dict); - file.get(std::string(name) + "/steps", steps); - optimizer->set_steps(steps); } void write_module(MsgPackFile& file, std::string_view name, const autograd::ModuleBase* module) { @@ -171,4 +164,35 @@ void read_module(MsgPackFile& file, std::string_view name, autograd::ModuleBase* read_named_parameters(file, name, named_parameters); } +void write_state_dict(MsgPackFile& file, std::string_view name, const serialization::StateDict& state_dict) { + for (const auto& [key, value] : state_dict) { + if (std::holds_alternative(value)) { + file.put(std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_ttnn_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_autograd_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_named_parameters(file, std::string(name) + "/" + key, std::get(value)); + } else { + throw std::runtime_error("Unsupported type in state dict"); + } + } +} +void read_state_dict(MsgPackFile& file, std::string_view name, serialization::StateDict& state_dict) { + for (auto& [key, value] : state_dict) { + if (std::holds_alternative(value)) { + file.get(std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_ttnn_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_autograd_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_named_parameters(file, std::string(name) + "/" + key, std::get(value)); + } else { + throw std::runtime_error("Unsupported type in state dict"); + } + } +} + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/serialization.hpp b/tt-train/sources/ttml/serialization/serialization.hpp index 1d4198e9996..6eee8247b53 100644 --- a/tt-train/sources/ttml/serialization/serialization.hpp +++ b/tt-train/sources/ttml/serialization/serialization.hpp @@ -23,8 +23,9 @@ void write_autograd_tensor( MsgPackFile& file, std::string_view name, const ttml::autograd::TensorPtr& tensor, bool save_grads = false); void read_autograd_tensor(MsgPackFile& file, std::string_view name, ttml::autograd::TensorPtr& tensor); -void write_named_parameters(MsgPackFile& file, std::string_view name, const ttml::autograd::NamedParameters& params); -void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autograd::NamedParameters& params); +void write_named_parameters( + MsgPackFile& file, std::string_view name, const ttml::serialization::NamedParameters& params); +void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::serialization::NamedParameters& params); void write_optimizer(MsgPackFile& file, std::string_view name, const optimizers::OptimizerBase* optimizer); void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::OptimizerBase* optimizer); @@ -32,4 +33,7 @@ void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::Optimi void write_module(MsgPackFile& file, std::string_view name, const autograd::ModuleBase* module); void read_module(MsgPackFile& file, std::string_view name, autograd::ModuleBase* module); +void write_state_dict(MsgPackFile& file, std::string_view name, const serialization::StateDict& state_dict); +void read_state_dict(MsgPackFile& file, std::string_view name, serialization::StateDict& state_dict); + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp index dee98552ef6..564c985c198 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include namespace ttml::ttnn_fixed { diff --git a/tt-train/tests/autograd/module_base_parameters_test.cpp b/tt-train/tests/autograd/module_base_parameters_test.cpp index 1edbf7d212e..3275c0a0a54 100644 --- a/tt-train/tests/autograd/module_base_parameters_test.cpp +++ b/tt-train/tests/autograd/module_base_parameters_test.cpp @@ -8,6 +8,7 @@ #include #include "autograd/module_base.hpp" +#include "core/tt_tensor_utils.hpp" #include "modules/dropout_module.hpp" #include "modules/layer_norm_module.hpp" #include "modules/linear_module.hpp" diff --git a/tt-train/tests/model/linear_regression_full_test.cpp b/tt-train/tests/model/linear_regression_full_test.cpp index 1af4f315405..0915b05abf7 100644 --- a/tt-train/tests/model/linear_regression_full_test.cpp +++ b/tt-train/tests/model/linear_regression_full_test.cpp @@ -8,6 +8,7 @@ #include #include "autograd/auto_context.hpp" +#include "core/tt_tensor_utils.hpp" #include "modules/linear_module.hpp" #include "ops/losses.hpp" #include "optimizers/sgd.hpp" diff --git a/tt-train/tests/schedulers/schedulers_test.cpp b/tt-train/tests/schedulers/schedulers_test.cpp new file mode 100644 index 00000000000..5fc6d8ae80d --- /dev/null +++ b/tt-train/tests/schedulers/schedulers_test.cpp @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "core/not_null.hpp" +#include "optimizers/optimizer_base.hpp" +#include "schedulers/lambda_scheduler.hpp" +#include "schedulers/linear_scheduler.hpp" +#include "schedulers/sequential_scheduler.hpp" +#include "schedulers/step_scheduler.hpp" + +namespace ttml::optimizers { +class MockOptimizer : public OptimizerBase { +public: + explicit MockOptimizer(float lr) : OptimizerBase(ttml::serialization::NamedParameters{}), m_lr(lr) { + } + + void zero_grad() override {}; + + void step() override {}; + + [[nodiscard]] serialization::StateDict get_state_dict() const override { + return {}; + } + + void set_state_dict(const serialization::StateDict &dict) override {}; + + [[nodiscard]] size_t get_steps() const override { + return {}; + }; + void set_steps(size_t steps) override {}; + + void set_lr(float lr) override { + m_lr = lr; + } + + [[nodiscard]] float get_lr() const override { + return m_lr; + } + +private: + float m_lr = 0; +}; +} // namespace ttml::optimizers + +// ---------------------------------- +// Tests for LambdaScheduler +// ---------------------------------- +TEST(LambdaSchedulerTest, ConstantFactor) { + auto optimizer = std::make_unique(0.1F); + + // Lambda that keeps the LR constant + // The learning rate of each parameter group is set to the initial lr times a given function. When last_epoch=-1, + // sets initial lr as lr. + + ttml::schedulers::LambdaScheduler scheduler(optimizer.get(), [](int epoch) { + (void)epoch; + return 0.5F; + }); + + // Initial LR + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + scheduler.step(); // epoch 0 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); + + scheduler.step(); // epoch 1 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); + + scheduler.step(); // epoch 2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); +} + +TEST(LambdaSchedulerTest, VaryingFactor) { + auto optimizer = std::make_unique(1.0f); + + // Lambda: lr_factor = 1.0 / (epoch+1) + ttml::schedulers::LambdaScheduler scheduler(optimizer.get(), [](int epoch) { return 1.0F / (epoch + 1); }); + + // Before stepping + EXPECT_FLOAT_EQ(optimizer->get_lr(), 1.0F); + + scheduler.step(); // epoch 0: factor = 1/1=0.5F + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.5F); + + scheduler.step(); // epoch 1: factor = 1/2=0.5 lr=1.0*0.5=0.5 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 1.F / 3.F); + + scheduler.step(); // epoch 2: factor = 1/3≈0.3333 lr=1.0*0.3333=0.3333 + EXPECT_NEAR(optimizer->get_lr(), 1.F / 4.F, 1e-5); + + scheduler.step(); // epoch 3: factor = 1/5=0.2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F); +} + +// ---------------------------------- +// Tests for StepLRScheduler +// ---------------------------------- +TEST(StepLRSchedulerTest, BasicDecay) { + auto optimizer = std::make_unique(0.2F); + + // Decrease LR by factor of 0.1 every 3 steps + ttml::schedulers::StepScheduler scheduler(optimizer.get(), 3, 0.1F); + + for (int i = 0; i < 3; ++i) { + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F); + scheduler.step(); + } + + for (int i = 0; i < 3; ++i) { + // After 3 steps: lr = base_lr * 0.1 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F * 0.1F); + scheduler.step(); + } + // After 6 steps: lr = base_lr * 0.1^2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F * 0.1F * 0.1F); +} + +// ---------------------------------- +// Tests for LinearScheduler +// ---------------------------------- +TEST(LinearSchedulerTest, DecreasingLR) { + auto optimizer = std::make_unique(0.2F); + + // Linearly go from 0.2 to 0.0 in 4 steps + ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 1.0F, 0.0F, 4); + + // step 1: progress = 1/4=0.25 lr = 0.2 + (0.0-0.2)*0.25 = 0.2 - 0.05=0.15 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.15F); + + // step 2: progress=0.5 lr=0.2+(0.0-0.2)*0.5=0.2-0.1=0.1 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + // step 3: progress=0.75 lr=0.2+(0.0-0.2)*0.75=0.2-0.15=0.05 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.05F); + + // step 4: progress=1.0 lr=0.2+(0.0-0.2)*1.0=0.0 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.0f); + + // Extra steps keep it at 0.0 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.0f); +} + +// ---------------------------------- +// Tests for SequentialScheduler +// ---------------------------------- +TEST(SequentialSchedulerTest, ChainSchedulers) { + auto optimizer = std::make_unique(1.0f); + + // First: StepLRScheduler for 3 steps (gamma=0.5 every step_size=1) + auto step_scheduler = std::make_unique(optimizer.get(), 1, 0.5F); + + // Then: LinearScheduler for 2 steps from current LR to 0.1 + auto linear_scheduler = std::make_unique(optimizer.get(), 1.0F, 0.1F, 2); + + std::vector> schedulers; + std::vector milestones; + schedulers.push_back(std::move(step_scheduler)); + schedulers.push_back(std::move(linear_scheduler)); + milestones.push_back(3); + milestones.push_back(2); + ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones)); + + // Initial LR = 1.0 + // Run StepLRScheduler for 3 steps: + // step_scheduler: every step reduces LR by factor 0.5 + seq_scheduler.step(); // 1st step: LR=1.0*0.5=0.5 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.5F); + + seq_scheduler.step(); // 2nd step: LR=0.5*0.5=0.25 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.25F); + + seq_scheduler.step(); // 3rd step: LR=0.25*0.5=0.125 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.125F); + + // total_steps=2, start_lr=0.125, end_lr=0.1 + // step 1: progress=1/2=0.5 lr=1.0+(0.1-1.0)*0.5=0.55 + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), 0.55, 1e-5); + + // step 2: progress=2/2=1.0 lr=1.0+(0.1-1.0)*1.0=0.1 (min lr in linear scheduler) + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + // Further steps do nothing (we finished all schedulers) + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); +} + +TEST(SequentialSchedulerTest, WarmupSetup) { + auto start_lr = 3.e-4F; + auto optimizer = std::make_unique(start_lr); + + // First: LinearScheduler for 10 steps from 0 to start_lr + auto warmup_scheduler = std::make_unique(optimizer.get(), 0.0F, 1.0F, 10); + + // Then: LinearScheduler for 50 steps from start_lr to 0.1F * start_lr + auto linear_scheduler = std::make_unique(optimizer.get(), 1.F, 0.1F, 50); + + std::vector> schedulers; + std::vector milestones; + schedulers.push_back(std::move(warmup_scheduler)); + schedulers.push_back(std::move(linear_scheduler)); + milestones.push_back(10); + milestones.push_back(50); + ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones)); + + for (int i = 0; i < 10; i++) { + // Linear warmup: 10 steps from 0 to start_lr + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), start_lr * (i + 1) / 10, 1e-5); + } + for (int i = 0; i < 50; i++) { + // Linear decay: 50 steps from start_lr to 0.1F * start_lr + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), start_lr * (1.0F - 0.9F * (i + 1) / 50.F), 1e-5); + } +} From 9f83120d2b7b389439f0227a9612bbaa9b7de262 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 11 Dec 2024 06:59:34 +0000 Subject: [PATCH 06/13] #15944: Add temporary api for automatically creating a sub-device manager with a fabric sub-device. Remove create_sub_device_manager overload of MeshDevice that accepts a diffent sub-device config per device This was added due to eth cores being different per device, but this is now handled internally --- tests/ttnn/unit_tests/test_sub_device.py | 36 ++--- tt_metal/distributed/mesh_device.cpp | 12 +- tt_metal/distributed/mesh_device.hpp | 5 +- tt_metal/impl/device/device.cpp | 14 ++ tt_metal/impl/device/device.hpp | 5 + .../impl/sub_device/sub_device_manager.cpp | 144 +++++++++--------- .../impl/sub_device/sub_device_manager.hpp | 7 + .../ttnn/distributed/distributed_pybind.cpp | 17 ++- 8 files changed, 138 insertions(+), 102 deletions(-) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index f7bfb20401a..2eb998b720a 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -7,7 +7,7 @@ import ttnn -def run_sub_devices(device, replicate_sub_devices=False): +def run_sub_devices(device, create_fabric_sub_device=False): tensix_cores0 = ttnn.CoreRangeSet( { ttnn.CoreRange( @@ -28,12 +28,12 @@ def run_sub_devices(device, replicate_sub_devices=False): sub_device_2 = ttnn.SubDevice([tensix_cores1]) sub_devices_1 = [sub_device_1, sub_device_2] sub_devices_2 = [sub_device_2] - if replicate_sub_devices: - num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() - sub_devices_1 = [sub_devices_1] * num_devices - sub_devices_2 = [sub_devices_2] * num_devices - sub_device_manager1 = device.create_sub_device_manager(sub_devices_1, 3200) - sub_device_manager2 = device.create_sub_device_manager(sub_devices_2, 3200) + if create_fabric_sub_device: + sub_device_manager1 = device.create_sub_device_manager_with_fabric(sub_devices_1, 3200) + sub_device_manager2 = device.create_sub_device_manager_with_fabric(sub_devices_2, 3200) + else: + sub_device_manager1 = device.create_sub_device_manager(sub_devices_1, 3200) + sub_device_manager2 = device.create_sub_device_manager(sub_devices_2, 3200) device.load_sub_device_manager(sub_device_manager1) ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(1)]) ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0), ttnn.SubDeviceId(1)]) @@ -45,7 +45,7 @@ def run_sub_devices(device, replicate_sub_devices=False): device.remove_sub_device_manager(sub_device_manager2) -def run_sub_devices_program(device, replicate_sub_devices=False): +def run_sub_devices_program(device, create_fabric_sub_device=False): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) @@ -74,10 +74,10 @@ def run_sub_devices_program(device, replicate_sub_devices=False): sub_device_1 = ttnn.SubDevice([tensix_cores0]) sub_device_2 = ttnn.SubDevice([tensix_cores1]) sub_devices = [sub_device_1, sub_device_2] - if replicate_sub_devices: - num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() - sub_devices = [sub_devices] * num_devices - sub_device_manager = device.create_sub_device_manager(sub_devices, 3200) + if create_fabric_sub_device: + sub_device_manager = device.create_sub_device_manager_with_fabric(sub_devices, 3200) + else: + sub_device_manager = device.create_sub_device_manager(sub_devices, 3200) device.load_sub_device_manager(sub_device_manager) x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16) @@ -140,9 +140,9 @@ def test_sub_devices(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -@pytest.mark.parametrize("replicate_sub_devices", (False, True)) -def test_sub_devices_mesh(mesh_device, replicate_sub_devices, enable_async_mode): - run_sub_devices(mesh_device, replicate_sub_devices) +@pytest.mark.parametrize("create_fabric_sub_device", (False, True)) +def test_sub_devices_mesh(mesh_device, create_fabric_sub_device, enable_async_mode): + run_sub_devices(mesh_device, create_fabric_sub_device) @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) @@ -151,6 +151,6 @@ def test_sub_device_program(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -@pytest.mark.parametrize("replicate_sub_devices", (False, True)) -def test_sub_device_program_mesh(mesh_device, replicate_sub_devices, enable_async_mode): - run_sub_devices_program(mesh_device, replicate_sub_devices) +@pytest.mark.parametrize("create_fabric_sub_device", (False, True)) +def test_sub_device_program_mesh(mesh_device, create_fabric_sub_device, enable_async_mode): + run_sub_devices_program(mesh_device, create_fabric_sub_device) diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index 6971abd948e..fb709c39901 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -490,21 +490,21 @@ MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span>& mesh_sub_devices, DeviceAddr local_l1_size) { +std::tuple MeshDevice::create_sub_device_manager_with_fabric(tt::stl::Span sub_devices, DeviceAddr local_l1_size) { MeshSubDeviceManagerId mesh_sub_device_manager_id(*this); - TT_FATAL(mesh_sub_devices.size() == this->num_devices(), "Number of devices does not match number of sub-device configurations"); + SubDeviceId fabric_sub_device_id; for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; - tt::stl::Span sub_devices(mesh_sub_devices[i]); - device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() { - sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size); + // All fabric sub-device ids will be the same, since all managers are created with the same sub_devices input + device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id, &fabric_sub_device_id]() { + std::tie(sub_device_manager_id, fabric_sub_device_id) = device->create_sub_device_manager_with_fabric(sub_devices, local_l1_size); }); } for (auto* device : this->devices) { device->synchronize(); } - return mesh_sub_device_manager_id; + return {mesh_sub_device_manager_id, fabric_sub_device_id}; } void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index a7727fb97bd..01a63d2e286 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -137,8 +137,9 @@ class MeshDevice : public std::enable_shared_from_this { MeshSubDeviceManagerId create_sub_device_manager( tt::stl::Span sub_devices, DeviceAddr local_l1_size); - MeshSubDeviceManagerId create_sub_device_manager( - const std::vector>& mesh_sub_devices, DeviceAddr local_l1_size); + // TODO #15944: Temporary api until migration to actual fabric is complete + std::tuple create_sub_device_manager_with_fabric( + tt::stl::Span sub_devices, DeviceAddr local_l1_size); void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); void clear_loaded_sub_device_manager(); void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 5c2875b4885..7a708b760d4 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -3724,6 +3724,20 @@ SubDeviceManagerId Device::create_sub_device_manager(tt::stl::Spanfirst; } +std::tuple Device::create_sub_device_manager_with_fabric(tt::stl::Span sub_devices, DeviceAddr local_l1_size) { + auto fabric_sub_device = SubDevice(std::array{CoreRangeSet(), this->default_sub_device_manager_->sub_device(SubDeviceId{0}).cores(HalProgrammableCoreType::ACTIVE_ETH)}); + auto new_sub_devices = std::vector(sub_devices.begin(), sub_devices.end()); + new_sub_devices.push_back(fabric_sub_device); + auto fabric_sub_device_id = SubDeviceId{static_cast(new_sub_devices.size() - 1)}; + auto sub_device_manager_id = this->create_sub_device_manager(new_sub_devices, local_l1_size); + this->sub_device_managers_[sub_device_manager_id]->set_fabric_sub_device_id(fabric_sub_device_id); + return {sub_device_manager_id, fabric_sub_device_id}; +} + +std::optional Device::get_fabric_sub_device_id() const { + return this->active_sub_device_manager_->fabric_sub_device_id(); +} + void Device::load_sub_device_manager(SubDeviceManagerId sub_device_manager_id) { TT_FATAL(!this->using_slow_dispatch(), "Using sub device managers is unsupported with slow dispatch"); if (this->active_sub_device_manager_id_ == sub_device_manager_id) { diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index a8cdb1f23b0..7fbe89bd853 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -389,6 +389,11 @@ class Device { void clear_loaded_sub_device_manager(); void remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id); const std::vector &get_sub_device_ids() const; + + // TODO #15944: Temporary api until migration to actual fabric is complete + std::tuple create_sub_device_manager_with_fabric(tt::stl::Span sub_devices, DeviceAddr local_l1_size); + std::optional get_fabric_sub_device_id() const; + private: void initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, tt::stl::Span l1_bank_remap); SubDeviceManagerId get_next_sub_device_manager_id(); diff --git a/tt_metal/impl/sub_device/sub_device_manager.cpp b/tt_metal/impl/sub_device/sub_device_manager.cpp index c1500f85064..2c4706590c8 100644 --- a/tt_metal/impl/sub_device/sub_device_manager.cpp +++ b/tt_metal/impl/sub_device/sub_device_manager.cpp @@ -38,16 +38,16 @@ SubDeviceManager::SubDeviceManager( SubDeviceManager::SubDeviceManager(Device* device, std::unique_ptr&& global_allocator) : device_(device) { TT_ASSERT(device != nullptr, "Device must not be null"); - this->local_l1_size_ = 0; - const auto& compute_grid_size = this->device_->compute_with_storage_grid_size(); - const auto& active_eth_cores = this->device_->get_active_ethernet_cores(true); + local_l1_size_ = 0; + const auto& compute_grid_size = device_->compute_with_storage_grid_size(); + const auto& active_eth_cores = device_->get_active_ethernet_cores(true); std::vector active_eth_core_ranges; active_eth_core_ranges.reserve(active_eth_cores.size()); for (const auto& core : active_eth_cores) { active_eth_core_ranges.emplace_back(core, core); } - this->sub_devices_ = {SubDevice(std::array{ + sub_devices_ = {SubDevice(std::array{ CoreRangeSet(CoreRange({0, 0}, {compute_grid_size.x - 1, compute_grid_size.y - 1})), CoreRangeSet(std::move(active_eth_core_ranges))})}; this->populate_sub_device_ids(); @@ -59,7 +59,7 @@ SubDeviceManager::SubDeviceManager(Device* device, std::unique_ptr&& } SubDeviceManager::~SubDeviceManager() { - for (const auto& allocator : this->sub_device_allocators_) { + for (const auto& allocator : sub_device_allocators_) { if (allocator) { // Clear the bank managers, this makes subsequent buffer deallocations fast allocator::clear(*allocator); @@ -73,9 +73,9 @@ SubDeviceManager::~SubDeviceManager() { } } -uint8_t SubDeviceManager::num_sub_devices() const { return this->sub_devices_.size(); } +uint8_t SubDeviceManager::num_sub_devices() const { return sub_devices_.size(); } -const std::vector& SubDeviceManager::get_sub_device_ids() const { return this->sub_device_ids_; } +const std::vector& SubDeviceManager::get_sub_device_ids() const { return sub_device_ids_; } const SubDevice& SubDeviceManager::sub_device(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); @@ -88,46 +88,46 @@ const vector_memcpy_aligned& SubDeviceManager::noc_mcast_unicast_data( uint8_t SubDeviceManager::num_noc_mcast_txns(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->num_noc_mcast_txns_[sub_device_index]; + return num_noc_mcast_txns_[sub_device_index]; } uint8_t SubDeviceManager::num_noc_unicast_txns(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->num_noc_unicast_txns_[sub_device_index]; + return num_noc_unicast_txns_[sub_device_index]; } uint8_t SubDeviceManager::noc_mcast_data_start_index(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->noc_mcast_data_start_index_[sub_device_index]; + return noc_mcast_data_start_index_[sub_device_index]; } uint8_t SubDeviceManager::noc_unicast_data_start_index(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->noc_unicast_data_start_index_[sub_device_index]; + return noc_unicast_data_start_index_[sub_device_index]; } const std::unique_ptr& SubDeviceManager::get_initialized_allocator(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); - TT_FATAL(this->sub_device_allocators_[sub_device_index], "SubDevice allocator not initialized"); - return this->sub_device_allocators_[sub_device_index]; + TT_FATAL(sub_device_allocators_[sub_device_index], "SubDevice allocator not initialized"); + return sub_device_allocators_[sub_device_index]; } std::unique_ptr& SubDeviceManager::sub_device_allocator(SubDeviceId sub_device_id) { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->sub_device_allocators_[sub_device_index]; + return sub_device_allocators_[sub_device_index]; } std::shared_ptr& SubDeviceManager::create_trace(uint32_t tid) { - auto [trace, emplaced] = this->trace_buffer_pool_.emplace(tid, Trace::create_empty_trace_buffer()); + auto [trace, emplaced] = trace_buffer_pool_.emplace(tid, Trace::create_empty_trace_buffer()); TT_ASSERT(emplaced, "Trace buffer with tid {} already exists", tid); return trace->second; } -void SubDeviceManager::release_trace(uint32_t tid) { this->trace_buffer_pool_.erase(tid); } +void SubDeviceManager::release_trace(uint32_t tid) { trace_buffer_pool_.erase(tid); } std::shared_ptr SubDeviceManager::get_trace(uint32_t tid) { - auto trace = this->trace_buffer_pool_.find(tid); - if (trace != this->trace_buffer_pool_.end()) { + auto trace = trace_buffer_pool_.find(tid); + if (trace != trace_buffer_pool_.end()) { return trace->second; } return nullptr; @@ -135,18 +135,18 @@ std::shared_ptr SubDeviceManager::get_trace(uint32_t tid) { void SubDeviceManager::reset_worker_launch_message_buffer_state() { std::for_each( - this->worker_launch_message_buffer_state_.begin(), - this->worker_launch_message_buffer_state_.end(), + worker_launch_message_buffer_state_.begin(), + worker_launch_message_buffer_state_.end(), std::mem_fn(&LaunchMessageRingBufferState::reset)); } LaunchMessageRingBufferState& SubDeviceManager::get_worker_launch_message_buffer_state(SubDeviceId sub_device_id) { auto sub_device_index = this->get_sub_device_index(sub_device_id); - return this->worker_launch_message_buffer_state_[sub_device_index]; + return worker_launch_message_buffer_state_[sub_device_index]; } bool SubDeviceManager::has_allocations() const { - for (const auto& allocator : this->sub_device_allocators_) { + for (const auto& allocator : sub_device_allocators_) { if (allocator && allocator->allocated_buffers.size() > 0) { return true; } @@ -154,25 +154,35 @@ bool SubDeviceManager::has_allocations() const { return false; } -DeviceAddr SubDeviceManager::local_l1_size() const { return this->local_l1_size_; } +DeviceAddr SubDeviceManager::local_l1_size() const { return local_l1_size_; } + +void SubDeviceManager::set_fabric_sub_device_id(SubDeviceId fabric_sub_device_id) { + const auto& fabric_sub_device = this->sub_device(fabric_sub_device_id); + TT_FATAL( + fabric_sub_device.cores(HalProgrammableCoreType::TENSIX).num_cores() == 0, + "Fabric sub device must not have Tensix cores"); + fabric_sub_device_id_ = fabric_sub_device_id; +} + +std::optional SubDeviceManager::fabric_sub_device_id() const { return fabric_sub_device_id_; } uint8_t SubDeviceManager::get_sub_device_index(SubDeviceId sub_device_id) const { auto sub_device_index = sub_device_id.to_index(); TT_FATAL( - sub_device_index < this->sub_devices_.size(), + sub_device_index < sub_devices_.size(), "SubDevice index {} out of bounds {}", sub_device_index, - this->sub_devices_.size()); + sub_devices_.size()); return sub_device_index; } void SubDeviceManager::validate_sub_devices() const { - TT_FATAL(this->sub_devices_.size() <= SubDeviceManager::MAX_NUM_SUB_DEVICES, "Too many sub devices specified"); + TT_FATAL(sub_devices_.size() <= SubDeviceManager::MAX_NUM_SUB_DEVICES, "Too many sub devices specified"); // Validate sub device cores fit inside the device grid - const auto& compute_grid_size = this->device_->compute_with_storage_grid_size(); + const auto& compute_grid_size = device_->compute_with_storage_grid_size(); CoreRange device_worker_cores = CoreRange({0, 0}, {compute_grid_size.x - 1, compute_grid_size.y - 1}); - const auto& device_eth_cores = this->device_->get_active_ethernet_cores(true); - for (const auto& sub_device : this->sub_devices_) { + const auto& device_eth_cores = device_->get_active_ethernet_cores(true); + for (const auto& sub_device : sub_devices_) { const auto& worker_cores = sub_device.cores(HalProgrammableCoreType::TENSIX); TT_FATAL( device_worker_cores.contains(worker_cores), @@ -191,15 +201,15 @@ void SubDeviceManager::validate_sub_devices() const { "Ethernet cores {} specified in sub device must be within device grid", eth_cores); } - if (this->sub_devices_.size() < 2) { + if (sub_devices_.size() < 2) { return; } // Validate no overlap of sub devices - for (uint32_t i = 0; i < this->sub_devices_.size(); ++i) { - for (uint32_t j = i + 1; j < this->sub_devices_.size(); ++j) { + for (uint32_t i = 0; i < sub_devices_.size(); ++i) { + for (uint32_t j = i + 1; j < sub_devices_.size(); ++j) { for (uint32_t k = 0; k < NumHalProgrammableCoreTypes; ++k) { TT_FATAL( - !(this->sub_devices_[i].cores()[k].intersects(this->sub_devices_[j].cores()[k])), + !(sub_devices_[i].cores()[k].intersects(sub_devices_[j].cores()[k])), "SubDevices specified for SubDeviceManager intersect"); } } @@ -207,33 +217,33 @@ void SubDeviceManager::validate_sub_devices() const { } void SubDeviceManager::populate_sub_device_ids() { - this->sub_device_ids_.resize(this->num_sub_devices()); + sub_device_ids_.resize(this->num_sub_devices()); for (uint8_t i = 0; i < this->num_sub_devices(); ++i) { - this->sub_device_ids_[i] = SubDeviceId{i}; + sub_device_ids_[i] = SubDeviceId{i}; } } void SubDeviceManager::populate_num_cores() { - for (const auto& sub_device : this->sub_devices_) { + for (const auto& sub_device : sub_devices_) { for (uint32_t i = 0; i < NumHalProgrammableCoreTypes; ++i) { - this->num_cores_[i] += sub_device.num_cores(static_cast(i)); + num_cores_[i] += sub_device.num_cores(static_cast(i)); } } } void SubDeviceManager::populate_sub_allocators() { - this->sub_device_allocators_.resize(this->num_sub_devices()); - if (this->local_l1_size_ == 0) { + sub_device_allocators_.resize(this->num_sub_devices()); + if (local_l1_size_ == 0) { return; } - const auto& global_allocator_config = this->device_->get_initialized_allocator()->config; + const auto& global_allocator_config = device_->get_initialized_allocator()->config; // Construct allocator config from soc_desc // Take max alignment to satisfy NoC rd/wr constraints // Tensix/Eth -> PCIe/DRAM src and dst addrs must be L1_ALIGNMENT aligned // PCIe/DRAM -> Tensix/Eth src and dst addrs must be DRAM_ALIGNMENT aligned // Tensix/Eth <-> Tensix/Eth src and dst addrs must be L1_ALIGNMENT aligned for (uint32_t i = 0; i < this->num_sub_devices(); ++i) { - const auto& compute_cores = this->sub_devices_[i].cores(HalProgrammableCoreType::TENSIX); + const auto& compute_cores = sub_devices_[i].cores(HalProgrammableCoreType::TENSIX); if (compute_cores.empty()) { continue; } @@ -243,7 +253,7 @@ void SubDeviceManager::populate_sub_allocators() { l1_bank_remap.reserve(compute_cores_vec.size()); for (const auto& core : compute_cores_vec) { // These are compute cores, so they should have a single bank - l1_bank_remap.push_back(this->device_->bank_ids_from_logical_core(BufferType::L1, core)[0]); + l1_bank_remap.push_back(device_->bank_ids_from_logical_core(BufferType::L1, core)[0]); } AllocatorConfig config( {.num_dram_channels = global_allocator_config.num_dram_channels, @@ -252,7 +262,7 @@ void SubDeviceManager::populate_sub_allocators() { .dram_unreserved_base = global_allocator_config.dram_unreserved_base, .l1_unreserved_base = global_allocator_config.l1_unreserved_base, .worker_grid = compute_cores, - .worker_l1_size = global_allocator_config.l1_unreserved_base + this->local_l1_size_, + .worker_l1_size = global_allocator_config.l1_unreserved_base + local_l1_size_, .storage_core_bank_size = std::nullopt, .l1_small_size = 0, .trace_region_size = 0, @@ -275,53 +285,51 @@ void SubDeviceManager::populate_sub_allocators() { // sub_devices only have compute cores for allocation for (const CoreCoord& core : corerange_to_cores(compute_cores)) { - const auto noc_coord = this->device_->worker_core_from_logical_core(core); + const auto noc_coord = device_->worker_core_from_logical_core(core); config.core_type_from_noc_coord_table.insert({noc_coord, AllocCoreType::ComputeAndStore}); } // L1_BANKING scheme creates 1 bank per DRAM core and splits up L1 such that there are power 2 num L1 banks // This is the only allocator scheme supported because kernel APIs assume num L1 banks are power of 2 - TT_ASSERT(this->device_->allocator_scheme_ == MemoryAllocator::L1_BANKING); - this->sub_device_allocators_[i] = std::make_unique(config); + TT_ASSERT(device_->allocator_scheme_ == MemoryAllocator::L1_BANKING); + sub_device_allocators_[i] = std::make_unique(config); } } void SubDeviceManager::populate_noc_data() { uint32_t num_sub_devices = this->num_sub_devices(); - this->num_noc_mcast_txns_.resize(num_sub_devices); - this->num_noc_unicast_txns_.resize(num_sub_devices); - this->noc_mcast_data_start_index_.resize(num_sub_devices); - this->noc_unicast_data_start_index_.resize(num_sub_devices); + num_noc_mcast_txns_.resize(num_sub_devices); + num_noc_unicast_txns_.resize(num_sub_devices); + noc_mcast_data_start_index_.resize(num_sub_devices); + noc_unicast_data_start_index_.resize(num_sub_devices); - NOC noc_index = this->device_->dispatch_go_signal_noc(); + NOC noc_index = device_->dispatch_go_signal_noc(); uint32_t idx = 0; for (uint32_t i = 0; i < num_sub_devices; ++i) { - const auto& tensix_cores = this->sub_devices_[i].cores(HalProgrammableCoreType::TENSIX); - const auto& eth_cores = this->sub_devices_[i].cores(HalProgrammableCoreType::ACTIVE_ETH); + const auto& tensix_cores = sub_devices_[i].cores(HalProgrammableCoreType::TENSIX); + const auto& eth_cores = sub_devices_[i].cores(HalProgrammableCoreType::ACTIVE_ETH); - this->noc_mcast_data_start_index_[i] = idx; - this->num_noc_mcast_txns_[i] = tensix_cores.size(); - this->noc_mcast_unicast_data_.resize(idx + this->num_noc_mcast_txns_[i] * 2); + noc_mcast_data_start_index_[i] = idx; + num_noc_mcast_txns_[i] = tensix_cores.size(); + noc_mcast_unicast_data_.resize(idx + num_noc_mcast_txns_[i] * 2); for (const auto& core_range : tensix_cores.ranges()) { - auto virtual_start = - this->device_->virtual_core_from_logical_core(core_range.start_coord, CoreType::WORKER); - auto virtual_end = this->device_->virtual_core_from_logical_core(core_range.end_coord, CoreType::WORKER); + auto virtual_start = device_->virtual_core_from_logical_core(core_range.start_coord, CoreType::WORKER); + auto virtual_end = device_->virtual_core_from_logical_core(core_range.end_coord, CoreType::WORKER); auto virtual_core_range = CoreRange(virtual_start, virtual_end); - this->noc_mcast_unicast_data_[idx++] = - this->device_->get_noc_multicast_encoding(noc_index, virtual_core_range); - this->noc_mcast_unicast_data_[idx++] = core_range.size(); + noc_mcast_unicast_data_[idx++] = device_->get_noc_multicast_encoding(noc_index, virtual_core_range); + noc_mcast_unicast_data_[idx++] = core_range.size(); } - this->noc_unicast_data_start_index_[i] = idx; + noc_unicast_data_start_index_[i] = idx; // TODO: Precompute number of eth cores and resize once for (const auto& core_range : eth_cores.ranges()) { - this->noc_mcast_unicast_data_.resize(idx + core_range.size()); + noc_mcast_unicast_data_.resize(idx + core_range.size()); for (const auto& core : core_range) { - auto virtual_core = this->device_->virtual_core_from_logical_core(core, CoreType::ETH); - this->noc_mcast_unicast_data_[idx++] = this->device_->get_noc_unicast_encoding(noc_index, virtual_core); + auto virtual_core = device_->virtual_core_from_logical_core(core, CoreType::ETH); + noc_mcast_unicast_data_[idx++] = device_->get_noc_unicast_encoding(noc_index, virtual_core); } } - this->num_noc_unicast_txns_[i] = idx - this->noc_unicast_data_start_index_[i]; + num_noc_unicast_txns_[i] = idx - noc_unicast_data_start_index_[i]; TT_FATAL( idx <= dispatch_constants::DISPATCH_GO_SIGNAL_NOC_DATA_ENTRIES, @@ -332,7 +340,7 @@ void SubDeviceManager::populate_noc_data() { } void SubDeviceManager::populate_worker_launch_message_buffer_state() { - this->worker_launch_message_buffer_state_.resize(this->num_sub_devices()); + worker_launch_message_buffer_state_.resize(this->num_sub_devices()); this->reset_worker_launch_message_buffer_state(); } diff --git a/tt_metal/impl/sub_device/sub_device_manager.hpp b/tt_metal/impl/sub_device/sub_device_manager.hpp index 90c98090814..356555dfff9 100644 --- a/tt_metal/impl/sub_device/sub_device_manager.hpp +++ b/tt_metal/impl/sub_device/sub_device_manager.hpp @@ -68,6 +68,10 @@ class SubDeviceManager { bool has_allocations() const; DeviceAddr local_l1_size() const; + // #TODO #15944: Temporary until migration to actual fabric is complete + void set_fabric_sub_device_id(SubDeviceId sub_device_id); + std::optional fabric_sub_device_id() const; + private: void validate_sub_devices() const; uint8_t get_sub_device_index(SubDeviceId sub_device_id) const; @@ -97,6 +101,9 @@ class SubDeviceManager { std::unordered_map> trace_buffer_pool_; std::vector worker_launch_message_buffer_state_; + + // TODO #15944: Temporary until migration to actual fabric is complete + std::optional fabric_sub_device_id_ = std::nullopt; }; } // namespace detail diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ed946f23d9b..71114e39ee5 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -197,23 +197,24 @@ void py_module(py::module& module) { MeshSubDeviceManagerId: The ID of the created sub-device manager. )doc") .def( - "create_sub_device_manager", - [](MeshDevice& self, - const std::vector>& mesh_sub_devices, - DeviceAddr local_l1_size) { return self.create_sub_device_manager(mesh_sub_devices, local_l1_size); }, + "create_sub_device_manager_with_fabric", + [](MeshDevice& self, const std::vector& sub_devices, DeviceAddr local_l1_size) { + return self.create_sub_device_manager(sub_devices, local_l1_size); + }, py::arg("sub_devices"), py::arg("local_l1_size"), R"doc( - Creates a sub-device manager for the given mesh device. + Creates a sub-device manager for the given mesh device. This will automatically create a sub-device of ethernet cores for use with fabric. + Note that this is a temporary API until migration to actual fabric is complete. Args: - mesh_sub_devices (List[List[ttnn.SubDevice]]): The sub-devices to include in the sub-device manager. - Each element of the outer list will be used to configure the corresponding device in the MeshDevice. - This means that the individual devices in the MeshDevice may have different configurations. + sub_devices (List[ttnn.SubDevice]): The sub-devices to include in the sub-device manager. No ethernet cores should be included in this list. + This configuration will be used for each device in the MeshDevice. local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. Returns: MeshSubDeviceManagerId: The ID of the created sub-device manager. + SubDeviceId: The ID of the sub-device that will be used for fabric. )doc") .def( "load_sub_device_manager", From 8926ed0357aecfef56c39b67c074a322bde16886 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 11 Dec 2024 21:57:52 +0000 Subject: [PATCH 07/13] #0: Update global semaphores and global circular buffers metal apis to be thread-safe instead of depending on ttnn apis --- .../tt_metal/api/test_global_semaphores.cpp | 5 +- .../ttnn/unit_tests/test_global_semaphore.py | 2 +- .../impl/buffers/global_circular_buffer.cpp | 112 ++++++++++-------- .../impl/buffers/global_circular_buffer.hpp | 33 ++++-- tt_metal/impl/buffers/global_semaphore.cpp | 48 +++++--- tt_metal/impl/buffers/global_semaphore.hpp | 54 +++++---- ttnn/cpp/pybind11/global_semaphore.cpp | 18 ++- ttnn/cpp/ttnn/global_semaphore.cpp | 14 ++- ttnn/cpp/ttnn/global_semaphore.hpp | 8 +- 9 files changed, 173 insertions(+), 121 deletions(-) diff --git a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp index f88e71efda9..58a0f987353 100644 --- a/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp +++ b/tests/tt_metal/tt_metal/api/test_global_semaphores.cpp @@ -86,6 +86,7 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { for (auto device : devices_) { { uint32_t initial_value = 1; + uint32_t reset_value = 2; std::vector overwrite_value = {2}; auto global_semaphore = tt::tt_metal::CreateGlobalSemaphore(device, cores, initial_value); auto address = global_semaphore->address(); @@ -104,14 +105,14 @@ TEST_F(DispatchFixture, ResetGlobalSemaphores) { EXPECT_EQ(sem_vals[0], overwrite_value[0]); } - global_semaphore->reset_semaphore_value(); + global_semaphore->reset_semaphore_value(reset_value); Synchronize(device); for (const auto& core : cores_vec) { auto sem_vals = tt::llrt::read_hex_vec_from_core( device->id(), device->worker_core_from_logical_core(core), address, sizeof(uint32_t)); tt::llrt::write_hex_vec_to_core( device->id(), device->worker_core_from_logical_core(core), overwrite_value, address); - EXPECT_EQ(sem_vals[0], initial_value); + EXPECT_EQ(sem_vals[0], reset_value); } } } diff --git a/tests/ttnn/unit_tests/test_global_semaphore.py b/tests/ttnn/unit_tests/test_global_semaphore.py index 24c6fa107de..32c17742c8b 100644 --- a/tests/ttnn/unit_tests/test_global_semaphore.py +++ b/tests/ttnn/unit_tests/test_global_semaphore.py @@ -29,7 +29,7 @@ def run_global_semaphore(device): assert ttnn.get_global_semaphore_address(global_sem0) != ttnn.get_global_semaphore_address(global_sem1) - ttnn.reset_global_semaphore_value(global_sem0) + ttnn.reset_global_semaphore_value(global_sem0, 3) @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) diff --git a/tt_metal/impl/buffers/global_circular_buffer.cpp b/tt_metal/impl/buffers/global_circular_buffer.cpp index 094670b2a30..02cc48b3a87 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.cpp +++ b/tt_metal/impl/buffers/global_circular_buffer.cpp @@ -28,7 +28,8 @@ GlobalCircularBuffer::GlobalCircularBuffer( const std::unordered_map& sender_receiver_core_mapping, uint32_t size, BufferType buffer_type, - tt::stl::Span sub_device_ids) : + tt::stl::Span sub_device_ids, + Private) : device_(device), sender_receiver_core_mapping_(sender_receiver_core_mapping), size_(size) { TT_FATAL(this->device_ != nullptr, "Device cannot be null"); uint32_t num_sender_cores = sender_receiver_core_mapping.size(); @@ -86,58 +87,65 @@ void GlobalCircularBuffer::setup_cb_buffers( shard_parameters, std::nullopt); - const auto& core_to_core_id = this->cb_config_buffer_->get_buffer_page_mapping()->core_to_core_id_; - - std::vector cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0); - uint32_t buffer_address = this->cb_buffer_->address(); - uint32_t noc_xy_address = this->cb_config_buffer_->address() + num_config_elements * sizeof(uint32_t); - uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment); - - for (const auto& [sender_core, receiver_cores] : this->sender_receiver_core_mapping_) { - const auto& receiver_cores_vec = corerange_to_cores(receiver_cores); - uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t); - uint32_t num_receivers = receiver_cores.num_cores(); - uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment; - cb_config_host_buffer[sender_idx++] = 1; - cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores(); - cb_config_host_buffer[sender_idx++] = buffer_address; - cb_config_host_buffer[sender_idx++] = this->size_; - cb_config_host_buffer[sender_idx++] = buffer_address; - cb_config_host_buffer[sender_idx++] = noc_xy_address; - cb_config_host_buffer[sender_idx++] = pages_sent_address; - - auto sender_physical_coord = this->device_->worker_core_from_logical_core(sender_core); - for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) { - auto receiver_physical_coord = this->device_->worker_core_from_logical_core(receiver_cores_vec[i]); - cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x; - cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y; - - uint32_t receiver_idx = core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t); - cb_config_host_buffer[receiver_idx++] = 0; - cb_config_host_buffer[receiver_idx++] = num_receivers; - cb_config_host_buffer[receiver_idx++] = buffer_address; - cb_config_host_buffer[receiver_idx++] = this->size_; - cb_config_host_buffer[receiver_idx++] = buffer_address; - cb_config_host_buffer[receiver_idx++] = noc_xy_address; - cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment; - cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x; - cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y; - } - } - // Write the config buffer to the device // Only block for the slow dispatch case - if (this->device_->using_slow_dispatch()) { - detail::WriteToBuffer(*this->cb_config_buffer_, cb_config_host_buffer); - tt::Cluster::instance().l1_barrier(this->device_->id()); - } else { - EnqueueWriteBuffer( - this->device_->command_queue(), - this->cb_config_buffer_, - cb_config_host_buffer.data(), - false, - sub_device_ids); - } + auto* device = this->device_; + device->push_work([device, + cb_config_size, + cb_config_page_size, + num_noc_xy_words, + l1_alignment, + buffer_address = this->cb_buffer_->address(), + cb_config_buffer = this->cb_config_buffer_, + size = this->size_, + sender_receiver_core_mapping = this->sender_receiver_core_mapping_, + sub_device_ids = std::vector(sub_device_ids.begin(), sub_device_ids.end())] { + auto config_buffer_address = cb_config_buffer->address(); + const auto& core_to_core_id = cb_config_buffer->get_buffer_page_mapping()->core_to_core_id_; + std::vector cb_config_host_buffer(cb_config_size / sizeof(uint32_t), 0); + uint32_t noc_xy_address = config_buffer_address + num_config_elements * sizeof(uint32_t); + uint32_t pages_sent_address = align(noc_xy_address + num_noc_xy_words * sizeof(uint32_t), l1_alignment); + + for (const auto& [sender_core, receiver_cores] : sender_receiver_core_mapping) { + const auto& receiver_cores_vec = corerange_to_cores(receiver_cores); + uint32_t sender_idx = core_to_core_id.at(sender_core) * cb_config_page_size / sizeof(uint32_t); + uint32_t num_receivers = receiver_cores.num_cores(); + uint32_t pages_acked_address = pages_sent_address + num_receivers * l1_alignment; + cb_config_host_buffer[sender_idx++] = 1; + cb_config_host_buffer[sender_idx++] = receiver_cores.num_cores(); + cb_config_host_buffer[sender_idx++] = buffer_address; + cb_config_host_buffer[sender_idx++] = size; + cb_config_host_buffer[sender_idx++] = buffer_address; + cb_config_host_buffer[sender_idx++] = noc_xy_address; + cb_config_host_buffer[sender_idx++] = pages_sent_address; + + auto sender_physical_coord = device->worker_core_from_logical_core(sender_core); + for (uint32_t i = 0; i < receiver_cores_vec.size(); i++) { + auto receiver_physical_coord = device->worker_core_from_logical_core(receiver_cores_vec[i]); + cb_config_host_buffer[sender_idx++] = receiver_physical_coord.x; + cb_config_host_buffer[sender_idx++] = receiver_physical_coord.y; + + uint32_t receiver_idx = + core_to_core_id.at(receiver_cores_vec[i]) * cb_config_page_size / sizeof(uint32_t); + cb_config_host_buffer[receiver_idx++] = 0; + cb_config_host_buffer[receiver_idx++] = num_receivers; + cb_config_host_buffer[receiver_idx++] = buffer_address; + cb_config_host_buffer[receiver_idx++] = size; + cb_config_host_buffer[receiver_idx++] = buffer_address; + cb_config_host_buffer[receiver_idx++] = noc_xy_address; + cb_config_host_buffer[receiver_idx++] = pages_sent_address + 2 * i * l1_alignment; + cb_config_host_buffer[receiver_idx++] = sender_physical_coord.x; + cb_config_host_buffer[receiver_idx++] = sender_physical_coord.y; + } + } + if (device->using_slow_dispatch()) { + detail::WriteToBuffer(*cb_config_buffer, cb_config_host_buffer); + tt::Cluster::instance().l1_barrier(device->id()); + } else { + EnqueueWriteBuffer( + device->command_queue(), cb_config_buffer, cb_config_host_buffer.data(), false, sub_device_ids); + } + }); } std::shared_ptr GlobalCircularBuffer::create( @@ -147,7 +155,7 @@ std::shared_ptr GlobalCircularBuffer::create( BufferType buffer_type, tt::stl::Span sub_device_ids) { return std::make_shared( - device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids); + device, sender_receiver_core_mapping, size, buffer_type, sub_device_ids, Private()); } const Buffer& GlobalCircularBuffer::cb_buffer() const { return *this->cb_buffer_; } diff --git a/tt_metal/impl/buffers/global_circular_buffer.hpp b/tt_metal/impl/buffers/global_circular_buffer.hpp index ca0c56da71f..96f8cfec73c 100644 --- a/tt_metal/impl/buffers/global_circular_buffer.hpp +++ b/tt_metal/impl/buffers/global_circular_buffer.hpp @@ -26,20 +26,11 @@ namespace v1 { namespace experimental { class GlobalCircularBuffer { -public: - GlobalCircularBuffer( - Device* device, - const std::unordered_map& sender_receiver_core_mapping, - uint32_t size, - BufferType buffer_type, - tt::stl::Span sub_device_ids); - - GlobalCircularBuffer(const GlobalCircularBuffer&) = default; - GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = default; - - GlobalCircularBuffer(GlobalCircularBuffer&&) noexcept = default; - GlobalCircularBuffer& operator=(GlobalCircularBuffer&&) noexcept = default; + struct Private { + explicit Private() = default; + }; +public: static std::shared_ptr create( Device* device, const std::unordered_map& sender_receiver_core_mapping, @@ -47,6 +38,12 @@ class GlobalCircularBuffer { BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); + GlobalCircularBuffer(const GlobalCircularBuffer&) = delete; + GlobalCircularBuffer& operator=(const GlobalCircularBuffer&) = delete; + + GlobalCircularBuffer(GlobalCircularBuffer&&) noexcept = delete; + GlobalCircularBuffer& operator=(GlobalCircularBuffer&&) noexcept = delete; + const Buffer& cb_buffer() const; const CoreRangeSet& sender_cores() const; @@ -59,6 +56,16 @@ class GlobalCircularBuffer { static constexpr auto attribute_names = std::forward_as_tuple("sender_receiver_core_mapping", "size"); const auto attribute_values() const { return std::make_tuple(this->sender_receiver_core_mapping_, this->size_); } + // "Private" constructor to prevent direct instantiation + // Use GlobalCircularBuffer::create instead + GlobalCircularBuffer( + Device* device, + const std::unordered_map& sender_receiver_core_mapping, + uint32_t size, + BufferType buffer_type, + tt::stl::Span sub_device_ids, + Private); + private: void setup_cb_buffers( BufferType buffer_type, uint32_t max_num_receivers_per_sender, tt::stl::Span sub_device_ids); diff --git a/tt_metal/impl/buffers/global_semaphore.cpp b/tt_metal/impl/buffers/global_semaphore.cpp index 57ef080d0f7..af976bd1a07 100644 --- a/tt_metal/impl/buffers/global_semaphore.cpp +++ b/tt_metal/impl/buffers/global_semaphore.cpp @@ -24,9 +24,10 @@ GlobalSemaphore::GlobalSemaphore( const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids) : - device_(device), cores_(cores), initial_value_(initial_value) { - this->setup_buffer(buffer_type, sub_device_ids); + tt::stl::Span sub_device_ids, + Private) : + device_(device), cores_(cores) { + this->setup_buffer(initial_value, buffer_type, sub_device_ids); } GlobalSemaphore::GlobalSemaphore( @@ -34,12 +35,14 @@ GlobalSemaphore::GlobalSemaphore( CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type, - tt::stl::Span sub_device_ids) : - device_(device), cores_(std::move(cores)), initial_value_(initial_value) { - this->setup_buffer(buffer_type, sub_device_ids); + tt::stl::Span sub_device_ids, + Private) : + device_(device), cores_(std::move(cores)) { + this->setup_buffer(initial_value, buffer_type, sub_device_ids); } -void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids) { +void GlobalSemaphore::setup_buffer( + uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { TT_FATAL( buffer_type == BufferType::L1 or buffer_type == BufferType::L1_SMALL, "Global semaphore can only be created for L1 buffer types"); @@ -58,8 +61,7 @@ void GlobalSemaphore::setup_buffer(BufferType buffer_type, tt::stl::Spanhost_buffer_ = std::vector(num_cores, this->initial_value_); - this->reset_semaphore_value(sub_device_ids); + this->reset_semaphore_value(initial_value, sub_device_ids); } std::shared_ptr GlobalSemaphore::create( @@ -68,7 +70,7 @@ std::shared_ptr GlobalSemaphore::create( uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids); + return std::make_shared(device, cores, initial_value, buffer_type, sub_device_ids, Private()); } std::shared_ptr GlobalSemaphore::create( Device* device, @@ -76,23 +78,31 @@ std::shared_ptr GlobalSemaphore::create( uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids) { - return std::make_shared(device, std::move(cores), initial_value, buffer_type, sub_device_ids); + return std::make_shared( + device, std::move(cores), initial_value, buffer_type, sub_device_ids, Private()); } Device* GlobalSemaphore::device() const { return device_; } DeviceAddr GlobalSemaphore::address() const { return buffer_->address(); } -void GlobalSemaphore::reset_semaphore_value(tt::stl::Span sub_device_ids) { +void GlobalSemaphore::reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids) { // Write the initial value to the semaphore to the device // Only block for the slow dispatch case - if (this->device_->using_slow_dispatch()) { - detail::WriteToBuffer(*this->buffer_, this->host_buffer_); - tt::Cluster::instance().l1_barrier(this->device_->id()); - } else { - EnqueueWriteBuffer( - this->device_->command_queue(), this->buffer_, this->host_buffer_.data(), false, sub_device_ids); - } + auto* device = this->device_; + device->push_work([device, + reset_value, + sub_device_ids = std::vector(sub_device_ids.begin(), sub_device_ids.end()), + num_cores = this->cores_.num_cores(), + buffer = this->buffer_] { + std::vector host_buffer(num_cores, reset_value); + if (device->using_slow_dispatch()) { + detail::WriteToBuffer(*buffer, host_buffer); + tt::Cluster::instance().l1_barrier(device->id()); + } else { + EnqueueWriteBuffer(device->command_queue(), buffer, host_buffer, false, sub_device_ids); + } + }); } } // namespace tt::tt_metal diff --git a/tt_metal/impl/buffers/global_semaphore.hpp b/tt_metal/impl/buffers/global_semaphore.hpp index 0d912b2f9ac..24e404a28e7 100644 --- a/tt_metal/impl/buffers/global_semaphore.hpp +++ b/tt_metal/impl/buffers/global_semaphore.hpp @@ -20,60 +20,66 @@ class Buffer; class Device; class GlobalSemaphore { + struct Private { + explicit Private() = default; + }; + public: - GlobalSemaphore( + static std::shared_ptr create( Device* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); - GlobalSemaphore( + static std::shared_ptr create( Device* device, CoreRangeSet&& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1, tt::stl::Span sub_device_ids = {}); - GlobalSemaphore(const GlobalSemaphore&) = default; - GlobalSemaphore& operator=(const GlobalSemaphore&) = default; + GlobalSemaphore(const GlobalSemaphore&) = delete; + GlobalSemaphore& operator=(const GlobalSemaphore&) = delete; - GlobalSemaphore(GlobalSemaphore&&) noexcept = default; - GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = default; + GlobalSemaphore(GlobalSemaphore&&) noexcept = delete; + GlobalSemaphore& operator=(GlobalSemaphore&&) noexcept = delete; - static std::shared_ptr create( + Device* device() const; + + DeviceAddr address() const; + + void reset_semaphore_value(uint32_t reset_value, tt::stl::Span sub_device_ids = {}); + + static constexpr auto attribute_names = std::forward_as_tuple("cores"); + const auto attribute_values() const { return std::make_tuple(this->cores_); } + + // "Private" constructor to prevent direct instantiation + // Use GlobalSemaphore::create instead + GlobalSemaphore( Device* device, const CoreRangeSet& cores, uint32_t initial_value, - BufferType buffer_type = BufferType::L1, - tt::stl::Span sub_device_ids = {}); + BufferType buffer_type, + tt::stl::Span sub_device_ids, + Private); - static std::shared_ptr create( + GlobalSemaphore( Device* device, CoreRangeSet&& cores, uint32_t initial_value, - BufferType buffer_type = BufferType::L1, - tt::stl::Span sub_device_ids = {}); - - Device* device() const; - - DeviceAddr address() const; - - void reset_semaphore_value(tt::stl::Span sub_device_ids = {}); - - static constexpr auto attribute_names = std::forward_as_tuple("cores", "initial_value"); - const auto attribute_values() const { return std::make_tuple(this->cores_, this->initial_value_); } + BufferType buffer_type, + tt::stl::Span sub_device_ids, + Private); private: - void setup_buffer(BufferType buffer_type, tt::stl::Span sub_device_ids); + void setup_buffer(uint32_t initial_value, BufferType buffer_type, tt::stl::Span sub_device_ids); // GlobalSemaphore is implemented as a wrapper around a sharded buffer // This can be updated in the future to be its own container with optimized dispatch functions std::shared_ptr buffer_; - std::vector host_buffer_; Device* device_; CoreRangeSet cores_; - uint32_t initial_value_ = 0; }; } // namespace v0 diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp index f6e44cb3419..726c960e166 100644 --- a/ttnn/cpp/pybind11/global_semaphore.cpp +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -57,15 +57,20 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast&, const std::vector&>( - &reset_global_semaphore_value), + [](const std::shared_ptr& global_semaphore, + uint32_t reset_value, + const std::vector& sub_device_ids) { + ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); + }, py::arg("global_semaphore"), + py::arg("reset_value"), py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + reset_value (int): The value to reset the global semaphore to. sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. Defaults to waiting on all sub-devices. )doc"); @@ -111,15 +116,20 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast&>( - &reset_global_semaphore_value), + [](const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + const std::vector& sub_device_ids) { + ttnn::global_semaphore::reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); + }, py::arg("global_semaphore"), + py::arg("reset_value"), py::arg("sub_device_ids") = std::vector(), R"doc( Reset the value of the global semaphore. Args: global_semaphore (GlobalSemaphore): The global semaphore object. + reset_value (int): The value to reset the global semaphore to. sub_device_ids (List[ttnn.SubDeviceIds]): Sub-device IDs to wait on before writing the global semaphore value to device. )doc"); } diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp index a74a4b350cc..777fe337b71 100644 --- a/ttnn/cpp/ttnn/global_semaphore.cpp +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -41,9 +41,13 @@ DeviceAddr get_global_semaphore_address(const std::shared_ptr& } void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids) { + const std::shared_ptr& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids) { auto* device = global_semaphore->device(); - device->push_work([global_semaphore, sub_device_ids] { global_semaphore->reset_semaphore_value(sub_device_ids); }); + device->push_work([global_semaphore, reset_value, sub_device_ids] { + global_semaphore->reset_semaphore_value(reset_value, sub_device_ids); + }); } MultiDeviceGlobalSemaphore create_global_semaphore( @@ -82,9 +86,11 @@ std::vector get_global_semaphore_address(const MultiDeviceGlobalSema } void reset_global_semaphore_value( - const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids) { + const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids) { for (const auto& global_semaphore : global_semaphore.global_semaphores) { - reset_global_semaphore_value(global_semaphore, sub_device_ids); + reset_global_semaphore_value(global_semaphore, reset_value, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp index b04cda2dd27..121e8c03cdf 100644 --- a/ttnn/cpp/ttnn/global_semaphore.hpp +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -24,7 +24,9 @@ std::shared_ptr create_global_semaphore( tt::stl::Span sub_device_ids = {}); DeviceAddr get_global_semaphore_address(const std::shared_ptr& global_semaphore); void reset_global_semaphore_value( - const std::shared_ptr& global_semaphore, const std::vector& sub_device_ids = {}); + const std::shared_ptr& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids = {}); // Multi Device APIs MultiDeviceGlobalSemaphore create_global_semaphore( @@ -35,6 +37,8 @@ MultiDeviceGlobalSemaphore create_global_semaphore( tt::stl::Span sub_device_ids = {}); std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); void reset_global_semaphore_value( - const MultiDeviceGlobalSemaphore& global_semaphore, const std::vector& sub_device_ids = {}); + const MultiDeviceGlobalSemaphore& global_semaphore, + uint32_t reset_value, + tt::stl::Span sub_device_ids = {}); } // namespace ttnn::global_semaphore From 52cc43757863f0623ee4ede86a4ff65e9d9f9a3f Mon Sep 17 00:00:00 2001 From: Colman Glagovich <114512306+cglagovichTT@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:46:49 -0500 Subject: [PATCH 08/13] Create chunked-prefill mode in SDPA op (#15907) --- .../misc/test_scaled_dot_product_attention.py | 231 +++++++++++++++- .../sdpa/device/kernels/compute/sdpa.cpp | 8 +- .../kernels/dataflow/reader_interleaved.cpp | 181 ++++++++++--- .../kernels/dataflow/writer_interleaved.cpp | 11 +- .../transformer/sdpa/device/sdpa_op.cpp | 250 ++++++++++++++---- .../transformer/sdpa/device/sdpa_op.hpp | 5 + .../sdpa/device/sdpa_program_factory.cpp | 134 ++++++++-- .../sdpa/device/sdpa_program_factory.hpp | 2 + .../ttnn/operations/transformer/sdpa/sdpa.cpp | 56 ++++ .../ttnn/operations/transformer/sdpa/sdpa.hpp | 29 ++ .../transformer/sdpa/sdpa_pybind.cpp | 68 +++++ 11 files changed, 844 insertions(+), 131 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index 736b2f2db82..37df74b5bf5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -73,7 +73,7 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype tt_back = ttnn.transformer.scaled_dot_product_attention( tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config ) - tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + tt_back = ttnn.to_torch(tt_back) K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S @@ -238,7 +238,7 @@ def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dt program_config=program_config, compute_kernel_config=compute_kernel_config, ) - tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + tt_back = ttnn.to_torch(tt_back) if nkv > 1 and nkv != nh: assert nh % nkv == 0 @@ -297,3 +297,230 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si pytest.skip("s must be divisible by q_chunk_size and k_chunk_size") ttnn.device.DisablePersistentKernelCache() run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk) + + +def run_test_chunked_sdpa( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_high_precision_compute, + grid_size=None, +): + program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(), + q_chunk_size=q_chunk_size, + k_chunk_size=k_chunk_size, + exp_approx_mode=False, + ) + + if use_high_precision_compute: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=False, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + Q = fa_rand(b, nh, s, d) + K = fa_rand(b, nkv, s, d) + V = fa_rand(b, nkv, s, d) + K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S + V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S + gt = torch.nn.functional.scaled_dot_product_attention(Q, K_repeated, V_repeated, is_causal=True) + + # Print shapes of all inputs along with input names + logger.debug(f"Q: {Q.shape}") + logger.debug(f"K: {K.shape}") + logger.debug(f"V: {V.shape}") + + assert s % prefill_chunk_size == 0, "s must be divisible by prefill_chunk_size" + assert prefill_chunk_size % page_block_size == 0, "prefill_chunk_size must be divisible by page_block_size" + num_prefill_chunks = s // prefill_chunk_size + # Prepare K, V paged for TT + max_num_blocks_per_seq = s // page_block_size + assert max_num_blocks_per_seq * page_block_size == s + max_num_blocks = b * max_num_blocks_per_seq + assert max_num_blocks * page_block_size == b * s + + # Shuffle paged KV cache according to some random page_table + permutation = torch.randperm(max_num_blocks) + reverse_permutation = torch.argsort(permutation) + # page_table is the reverse permutation from shuffled -> unshuffled, and is used to map + # a virtual block to the physical block id. + page_table = reverse_permutation.reshape(b, max_num_blocks_per_seq) + + def page_cache(cache): + paged_cache = ( + cache.reshape(b, nkv, max_num_blocks_per_seq, page_block_size, d) + .transpose(1, 2) + .reshape(max_num_blocks, nkv, page_block_size, d) + ) + + shuffled_page_cache = paged_cache[permutation] + return shuffled_page_cache + + def unpage_cache(cache): + unshuffled_page_cache = cache[reverse_permutation] + paged_cache_back = ( + unshuffled_page_cache.reshape(b, nkv, max_num_blocks_per_seq, page_block_size, d) + .transpose(1, 2) + .reshape(b, nkv, s, d) + ) + return paged_cache_back + + # Check that we can convert from normal to paged to normal + assert torch.allclose(unpage_cache(page_cache(K)), K), "K is not equal to unpage_cache(page_cache(K))" + assert torch.allclose(unpage_cache(page_cache(V)), V), "V is not equal to unpage_cache(page_cache(V))" + + tt_paged_K = ttnn.Tensor(page_cache(K), k_dtype).to(ttnn.TILE_LAYOUT).to(device) + tt_paged_V = ttnn.Tensor(page_cache(V), k_dtype).to(ttnn.TILE_LAYOUT).to(device) + page_table_tt = ttnn.Tensor(page_table, ttnn.int32).to(device) + + for chunk_idx in range(num_prefill_chunks): + # Chunk Q + Q_chunk = Q[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size] + tt_Q_chunk = ttnn.Tensor(Q_chunk, q_dtype).to(ttnn.TILE_LAYOUT).to(device) + chunk_start_idx = chunk_idx * prefill_chunk_size + + tt_back = ttnn.transformer.chunked_scaled_dot_product_attention( + tt_Q_chunk, + tt_paged_K, + tt_paged_V, + page_table_tt, + chunk_start_idx, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + ) + tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + gt_chunk = gt[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size] + out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998) + logger.debug(f"python vs pytorch: {out_pcc}") + assert out_pass + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) +@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) +@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048]) +@pytest.mark.parametrize("page_block_size", [64, 128]) +@pytest.mark.parametrize( + "b, nh, nkv, s, d", + [ + [1, 8, 1, 16 * 1024, 128], + ], # Llama2-70B +) +def test_sdpa_chunked( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_program_cache, + use_high_precision_compute=False, +): + for _ in range(2): + run_test_chunked_sdpa( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_high_precision_compute, + ) + + # Print number of program cache entries + assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( + device.num_program_cache_entries() + ) + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("q_chunk_size", [128]) +@pytest.mark.parametrize("k_chunk_size", [128]) +@pytest.mark.parametrize("prefill_chunk_size", [1024]) +@pytest.mark.parametrize("page_block_size", [64]) +@pytest.mark.parametrize( + "b, nh, nkv, s, d", + [ + [2, 1, 1, 4096, 128], + ], # Llama2-70B +) +def test_sdpa_chunked_iterate_batch( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_program_cache, + use_high_precision_compute=False, +): + """ + This tests chunked prefill where a single core has more than one user to process. + """ + for _ in range(2): + run_test_chunked_sdpa( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_high_precision_compute, + grid_size=(1, 1), + ) + + # Print number of program cache entries + assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( + device.num_program_cache_entries() + ) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 976364d32cb..1fa196dc220 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -360,6 +360,7 @@ void MAIN { constexpr uint32_t is_causal = get_compile_time_arg_val(22) == 1; constexpr uint32_t use_provided_mask = get_compile_time_arg_val(23) == 1; + constexpr uint32_t is_chunked = get_compile_time_arg_val(24) == 1; const uint32_t core_id = get_arg_val(0); const uint32_t local_batch_start = get_arg_val(1); @@ -368,6 +369,7 @@ void MAIN { const uint32_t local_nh_end = get_arg_val(4); const uint32_t local_q_start = get_arg_val(5); const uint32_t local_q_end = get_arg_val(6); + const uint32_t chunked_q_chunk_offset = get_arg_val(7); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -413,7 +415,10 @@ void MAIN { #endif // Get Q chunk - const uint32_t q_low_idx = + if constexpr (is_chunked) { + q_chunk = chunked_q_chunk_offset + q_chunk; + } + uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx; if constexpr (is_causal) { @@ -510,6 +515,7 @@ void MAIN { out_subblock_h, out_subblock_w, false /*transpose*/); + reconfig_data_format_srca(cb_out_im); cb_pop_front(cb_qk_im, qk_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp index 8b945b404e8..294a1b0b2a2 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp @@ -10,6 +10,21 @@ constexpr uint32_t get_barrier_read_threshold() { return ((512 / num_readers) * (1024 + 128)) / tile_bytes; } +template +uint32_t virtual_seq_tile_id_to_physical_tile_id( + uint32_t seq_tile_idx, uint32_t cur_head, volatile tt_l1_ptr const uint32_t* const page_table_ptr) { + // Given some index in the sequence tiles in range [0, max_seq_len_t] + // Return the physical tile id for that tile row + constexpr uint32_t block_stride = num_heads * block_size_t * Wt; + const uint32_t head_offset = cur_head * block_size_t * Wt; + + const uint32_t virtual_block = seq_tile_idx / block_size_t; + const uint32_t physical_block = page_table_ptr[virtual_block]; + const uint32_t block_row_offset = seq_tile_idx % block_size_t; + const uint32_t block_offset = block_row_offset * Wt; + return physical_block * block_stride + head_offset + block_offset; +} + void kernel_main() { constexpr uint32_t B = get_compile_time_arg_val(0); constexpr uint32_t NQH = get_compile_time_arg_val(1); @@ -24,18 +39,25 @@ void kernel_main() { constexpr uint32_t num_cores = get_compile_time_arg_val(10); constexpr uint32_t is_causal = get_compile_time_arg_val(11) == 1; constexpr uint32_t use_provided_mask = get_compile_time_arg_val(12) == 1; + constexpr uint32_t is_chunked = get_compile_time_arg_val(13) == 1; + constexpr uint32_t page_table_is_dram = get_compile_time_arg_val(14) == 1; + constexpr uint32_t block_size_t = get_compile_time_arg_val(15); + constexpr uint32_t page_table_stick_size = get_compile_time_arg_val(16); - const uint32_t q_addr = get_arg_val(0); - const uint32_t k_addr = get_arg_val(1); - const uint32_t v_addr = get_arg_val(2); - const uint32_t mask_addr = get_arg_val(3); - const uint32_t core_id = get_arg_val(4); - const uint32_t local_batch_start = get_arg_val(5); - const uint32_t local_batch_end = get_arg_val(6); - const uint32_t local_nh_start = get_arg_val(7); - const uint32_t local_nh_end = get_arg_val(8); - const uint32_t local_q_start = get_arg_val(9); - const uint32_t local_q_end = get_arg_val(10); + uint32_t argidx = 0; + const uint32_t q_addr = get_arg_val(argidx++); + const uint32_t k_addr = get_arg_val(argidx++); + const uint32_t v_addr = get_arg_val(argidx++); + const uint32_t mask_addr = get_arg_val(argidx++); + const uint32_t page_table_addr = get_arg_val(argidx++); + const uint32_t core_id = get_arg_val(argidx++); + const uint32_t local_batch_start = get_arg_val(argidx++); + const uint32_t local_batch_end = get_arg_val(argidx++); + const uint32_t local_nh_start = get_arg_val(argidx++); + const uint32_t local_nh_end = get_arg_val(argidx++); + const uint32_t local_q_start = get_arg_val(argidx++); + const uint32_t local_q_end = get_arg_val(argidx++); + const uint32_t chunked_q_chunk_offset = get_arg_val(argidx++); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -49,6 +71,7 @@ void kernel_main() { constexpr uint32_t cb_k_in = tt::CBIndex::c_1; constexpr uint32_t cb_v_in = tt::CBIndex::c_2; constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t cb_id_page_table = tt::CBIndex::c_6; constexpr uint32_t onetile = 1; constexpr uint32_t q_tile_bytes = get_tile_size(cb_q_in); @@ -76,6 +99,8 @@ void kernel_main() { const InterleavedAddrGenFast mask_reader = { .bank_base_address = mask_addr, .page_size = mask_tile_bytes, .data_format = mask_data_format}; + volatile tt_l1_ptr uint32_t* page_table_ptr; + uint32_t q_tile_id = 0; uint32_t k_tile_id = 0; uint32_t v_tile_id = 0; @@ -83,6 +108,18 @@ void kernel_main() { uint32_t barrier_count = 0; for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { + if constexpr (is_chunked) { + // Chunked means that we have paged attention + const InterleavedAddrGen page_table_gen = { + .bank_base_address = page_table_addr, .page_size = page_table_stick_size}; + cb_reserve_back(cb_id_page_table, 1); + uint32_t page_table_cb_wr_ptr = get_write_ptr(cb_id_page_table); + uint64_t page_table_noc_addr = get_noc_addr(nb, page_table_gen); + noc_async_read(page_table_noc_addr, page_table_cb_wr_ptr, page_table_stick_size); + noc_async_read_barrier(); + cb_push_back(cb_id_page_table, 1); + page_table_ptr = reinterpret_cast(page_table_cb_wr_ptr); + } const uint32_t q_batch_offset = nb * NQH * Sqt * DHt; const uint32_t kv_batch_offset = nb * NKH * Skt * DHt; const uint32_t mask_batch_offset = nb * Sqt * Skt; @@ -124,7 +161,10 @@ void kernel_main() { cb_push_back(cb_q_in, q_chunk_tiles); - const uint32_t q_low_idx = + if constexpr (is_chunked) { + q_chunk = chunked_q_chunk_offset + q_chunk; + } + uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx; if constexpr (is_causal) { @@ -142,25 +182,53 @@ void kernel_main() { const uint32_t k_high_idx = k_low_idx + Sk_chunk_t; const uint32_t k_start_tile_id = kv_batch_offset + kv_head_offset + k_chunk * Sk_chunk_t * DHt; - // Read K chunk transposed - cb_reserve_back(cb_k_in, k_chunk_tiles); - uint32_t k_write_ptr = get_write_ptr(cb_k_in); - barrier_count = 0; - for (uint32_t col = 0; col < DHt; ++col) { - k_tile_id = k_start_tile_id + col; + if constexpr (is_chunked) { + // Use page table to read K chunk + const uint32_t k_chunk_start_row_num = k_chunk * Sk_chunk_t; + cb_reserve_back(cb_k_in, k_chunk_tiles); + uint32_t k_write_ptr = get_write_ptr(cb_k_in); + barrier_count = 0; for (uint32_t row = 0; row < Sk_chunk_t; ++row) { - noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); - k_tile_id += DHt; - k_write_ptr += k_tile_bytes; + uint32_t k_write_ptr_col = k_write_ptr + row * k_tile_bytes; + uint32_t virtual_k_tile_row_num = k_chunk_start_row_num + row; + uint32_t physical_k_tile_id = + virtual_seq_tile_id_to_physical_tile_id( + virtual_k_tile_row_num, kv_head, page_table_ptr); + for (uint32_t col = 0; col < DHt; ++col) { + noc_async_read_tile(physical_k_tile_id, k_reader, k_write_ptr_col); + physical_k_tile_id += 1; // Go to next tile in row + k_write_ptr_col += Sk_chunk_t * k_tile_bytes; // Go to next column in CB - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } } } + noc_async_read_barrier(); + cb_push_back(cb_k_in, k_chunk_tiles); + + } else { + // Read K chunk transposed + cb_reserve_back(cb_k_in, k_chunk_tiles); + uint32_t k_write_ptr = get_write_ptr(cb_k_in); + barrier_count = 0; + for (uint32_t col = 0; col < DHt; ++col) { + k_tile_id = k_start_tile_id + col; + for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + noc_async_read_tile(k_tile_id, k_reader, k_write_ptr); + k_tile_id += DHt; + k_write_ptr += k_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_k_in, k_chunk_tiles); } - noc_async_read_barrier(); - cb_push_back(cb_k_in, k_chunk_tiles); if constexpr (use_provided_mask) { // Finding the diagonal is harder now that q_chunk_size and k_chunk_size can differ @@ -191,25 +259,56 @@ void kernel_main() { cb_push_back(cb_mask_in, mask_chunk_tiles); } - v_tile_id = k_start_tile_id; - // Read V chunk - cb_reserve_back(cb_v_in, k_chunk_tiles); - uint32_t v_write_ptr = get_write_ptr(cb_v_in); - barrier_count = 0; - for (uint32_t tile = 0; tile < k_chunk_tiles; ++tile) { - noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); - v_tile_id += 1; - v_write_ptr += v_tile_bytes; - - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; + if constexpr (is_chunked) { + // Use page table to read V chunk + const uint32_t k_chunk_start_row_num = k_chunk * Sk_chunk_t; + cb_reserve_back(cb_v_in, k_chunk_tiles); + uint32_t v_write_ptr = get_write_ptr(cb_v_in); + barrier_count = 0; + + for (uint32_t row = 0; row < Sk_chunk_t; ++row) { + uint32_t virtual_v_tile_row_num = k_chunk_start_row_num + row; + uint32_t physical_v_tile_id = + virtual_seq_tile_id_to_physical_tile_id( + virtual_v_tile_row_num, kv_head, page_table_ptr); + for (uint32_t col = 0; col < DHt; ++col) { + noc_async_read_tile(physical_v_tile_id, v_reader, v_write_ptr); + physical_v_tile_id += 1; + v_write_ptr += v_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_v_in, k_chunk_tiles); + } else { + v_tile_id = k_start_tile_id; + // Read V chunk + cb_reserve_back(cb_v_in, k_chunk_tiles); + uint32_t v_write_ptr = get_write_ptr(cb_v_in); + barrier_count = 0; + for (uint32_t tile = 0; tile < k_chunk_tiles; ++tile) { + noc_async_read_tile(v_tile_id, v_reader, v_write_ptr); + v_tile_id += 1; + v_write_ptr += v_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } } + noc_async_read_barrier(); + cb_push_back(cb_v_in, k_chunk_tiles); } - noc_async_read_barrier(); - cb_push_back(cb_v_in, k_chunk_tiles); } } } + + if constexpr (is_chunked) { + cb_pop_front(cb_id_page_table, 1); + } } } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp index 5cf07e576e2..5ac5d251036 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp @@ -50,7 +50,6 @@ void fill_diagonal_tile(uint32_t cb_id, uint32_t tile_id, uint32_t partial_val) fill_tile(cb_id, tile_id, 0); - // DPRINT << "Fill partial tile" << ENDL(); const uint16_t datum_val = partial_val >> 16; volatile tt_l1_ptr uint16_t* uint16_ptr = reinterpret_cast(get_write_ptr(cb_id) + tile_id * tile_bytes); @@ -147,6 +146,7 @@ void kernel_main() { constexpr uint32_t num_cores = get_compile_time_arg_val(11); constexpr uint32_t is_causal = get_compile_time_arg_val(12) == 1; constexpr uint32_t use_provided_mask = get_compile_time_arg_val(13) == 1; + constexpr uint32_t is_chunked = get_compile_time_arg_val(14) == 1; const uint32_t out_addr = get_arg_val(0); const uint32_t core_id = get_arg_val(1); @@ -156,6 +156,7 @@ void kernel_main() { const uint32_t local_nh_end = get_arg_val(5); const uint32_t local_q_start = get_arg_val(6); const uint32_t local_q_end = get_arg_val(7); + const uint32_t chunk_start_t_in_q_chunks = get_arg_val(8); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -205,9 +206,13 @@ void kernel_main() { out_tile_id = q_batch_offset + q_head_offset + q_chunk_offset; if constexpr (is_causal) { - const uint32_t q_low_idx = + if constexpr (is_chunked) { + // Bump it up to the chunk start + q_chunk = chunk_start_t_in_q_chunks + q_chunk; + } + uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk - const uint32_t q_high_idx = q_low_idx + Sq_chunk_t; + uint32_t q_high_idx = q_low_idx + Sq_chunk_t; for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index 5b3981bedfe..cc137f759ed 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -14,9 +14,11 @@ namespace ttnn::operations::transformer { void ScaledDotProductAttention::validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { + // Common validations for both modes + TT_FATAL(input_tensors.size() == 3, "Must have 3 input tensors (Q, K, V)"); TT_FATAL( - input_tensors.size() == 3 and optional_input_tensors.size() == 1, - "Must have 3 input tensors and optional mask"); + optional_input_tensors.size() == 1 or optional_input_tensors.size() == 2, + "Must have 1 or 2 optional tensors (mask/page_table)"); for (auto& input_tensor : input_tensors) { TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device"); @@ -29,75 +31,188 @@ void ScaledDotProductAttention::validate( "Operands to SDPA need to be in DRAM"); } - TT_FATAL( - !(this->is_causal && optional_input_tensors.at(0).has_value()), - "is_causal and attn_mask cannot both be present. Got is_causal: {}, attn_mask: {}", - this->is_causal, - optional_input_tensors.at(0).has_value()); + auto validate_regular_mode = [&]() { + TT_FATAL( + !(this->is_causal && optional_input_tensors.at(0).has_value()), + "is_causal and attn_mask cannot both be present. Got is_causal: {}, attn_mask: {}", + this->is_causal, + optional_input_tensors.at(0).has_value()); + + const auto& mask_option = optional_input_tensors.at(0); + if (mask_option.has_value()) { + auto mask = mask_option.value(); + TT_FATAL( + mask.storage_type() == StorageType::DEVICE, + "When mask is provided to SDPA, the tensor must be on device"); + TT_FATAL( + input_tensors.at(0).device() == mask.device(), + "When mask is provided to SDPA, it must be on the same device as the input tensors"); + TT_FATAL(mask.get_layout() == Layout::TILE, "When mask is provided to SDPA, it must be tilized"); + TT_FATAL( + mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B || + mask.get_dtype() == DataType::BFLOAT4_B, + "When mask is provided to SDPA, it must be in BF16, BFP8, or BFP4 dataformat"); + + TT_FATAL( + mask.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, + "When mask is provided to SDPA, it must be in DRAM"); + + const auto mask_shape = mask.get_legacy_shape(); + const auto q_shape = input_tensors.at(0).get_legacy_shape(); + const auto k_shape = input_tensors.at(1).get_legacy_shape(); + + TT_FATAL(mask_shape[0] == q_shape[0], "Mask batch dim must match Q batch dim"); + TT_FATAL(mask_shape[1] == 1, "Mask num_heads must be 1 to be broadcasted across all heads"); + TT_FATAL(mask_shape[2] == q_shape[2], "Mask sequence length must match Q sequence length"); + TT_FATAL(mask_shape[3] == k_shape[2], "Mask sequence length must match K sequence length"); + } + + // Shape checks + const auto q_shape = input_tensors.at(0).get_legacy_shape(); + const auto k_shape = input_tensors.at(1).get_legacy_shape(); + const auto v_shape = input_tensors.at(2).get_legacy_shape(); + const auto B = q_shape[0]; + const auto nqh = q_shape[1]; + const auto nkv = k_shape[1]; + const auto Sq = q_shape[2]; + const auto DH = q_shape[3]; + const auto Sk = k_shape[2]; + if (this->is_causal) { + TT_FATAL( + Sq == Sk, "Causal SDPA requires Q and K to have the same sequence length. Got Q: {}, K: {}", Sq, Sk); + } - const auto& mask_option = optional_input_tensors.at(0); - if (mask_option.has_value()) { - auto mask = optional_input_tensors.at(0).value(); TT_FATAL( - mask.storage_type() == StorageType::DEVICE, "When mask is provided to SDPA, the tensor must be on device"); + k_shape[0] == B && v_shape[0] == B, "K and V batch must match. Got K: {}, V: {}", k_shape[0], v_shape[0]); + TT_FATAL(v_shape[1] == nkv, "K and V num_heads must match. Got K: {}, V: {}", k_shape[1], v_shape[1]); + TT_FATAL(v_shape[2] == Sk, "K and V sequence length must match. Got K: {}, V: {}", k_shape[2], v_shape[2]); TT_FATAL( - input_tensors.at(0).device() == mask.device(), - "When mask is provided to SDPA, it must be on the same device as the input tensors"); - TT_FATAL(mask.get_layout() == Layout::TILE, "When mask is provided to SDPA, it must be tilized"); + k_shape[3] == DH && v_shape[3] == DH, + "K and V hidden dim must match. Got K: {}, V: {}", + k_shape[3], + v_shape[3]); TT_FATAL( - mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B || - mask.get_dtype() == DataType::BFLOAT4_B, - "When mask is provided to SDPA, it must be in BF16, BFP8, or BFP4 dataformat"); + nqh >= nkv && nqh % nkv == 0, + "Q num_heads must be >= K num_heads and divisible by K num_heads. Got Q: {}, K: {}", + nqh, + nkv); + if (this->program_config.has_value()) { + auto q_chunk_size = program_config->q_chunk_size; + auto k_chunk_size = program_config->k_chunk_size; + + TT_FATAL( + Sq % q_chunk_size == 0, + "q_chunk_size must divide q_shape[-2]. Got q_chunk_size: {}, q_shape[-2]: {}", + q_chunk_size, + q_shape[-2]); + TT_FATAL( + Sk % k_chunk_size == 0, + "k_chunk_size must divide k_shape[-2]. Got k_chunk_size: {}, k_shape[-2]: {}", + k_chunk_size, + k_shape[-2]); + } + }; + + auto validate_chunked_mode = [&]() { + TT_FATAL(chunk_start_idx.has_value(), "chunk_start_idx must be provided for chunked mode"); + TT_FATAL(chunk_start_idx.value() >= 0, "chunk_start_idx must be non-negative"); + + // Validate page table tensor + const auto& page_table = optional_input_tensors[1].value(); + TT_FATAL(page_table.storage_type() == StorageType::DEVICE, "Page table tensor must be on device"); TT_FATAL( - mask.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, - "When mask is provided to SDPA, it must be in DRAM"); - } + input_tensors.at(0).device() == page_table.device(), + "Page table must be on the same device as the input tensors"); + TT_FATAL(page_table.get_layout() == Layout::ROW_MAJOR, "Page table must be row major"); + // Check that page table is int32 + TT_FATAL(page_table.get_dtype() == DataType::INT32, "Page table must be int32"); + // Validate that first optional tensor (mask) is not provided + TT_FATAL( + !optional_input_tensors[0].has_value(), + "Attention mask should not be provided in chunked mode - masking is handled internally"); - // assert all dataformats are the same - TT_FATAL( - input_tensors.at(0).get_dtype() == input_tensors.at(1).get_dtype() && - input_tensors.at(0).get_dtype() == input_tensors.at(2).get_dtype(), - "All inputs to SDPA must have the same dataformat"); - - TT_FATAL(this->output_mem_config.buffer_type == tt::tt_metal::BufferType::DRAM, "Output must be in DRAM"); - - // Check shapes - const auto q_shape = input_tensors.at(0).get_legacy_shape(); - const auto k_shape = input_tensors.at(1).get_legacy_shape(); - const auto v_shape = input_tensors.at(2).get_legacy_shape(); - const auto B = q_shape[0]; - const auto nqh = q_shape[1]; - const auto nkv = k_shape[1]; - const auto Sq = q_shape[2]; - const auto DH = q_shape[3]; - const auto Sk = k_shape[2]; - if (this->is_causal) { - TT_FATAL(Sq == Sk, "Causal SDPA requires Q and K to have the same sequence length. Got Q: {}, K: {}", Sq, Sk); - } + // Additional chunked-specific validations + const auto q_shape = input_tensors.at(0).get_legacy_shape(); + const auto k_shape = input_tensors.at(1).get_legacy_shape(); + const auto v_shape = input_tensors.at(2).get_legacy_shape(); + const auto page_table_shape = page_table.get_legacy_shape(); + const auto B = q_shape[0]; + const auto nqh = q_shape[1]; + const auto nkv = k_shape[1]; + const auto Sq = q_shape[2]; + const auto DH = q_shape[3]; + const auto k_page_size = k_shape[2]; + const uint32_t num_pages_per_user = page_table.get_legacy_shape()[1]; + // Check that k page size matches v page size + TT_FATAL( + k_page_size == v_shape[2], "K page size must match V page size. Got K: {}, V: {}", k_page_size, v_shape[2]); + // Check that page table has same batch size as input tensors + TT_FATAL( + page_table_shape[0] == B, + "Page table batch size must match input batch size. Got Page table: {}, Input: {}", + page_table_shape[0], + B); + // Calculate K length based on number of pages per user + const uint32_t kv_length = num_pages_per_user * k_page_size; - TT_FATAL(k_shape[0] == B && v_shape[0] == B, "K and V batch must match. Got K: {}, V: {}", k_shape[0], v_shape[0]); - TT_FATAL(v_shape[1] == nkv, "K and V num_heads must match. Got K: {}, V: {}", k_shape[1], v_shape[1]); - TT_FATAL(v_shape[2] == Sk, "K and V sequence length must match. Got K: {}, V: {}", k_shape[2], v_shape[2]); - TT_FATAL(k_shape[3] == DH && v_shape[3] == DH, "K and V hidden dim must match. Got K: {}, V: {}", k_shape[3], v_shape[3]); - TT_FATAL(nqh >= nkv && nqh % nkv == 0, "Q num_heads must be >= K num_heads and divisible by K num_heads. Got Q: {}, K: {}", nqh, nkv); + TT_FATAL(v_shape[1] == nkv, "K and V num_heads must match. Got K: {}, V: {}", k_shape[1], v_shape[1]); + TT_FATAL( + k_shape[3] == DH && v_shape[3] == DH, + "K and V hidden dim must match. Got K: {}, V: {}", + k_shape[3], + v_shape[3]); + TT_FATAL( + nqh >= nkv && nqh % nkv == 0, + "Q num_heads must be >= K num_heads and divisible by K num_heads. Got Q: {}, K: {}", + nqh, + nkv); - if (mask_option.has_value()) { - const auto mask_shape = mask_option.value().get_legacy_shape(); + if (this->program_config.has_value()) { + auto q_chunk_size = program_config->q_chunk_size; + auto k_chunk_size = program_config->k_chunk_size; - TT_FATAL(mask_shape[0] == B, "Mask batch dim must match Q batch dim"); - TT_FATAL(mask_shape[1] == 1, "Mask num_heads must be 1 to be broadcasted across all heads"); - TT_FATAL(mask_shape[2] == Sq, "Mask sequence length must match Q sequence length"); - TT_FATAL(mask_shape[3] == Sk, "Mask sequence length must match K sequence length"); - } + TT_FATAL( + Sq % q_chunk_size == 0, + "q_chunk_size must divide q_shape[-2]. Got q_chunk_size: {}, q_shape[-2]: {}", + q_chunk_size, + q_shape[-2]); + TT_FATAL( + kv_length % k_chunk_size == 0, + "k_chunk_size must divide k_shape[-2]. Got k_chunk_size: {}, k_shape[-2]: {}", + k_chunk_size, + k_shape[-2]); + } - if (this->program_config.has_value()) { - auto q_chunk_size = program_config->q_chunk_size; - auto k_chunk_size = program_config->k_chunk_size; + // In chunked mode, K's sequence dimension should be >= Q's sequence dimension + chunk_start_idx + TT_FATAL( + kv_length >= q_shape[2] + chunk_start_idx.value(), + "K's sequence length must be >= Q's sequence length + chunk_start_idx. Got K: {}, Q: {}, chunk_start_idx: " + "{}", + kv_length, + q_shape[2], + chunk_start_idx.value()); + }; - TT_FATAL(Sq % q_chunk_size == 0, "q_chunk_size must divide q_shape[-2]. Got q_chunk_size: {}, q_shape[-2]: {}", q_chunk_size, q_shape[-2]); - TT_FATAL(Sk % k_chunk_size == 0, "k_chunk_size must divide k_shape[-2]. Got k_chunk_size: {}, k_shape[-2]: {}", k_chunk_size, k_shape[-2]); + auto check_conditions = [&]() { + bool has_chunk_start = chunk_start_idx.has_value(); + bool has_two_optional_inputs = optional_input_tensors.size() == 2; + bool has_page_table = optional_input_tensors.size() > 1 && optional_input_tensors.at(1).has_value(); + TT_FATAL( + has_chunk_start == has_two_optional_inputs, "chunk_start_idx and number of optional inputs must match"); + TT_FATAL( + has_two_optional_inputs == has_page_table, + "page_table must be provided if and only if there are two optional inputs"); + }; + check_conditions(); + bool is_chunked_mode = chunk_start_idx.has_value(); + + // Check if we're in chunked mode and call appropriate validation + if (is_chunked_mode) { + validate_chunked_mode(); + } else { + validate_regular_mode(); } } @@ -125,6 +240,8 @@ operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( std::size_t q_chunk_size = this->program_config ? this->program_config->q_chunk_size : 32; std::size_t k_chunk_size = this->program_config ? this->program_config->k_chunk_size : 32; + // get page table if chunked + const auto page_table = this->chunk_start_idx.has_value() ? optional_input_tensors.at(1) : std::nullopt; return detail::sdpa_multi_core( input_tensor_q, @@ -132,6 +249,8 @@ operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( input_tensor_v, output_tensor, attn_mask, + page_table, + this->chunk_start_idx, scale, this->is_causal, q_chunk_size, @@ -140,4 +259,19 @@ operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( this->program_config); } +operation::Hash ScaledDotProductAttention::compute_program_hash( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const { + bool is_chunked_prefill = this->chunk_start_idx.has_value(); + return operation::hash_operation( + this->scale, + this->output_mem_config, + this->program_config, + this->is_causal, + is_chunked_prefill, + this->compute_kernel_config, + input_tensors, + optional_input_tensors); +} + } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp index dfe88389084..02f76971bec 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp @@ -18,6 +18,7 @@ struct ScaledDotProductAttention { const MemoryConfig output_mem_config; const std::optional program_config; const bool is_causal; + const std::optional chunk_start_idx; const DeviceComputeKernelConfig compute_kernel_config; void validate( @@ -30,6 +31,10 @@ struct ScaledDotProductAttention { const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector& output_tensors) const; + + operation::Hash compute_program_hash( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const; }; } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 9278d02c812..ec9805872e7 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "sdpa_program_factory.hpp" +#include "sdpa_op.hpp" #include @@ -26,6 +27,8 @@ operation::ProgramWithCallbacks sdpa_multi_core( const Tensor& input_tensor_v, const Tensor& output_tensor, const std::optional& attn_mask, + const std::optional& page_table, + const std::optional& chunk_start_idx, std::optional scale, bool is_causal, std::size_t q_chunk_size, @@ -42,8 +45,13 @@ operation::ProgramWithCallbacks sdpa_multi_core( const auto q_shape = input_tensor_q.get_legacy_shape(); const auto k_shape = input_tensor_k.get_legacy_shape(); const uint32_t B = q_shape[0], NQH = q_shape[1], Sq = q_shape[2], DH = q_shape[3]; - const uint32_t Sk = k_shape[2]; const uint32_t NKH = k_shape[1]; + + // Paged cache parameters when in chunked mode + bool is_chunked = chunk_start_idx.has_value(); + // In chunked mode, we only need to process K/V up to chunk_start_idx + Sq + const uint32_t Sk = is_chunked ? (chunk_start_idx.value() + Sq) : k_shape[2]; + const uint32_t Sqt = Sq / TILE_HEIGHT; const uint32_t Skt = Sk / TILE_HEIGHT; const uint32_t DHt = DH / TILE_WIDTH; @@ -72,6 +80,42 @@ operation::ProgramWithCallbacks sdpa_multi_core( tt::log_debug("k_num_chunks: {}", k_num_chunks); tt::log_debug("NKH: {}", NKH); + // In chunked prefill mode, the offset of Q in terms of Q chunks + uint32_t chunked_q_chunk_offset = 0; + uint32_t block_size = 0; + uint32_t block_size_t = 0; + uint32_t max_blocks_per_seq = 0; + uint32_t page_table_stick_size = 0; + bool page_table_is_dram = true; + tt::DataFormat page_table_df = tt::DataFormat::Int32; + + if (is_chunked) { + chunked_q_chunk_offset = chunk_start_idx.value() / q_chunk_size; + const auto& page_table_tensor = page_table.value(); + block_size = k_shape[2]; // K's sequence dimension represents block size + block_size_t = block_size / TILE_HEIGHT; + max_blocks_per_seq = page_table_tensor.get_legacy_shape()[1]; + page_table_stick_size = page_table_tensor.buffer()->aligned_page_size(); + TT_FATAL( + page_table_stick_size % 32 == 0, + "page table page size in bytes must be a multiple of 32 due to address alignment"); + page_table_is_dram = page_table_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM; + + TT_FATAL( + page_table_stick_size % 32 == 0, + "page table page size in bytes must be a multiple of 32 due to address alignment"); + } + // Log page table info + tt::log_debug("is_chunked: {}", is_chunked); + if (is_chunked) { + tt::log_debug("block_size: {}", block_size); + tt::log_debug("block_size_t: {}", block_size_t); + tt::log_debug("max_blocks_per_seq: {}", max_blocks_per_seq); + tt::log_debug("page_table_stick_size: {}", page_table_stick_size); + tt::log_debug("page_table_is_dram: {}", page_table_is_dram); + tt::log_debug("page_table_df: {}", page_table_df); + } + Program program = CreateProgram(); Device* device = input_tensor_q.device(); @@ -220,22 +264,26 @@ operation::ProgramWithCallbacks sdpa_multi_core( scale_union.f = scale.value_or(1.0f); std::vector reader_compile_time_args = {// interleaved accessor args - B, - NQH, - NKH, - Sqt, - Skt, - DHt, - Sq_chunk_t, - q_num_chunks, - Sk_chunk_t, - k_num_chunks, - num_cores, - (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask - }; - - std::vector writer_compile_time_args = {// interleaved accessor args + B, + NQH, + NKH, + Sqt, + Skt, + DHt, + Sq_chunk_t, + q_num_chunks, + Sk_chunk_t, + k_num_chunks, + num_cores, + (std::uint32_t)is_causal, + (std::uint32_t)use_provided_mask, + (uint32_t)is_chunked, + (uint32_t)page_table_is_dram, + block_size_t, + page_table_stick_size}; + + std::vector writer_compile_time_args = { + // interleaved accessor args B, NQH, NKH, @@ -249,10 +297,12 @@ operation::ProgramWithCallbacks sdpa_multi_core( scale_union.u, num_cores, (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask + (std::uint32_t)use_provided_mask, + (uint32_t)is_chunked, }; - std::vector compute_compile_time_args = {// matmul args + std::vector compute_compile_time_args = { + // matmul args B, NQH, NKH, @@ -276,7 +326,8 @@ operation::ProgramWithCallbacks sdpa_multi_core( out_num_blocks, num_cores, (std::uint32_t)is_causal, - (std::uint32_t)use_provided_mask + (std::uint32_t)use_provided_mask, + (uint32_t)is_chunked, }; std::map defines; @@ -381,6 +432,12 @@ operation::ProgramWithCallbacks sdpa_multi_core( .set_page_size(tt::CBIndex::c_5, scalar_tile_size); auto cb_in5_id = CreateCircularBuffer(program, core_grid, c_in5_config); + if (is_chunked) { + auto c_in6_config = CircularBufferConfig(page_table_stick_size, {{tt::CBIndex::c_6, page_table_df}}) + .set_page_size(tt::CBIndex::c_6, page_table_stick_size); + auto cb_in6_id = CreateCircularBuffer(program, core_grid, c_in6_config); + } + // cb_qk_im auto c_intermed0_config = CircularBufferConfig(qk_tiles * im_tile_size, {{tt::CBIndex::c_24, im_df}}) .set_page_size(tt::CBIndex::c_24, im_tile_size); @@ -471,13 +528,15 @@ operation::ProgramWithCallbacks sdpa_multi_core( k_addr, v_addr, mask_addr, + is_chunked ? page_table.value().buffer()->address() : 0, i, local_batch_start, local_batch_end, local_nh_start, local_nh_end, local_q_start, - local_q_end}); + local_q_end, + chunked_q_chunk_offset}); SetRuntimeArgs( program, writer_kernels_id, @@ -489,12 +548,20 @@ operation::ProgramWithCallbacks sdpa_multi_core( local_nh_start, local_nh_end, local_q_start, - local_q_end}); + local_q_end, + chunked_q_chunk_offset}); SetRuntimeArgs( program, compute_kernels_id, core, - {i, local_batch_start, local_batch_end, local_nh_start, local_nh_end, local_q_start, local_q_end}); + {i, + local_batch_start, + local_batch_end, + local_nh_start, + local_nh_end, + local_q_start, + local_q_end, + chunked_q_chunk_offset}); } auto override_runtime_arguments_callback = @@ -513,7 +580,9 @@ operation::ProgramWithCallbacks sdpa_multi_core( q_num_chunks, is_causal, cb_in0_id, - cb_out0_id]( + cb_out0_id, + is_chunked, + q_chunk_size]( const void* operation, Program& program, const std::vector& input_tensors, @@ -532,9 +601,17 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t mask_addr = mask_buffer != nullptr ? mask_buffer->address() : 0; uint32_t out_addr = out0_buffer->address(); + uint32_t page_table_addr = 0; + uint32_t chunked_q_chunk_offset = 0; + if (is_chunked) { + page_table_addr = optional_input_tensors.at(1).value().buffer()->address(); + chunked_q_chunk_offset = + static_cast(operation)->chunk_start_idx.value() / q_chunk_size; + } + auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id); auto& writer_args_by_core = GetRuntimeArgs(program, writer_kernels_id); - + auto& compute_args_by_core = GetRuntimeArgs(program, compute_kernels_id); // Set reader rt args for (uint32_t i = 0; i < num_cores; ++i) { CoreCoord core = {i % grid_size.x, i / grid_size.x}; @@ -566,13 +643,18 @@ operation::ProgramWithCallbacks sdpa_multi_core( auto& reader_args = reader_args_by_core[core.x][core.y]; auto& writer_args = writer_args_by_core[core.x][core.y]; - + auto& compute_args = compute_args_by_core[core.x][core.y]; reader_args[0] = q_addr; reader_args[1] = k_addr; reader_args[2] = v_addr; reader_args[3] = mask_addr; + reader_args[4] = page_table_addr; + reader_args[12] = chunked_q_chunk_offset; writer_args[0] = out_addr; + writer_args[8] = chunked_q_chunk_offset; + + compute_args[7] = chunked_q_chunk_offset; } }; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.hpp index 06ede29653b..f9fcb254419 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.hpp @@ -16,6 +16,8 @@ operation::ProgramWithCallbacks sdpa_multi_core( const Tensor& input_tensor_v, const Tensor& output_tensor, const std::optional& attn_mask, + const std::optional& page_table, + const std::optional& chunk_start_idx, std::optional scale, bool is_causal, std::size_t q_chunk_size, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp index 1cd57008807..0b7320b6a85 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.cpp @@ -37,6 +37,7 @@ ttnn::Tensor ExecuteScaledDotProductAttention::invoke( .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), .program_config = program_config, .is_causal = is_causal, + .chunk_start_idx = std::nullopt, .compute_kernel_config = kernel_config_val}, {input_tensor_q, input_tensor_k, input_tensor_v}, {attn_mask}, @@ -68,4 +69,59 @@ ttnn::Tensor ExecuteScaledDotProductAttention::invoke( compute_kernel_config); } +ttnn::Tensor ExecuteChunkedScaledDotProductAttention::invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& page_table_tensor, + int64_t chunk_start_idx, + std::optional scale, + const std::optional& memory_config, + std::optional program_config, + std::optional compute_kernel_config) { + auto arch = input_tensor_q.storage_type() == StorageType::DEVICE + ? input_tensor_q.device()->arch() + : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config( + input_tensor_q.device()->arch(), compute_kernel_config, MathFidelity::HiFi2, true, false, false); + + return operation::run( + ScaledDotProductAttention{ + .scale = scale, + .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), + .program_config = program_config, + .is_causal = true, // Always causal for chunked version + .chunk_start_idx = chunk_start_idx, + .compute_kernel_config = kernel_config_val}, + {input_tensor_q, input_tensor_k, input_tensor_v}, + {std::nullopt, page_table_tensor}, // No attention mask - handled internally based on chunk_start_idx + {}, + queue_id) + .at(0); +} + +ttnn::Tensor ExecuteChunkedScaledDotProductAttention::invoke( + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& page_table_tensor, + int64_t chunk_start_idx, + std::optional scale, + const std::optional& memory_config, + std::optional program_config, + std::optional compute_kernel_config) { + return invoke( + DefaultQueueId, + input_tensor_q, + input_tensor_k, + input_tensor_v, + page_table_tensor, + chunk_start_idx, + scale, + memory_config, + program_config, + compute_kernel_config); +} + } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp index cba93297240..1f94b192158 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa.hpp @@ -36,6 +36,31 @@ struct ExecuteScaledDotProductAttention { std::optional compute_kernel_config = std::nullopt); }; +struct ExecuteChunkedScaledDotProductAttention { + static ttnn::Tensor invoke( + uint8_t queue_id, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& page_table_tensor, + int64_t chunk_start_idx, + std::optional scale = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional program_config = std::nullopt, + std::optional compute_kernel_config = std::nullopt); + + static ttnn::Tensor invoke( + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& page_table_tensor, + int64_t chunk_start_idx, + std::optional scale = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional program_config = std::nullopt, + std::optional compute_kernel_config = std::nullopt); +}; + } // namespace operations::transformer namespace transformer { @@ -44,6 +69,10 @@ constexpr auto scaled_dot_product_attention = ttnn::register_operation_with_auto "ttnn::transformer::scaled_dot_product_attention", ttnn::operations::transformer::ExecuteScaledDotProductAttention>(); +constexpr auto chunked_scaled_dot_product_attention = ttnn::register_operation_with_auto_launch_op< + "ttnn::transformer::chunked_scaled_dot_product_attention", + ttnn::operations::transformer::ExecuteChunkedScaledDotProductAttention>(); + } // namespace transformer } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp index 638e257b8e6..54817868168 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp @@ -81,5 +81,73 @@ void py_bind_sdpa(py::module& module) { py::arg("compute_kernel_config").noconvert() = std::nullopt, py::arg("queue_id") = 0, }); + + auto chunked_doc = + R"doc( + Chunked causal scaled dot product attention for processing long sequences in chunks. + This variant allows processing of sequences longer than the maximum supported length + by splitting the input into chunks and maintaining KV cache state. + The KV cache is page-based, and the page table tensor is used to map the page indices to the corresponding KV cache indices. + + Args: + input_tensor_q (ttnn.Tensor): the input tensor. [b x nqh x s x dh] + input_tensor_k (ttnn.Tensor): the input tensor. [b x nkv x s x dh] + input_tensor_v (ttnn.Tensor): the input tensor. [b x nkv x s x dh] + page_table_tensor (ttnn.Tensor): the page table tensor. [b x num_pages] + chunk_start_idx (int): Absolute position in the sequence where this chunk starts. + + Keyword args: + scale (float, optional): Defaults to `None`. + memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + program_config (SDPAProgramConfig, optional): Defaults to `None`. + compute_kernel_config (ttnn.DeviceComputeKernelConfig, optional): Defaults to `None`. + queue_id (int, optional): command queue id. Defaults to `0`. + + Returns: + ttnn.Tensor: the output tensor [b x nqh x s x dh]. + + )doc"; + + using ChunkedOperationType = decltype(ttnn::transformer::chunked_scaled_dot_product_attention); + ttnn::bind_registered_operation( + module, + ttnn::transformer::chunked_scaled_dot_product_attention, + chunked_doc, + ttnn::pybind_overload_t{ + [](const ChunkedOperationType& self, + const ttnn::Tensor& input_tensor_q, + const ttnn::Tensor& input_tensor_k, + const ttnn::Tensor& input_tensor_v, + const ttnn::Tensor& page_table_tensor, + int64_t chunk_start_idx, + std::optional scale, + const std::optional& memory_config, + std::optional program_config, + std::optional compute_kernel_config, + uint8_t queue_id) { + return self( + queue_id, + input_tensor_q, + input_tensor_k, + input_tensor_v, + page_table_tensor, + chunk_start_idx, + scale, + memory_config, + program_config, + compute_kernel_config); + }, + py::arg("input_tensor_q").noconvert(), + py::arg("input_tensor_k").noconvert(), + py::arg("input_tensor_v").noconvert(), + py::arg("page_table_tensor").noconvert(), + py::arg("chunk_start_idx"), + py::kw_only(), + py::arg("scale").noconvert() = std::nullopt, + py::arg("memory_config").noconvert() = std::nullopt, + py::arg("program_config").noconvert() = std::nullopt, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("queue_id") = 0, + }); } } // namespace ttnn::operations::transformer From 3bf0fd65a23bcd17f9a9d30d9f75c5406f418102 Mon Sep 17 00:00:00 2001 From: Andrew Fuller Date: Thu, 12 Dec 2024 14:09:52 -0500 Subject: [PATCH 09/13] Run Blackhole on 22.04 (#15970) ### Ticket #14393 ### Problem description Move BH pipeline to 22.04 ### What's changed * Plumb the correct version of Python to the wheel builder * Install the repo to be able to run the unit tests -- not really sure why it fails in 22.04 but was okay with 20.04. ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12300743997 - [x] Blackhole Post commit (if applicable) https://github.com/tenstorrent/tt-metal/actions/runs/12289973402 --- .github/actions/prepare-metal-run/action.yml | 6 ++++++ .github/workflows/_build-wheels-impl.yaml | 12 ++++++++++++ .github/workflows/blackhole-post-commit.yaml | 10 +++++----- .github/workflows/build-and-unit-tests.yaml | 3 ++- .github/workflows/cpp-post-commit.yaml | 5 +++-- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/.github/actions/prepare-metal-run/action.yml b/.github/actions/prepare-metal-run/action.yml index 6ea399af154..0c1d95f3a73 100644 --- a/.github/actions/prepare-metal-run/action.yml +++ b/.github/actions/prepare-metal-run/action.yml @@ -9,6 +9,10 @@ inputs: description: "Whether to load with profiler" required: false default: 'false' + python-version: + description: 'Which version of Python to use to run the tests.' + required: false + default: '3.8' runs: using: "composite" @@ -25,6 +29,8 @@ runs: shell: bash run: tar -xvf ttm_${{ inputs.arch }}.tar - uses: ./.github/actions/install-python-deps + with: + python-version: ${{ inputs.python-version }} - name: Collect Workflow Telemetry if: ${{ !cancelled() }} uses: catchpoint/workflow-telemetry-action@v2 diff --git a/.github/workflows/_build-wheels-impl.yaml b/.github/workflows/_build-wheels-impl.yaml index 0d4c7179135..239729947f0 100644 --- a/.github/workflows/_build-wheels-impl.yaml +++ b/.github/workflows/_build-wheels-impl.yaml @@ -31,6 +31,17 @@ jobs: os: ${{ inputs.os }} - name: Clean up dirty files run: git clean -f -d + - name: Set Python Version + id: python-version + run: | + if [[ "${{ inputs.os }}" == "ubuntu-20.04" ]]; then + echo "python-version=3.8" >> $GITHUB_ENV + elif [[ "${{ inputs.os }}" == "ubuntu-22.04" ]]; then + echo "python-version=3.10" >> $GITHUB_ENV + else + echo "Unsupported OS version: ${{ inputs.os }}" + exit 1 + fi - uses: actions/setup-python@v5.0.0 with: cache: 'pip' @@ -47,6 +58,7 @@ jobs: if: ${{ inputs.from-precompiled }} with: arch: ${{ inputs.arch }} + python-version: ${{ env.python-version }} - name: Set precompiled dir for precompile builds if: ${{ inputs.from-precompiled }} # TT_FROM_PRECOMPILED_DIR env variable allows us to not re-run the full C++ build and instead diff --git a/.github/workflows/blackhole-post-commit.yaml b/.github/workflows/blackhole-post-commit.yaml index ba479a8b63c..237654297f9 100644 --- a/.github/workflows/blackhole-post-commit.yaml +++ b/.github/workflows/blackhole-post-commit.yaml @@ -29,14 +29,14 @@ jobs: uses: ./.github/workflows/build-artifact.yaml secrets: inherit with: - os: "ubuntu-20.04-amd64" + os: "ubuntu-22.04-amd64" arch: '["blackhole"]' build-docker: false build-wheels: needs: build-artifact uses: ./.github/workflows/_build-wheels-impl.yaml with: - os: "ubuntu-20.04" + os: "ubuntu-22.04" arch: "blackhole" from-precompiled: true # build-artifact-profiler: @@ -58,7 +58,7 @@ jobs: arch: blackhole runner-label: BH timeout: 30 - os: "ubuntu-20.04" + os: "ubuntu-22.04" fd-unit-tests: needs: build-wheels uses: ./.github/workflows/fast-dispatch-build-and-unit-tests.yaml @@ -66,7 +66,7 @@ jobs: with: arch: blackhole runner-label: BH - os: "ubuntu-20.04" + os: "ubuntu-22.04" # FD C++ Unit Tests cpp-unit-tests: needs: build-artifact @@ -76,7 +76,7 @@ jobs: arch: blackhole runner-label: BH timeout: 60 - os: "ubuntu-20.04" + os: "ubuntu-22.04" # profiler-regression: # needs: build-artifact-profiler diff --git a/.github/workflows/build-and-unit-tests.yaml b/.github/workflows/build-and-unit-tests.yaml index e51dced1890..4f6c87644a8 100644 --- a/.github/workflows/build-and-unit-tests.yaml +++ b/.github/workflows/build-and-unit-tests.yaml @@ -70,7 +70,8 @@ jobs: -e TT_METAL_SLOW_DISPATCH_MODE=1 -e LD_LIBRARY_PATH=${{ github.workspace }}/build/lib run_args: | - python3 -m pip install -r $(pwd)/tt_metal/python_env/requirements-dev.txt + pip install --force-reinstall pip==21.2.4 + pip install -r tt_metal/python_env/requirements-dev.txt pip install -e . ./tests/scripts/run_tests.sh --tt-arch ${{ inputs.arch }} --pipeline-type post_commit --dispatch-mode slow - uses: ./.github/actions/slack-report diff --git a/.github/workflows/cpp-post-commit.yaml b/.github/workflows/cpp-post-commit.yaml index c90c623cb76..fbaea0cf83a 100644 --- a/.github/workflows/cpp-post-commit.yaml +++ b/.github/workflows/cpp-post-commit.yaml @@ -84,8 +84,9 @@ jobs: -e ARCH_NAME=${{ inputs.arch }} -e LD_LIBRARY_PATH=${{ github.workspace }}/build/lib run_args: | - python3 -m pip install -r $(pwd)/tt_metal/python_env/requirements-dev.txt - python3 -m pip install -e . + pip install --force-reinstall pip==21.2.4 + pip install -r tt_metal/python_env/requirements-dev.txt + pip install -e . ${{ matrix.test-group.cmd }} - uses: ./.github/actions/slack-report if: ${{ failure() }} From 279ef8b34e52cef104b5bbeaf07f66a5dbe221bb Mon Sep 17 00:00:00 2001 From: Roman Furko Date: Thu, 12 Dec 2024 11:10:27 -0800 Subject: [PATCH 10/13] [tt-train] Add BPE tokenizer option (#15951) ### Problem description BPE tokenizer from original GPT2 training is important component. ### What's changed * Add BPE tokenizer option * Add config with BPE tokenizer (use max possible batch now - 2) * Substitute variable to unique_ptr to tokenizer * Add workaround to `tilize_with_zero_padding` issue for shape (1, 1, 1, 50304) ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12289134453 - [x] New/Existing tests provide coverage for changes --- .../training_shakespear_nanogpt_bpe.yaml | 21 ++++++++++++++ tt-train/sources/examples/nano_gpt/main.cpp | 29 ++++++++++++++----- tt-train/sources/ttml/CMakeLists.txt | 18 +++++++++++- .../sources/ttml/core/tt_tensor_utils.cpp | 15 ++++++++-- tt-train/sources/ttml/datasets/utils.cpp | 17 ++++++----- tt-train/sources/ttml/datasets/utils.hpp | 3 +- .../sources/ttml/tokenizers/bpe_tokenizer.hpp | 2 +- .../ttml/tokenizers/char_tokenizer.hpp | 2 +- .../tokenizers/char_tokenizer_trainer.cpp | 4 +-- .../tokenizers/char_tokenizer_trainer.hpp | 2 +- .../ttml/tokenizers/tokenizer_base.hpp | 3 ++ tt-train/tests/core/tensor_utils_test.cpp | 16 ++++++++++ tt-train/tests/model/gpt2s_test.cpp | 4 +-- tt-train/tests/model/nano_gpt_test.cpp | 5 ++-- .../char_tokenizer_trainer_test.cpp | 28 +++++++++--------- 15 files changed, 125 insertions(+), 44 deletions(-) create mode 100644 tt-train/configs/training_shakespear_nanogpt_bpe.yaml diff --git a/tt-train/configs/training_shakespear_nanogpt_bpe.yaml b/tt-train/configs/training_shakespear_nanogpt_bpe.yaml new file mode 100644 index 00000000000..84dbef90ee2 --- /dev/null +++ b/tt-train/configs/training_shakespear_nanogpt_bpe.yaml @@ -0,0 +1,21 @@ +training_config: + project_name: "tt_train_nano_gpt" + seed: 5489 + model_save_interval: 500 + batch_size: 2 + num_epochs: 1 + max_steps: 5000 + learning_rate: 0.0003 + weight_decay: 0.01 + tokenizer_type: bpe + + transformer_config: + num_heads: 6 + embedding_dim: 384 + dropout_prob: 0.2 + num_blocks: 6 + vocab_size: 96 + max_sequence_length: 256 + positional_embedding_type: trainable + experimental: + use_composite_layernorm: false diff --git a/tt-train/sources/examples/nano_gpt/main.cpp b/tt-train/sources/examples/nano_gpt/main.cpp index 87136c9f079..89fb363187f 100644 --- a/tt-train/sources/examples/nano_gpt/main.cpp +++ b/tt-train/sources/examples/nano_gpt/main.cpp @@ -20,6 +20,7 @@ #include "ops/losses.hpp" #include "optimizers/adamw.hpp" #include "optimizers/sgd.hpp" +#include "tokenizers/bpe_tokenizer.hpp" #include "tokenizers/char_tokenizer.hpp" #include "ttnn_fixed/trivial_ttnn_ops.hpp" #include "utils.hpp" @@ -142,6 +143,7 @@ struct TrainingConfig { uint32_t gradient_accumulation_steps = 1; std::string model_path; std::string data_path; + std::string tokenizer_type = "char"; std::string scheduler_type = "identity"; ttml::models::gpt2::TransformerConfig transformer_config; @@ -163,6 +165,7 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) { training_config["gradient_accumulation_steps"].as(config.gradient_accumulation_steps); config.model_path = training_config["model_path"].as(""); config.data_path = training_config["data_path"].as(std::string(DATA_FOLDER) + "/shakespeare.txt"); + config.tokenizer_type = training_config["tokenizer_type"].as(config.tokenizer_type); config.scheduler_type = training_config["scheduler_type"].as(config.scheduler_type); config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]); @@ -208,6 +211,7 @@ int main(int argc, char **argv) { {"sequence_length", static_cast(config.transformer_config.max_sequence_length)}, {"max_steps", static_cast(config.max_steps)}, {"seed", static_cast(config.seed)}, + {"tokenizer_type", config.tokenizer_type}, {"use_kahan_summation", config.use_kahan_summation}, {"gradient_accumulation_steps", static_cast(config.gradient_accumulation_steps)}, {"positional_embedding_type", @@ -236,10 +240,22 @@ int main(int argc, char **argv) { fmt::print("Seed {}\n", ttml::autograd::ctx().get_seed()); auto sequence_length = config.transformer_config.max_sequence_length; - auto [dataset, tokenizer] = - ttml::datasets::create_in_memory_token_dataset(text, sequence_length); + auto create_dataset_and_tokenizer = [](const auto &text, const auto sequence_length, const auto &tokenizer_type) { + if (tokenizer_type == "char") { + return ttml::datasets::create_in_memory_token_dataset( + text, sequence_length); + } else if (tokenizer_type == "bpe") { + return ttml::datasets::create_in_memory_token_dataset( + text, sequence_length); + } else { + throw std::runtime_error("Unknown tokenizer type: " + tokenizer_type); + } + }; + + auto [dataset, tokenizer] = create_dataset_and_tokenizer(text, sequence_length, config.tokenizer_type); fmt::print("Dataset size: {}\n", dataset.get_size()); - fmt::print("Vocab size: {}\n", tokenizer.get_vocab_size()); + fmt::print("Vocab size: {}\n", tokenizer->get_vocab_size()); + fmt::print("Tokenizer type: {}\n", config.tokenizer_type); auto *device = &ttml::autograd::ctx().get_device(); device->enable_program_cache(); @@ -269,8 +285,7 @@ int main(int argc, char **argv) { mask, ttml::core::create_shape({config.batch_size, num_heads, sequence_length, sequence_length}), device)); std::function && samples)> collate_fn = - [sequence_length, num_heads, vocab_size = tokenizer.get_vocab_size(), device, &cached_data]( - std::vector &&samples) { + [sequence_length, num_heads, device, &cached_data](std::vector &&samples) { auto start_timer = std::chrono::high_resolution_clock::now(); const uint32_t batch_size = samples.size(); std::vector &data = cached_data.data; @@ -302,7 +317,7 @@ int main(int argc, char **argv) { auto train_dataloader = DataLoader(dataset, /* batch_size */ config.batch_size, /* shuffle */ true, collate_fn); fmt::print("Overriding vocab size to be divisible by 32\n"); - config.transformer_config.vocab_size = round_up_to_tile(tokenizer.get_vocab_size()); + config.transformer_config.vocab_size = round_up_to_tile(tokenizer->get_vocab_size()); auto model = ttml::models::gpt2::create(config.transformer_config); auto adamw_params = ttml::optimizers::AdamWConfig(); @@ -324,7 +339,7 @@ int main(int argc, char **argv) { if (is_eval) { fmt::print("\nEvaluation started\n"); for (;;) { - generate(model, tokenizer, config.transformer_config.max_sequence_length, num_heads); + generate(model, *tokenizer, config.transformer_config.max_sequence_length, num_heads); } fmt::print("\nEvaluation finished\n"); return 0; diff --git a/tt-train/sources/ttml/CMakeLists.txt b/tt-train/sources/ttml/CMakeLists.txt index 0e241cd7bb6..9efb0b24343 100644 --- a/tt-train/sources/ttml/CMakeLists.txt +++ b/tt-train/sources/ttml/CMakeLists.txt @@ -130,4 +130,20 @@ target_link_libraries( ) target_compile_options(wandbcpp PUBLIC -stdlib=libc++) -add_definitions(-DTOKENIZERS_DATA_PATH="${CMAKE_CURRENT_SOURCE_DIR}/data/tokenizers") +add_definitions(-DTOKENIZERS_DATA_PATH="${CMAKE_SOURCE_DIR}/data") + +set(TOKENIZER_URL "https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1/resolve/main/tokenizer.json") +set(TOKENIZER_FILE "${CMAKE_SOURCE_DIR}/data/gpt2-tokenizer.json") + +# Check if the file already exists before downloading +if(NOT EXISTS "${TOKENIZER_FILE}") + message(STATUS "Downloading Tokenizer text file to ${TOKENIZER_FILE}") + file( + DOWNLOAD + ${TOKENIZER_URL} + ${TOKENIZER_FILE} + SHOW_PROGRESS + ) +else() + message(STATUS "Tokenizer text file already exists at ${TOKENIZER_FILE}, skipping download.") +endif() diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index 706c8d98dfc..00e2c0761c7 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -245,9 +245,18 @@ tt::tt_metal::Tensor from_vector( auto owned_buffer = create_owned_buffer_from_vector_of_floats(buffer, data_type); // remove possible paddings from the shape (it conflicts with ROW MAJOR) auto output = tt::tt_metal::Tensor(OwnedStorage{owned_buffer}, logical_shape, data_type, Layout::ROW_MAJOR); - output = ttnn::to_device(output, device, output_mem_config); - if (layout == Layout::TILE) { - output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true); + + const size_t MAX_TILE_DIMENSION = 32678; + // Temporary workaround for the issue with tilize for large size + // https://github.com/tenstorrent/tt-metal/issues/15950 + if (logical_shape[-1] > MAX_TILE_DIMENSION && layout == Layout::TILE) { + output = ttnn::to_layout(output, Layout::TILE, std::nullopt, output_mem_config, device); + output = ttnn::to_device(output, device, output_mem_config); + } else { + output = ttnn::to_device(output, device, output_mem_config); + if (layout == Layout::TILE) { + output = ttnn::tilize_with_zero_padding(output, output_mem_config, std::nullopt, /* multicore */ true); + } } return output; diff --git a/tt-train/sources/ttml/datasets/utils.cpp b/tt-train/sources/ttml/datasets/utils.cpp index ee42f0a55ec..aa85fcd8914 100644 --- a/tt-train/sources/ttml/datasets/utils.cpp +++ b/tt-train/sources/ttml/datasets/utils.cpp @@ -7,6 +7,7 @@ #include "datasets/in_memory_token_dataset.hpp" #include "tokenizers/bpe_tokenizer.hpp" #include "tokenizers/char_tokenizer_trainer.hpp" +#include "tokenizers/tokenizer_base.hpp" namespace { constexpr auto gpt2_tokenizer_file_name = "/gpt2-tokenizer.json"; @@ -14,22 +15,22 @@ constexpr auto gpt2_tokenizer_file_name = "/gpt2-tokenizer.json"; namespace ttml::datasets { template <> -std::tuple create_in_memory_token_dataset( - const std::string &text, uint32_t seq_length) { - tokenizers::CharTokenizer tokenizer = tokenizers::CharTokenizerTrainer::train(text); +std::tuple> +create_in_memory_token_dataset(const std::string &text, uint32_t seq_length) { + std::unique_ptr tokenizer = tokenizers::CharTokenizerTrainer::train(text); - std::vector tokenized_text = tokenizer.encode(text); + std::vector tokenized_text = tokenizer->encode(text); return {InMemoryTokenDataset(tokenized_text, seq_length), std::move(tokenizer)}; } template <> -std::tuple create_in_memory_token_dataset( - const std::string &text, uint32_t seq_length) { +std::tuple> +create_in_memory_token_dataset(const std::string &text, uint32_t seq_length) { auto json_file_path = std::string(TOKENIZERS_DATA_PATH) + gpt2_tokenizer_file_name; - auto tokenizer = tokenizers::BPETokenizer(json_file_path); + std::unique_ptr tokenizer = std::make_unique(json_file_path); - const std::vector tokenized_text = tokenizer.encode(text); + const std::vector tokenized_text = tokenizer->encode(text); return {InMemoryTokenDataset(tokenized_text, seq_length), std::move(tokenizer)}; } diff --git a/tt-train/sources/ttml/datasets/utils.hpp b/tt-train/sources/ttml/datasets/utils.hpp index 8e4bbdc6688..1c96d37ec0d 100644 --- a/tt-train/sources/ttml/datasets/utils.hpp +++ b/tt-train/sources/ttml/datasets/utils.hpp @@ -10,11 +10,12 @@ #include "autograd/auto_context.hpp" #include "dataset_subset.hpp" #include "in_memory_token_dataset.hpp" +#include "tokenizers/tokenizer_base.hpp" namespace ttml::datasets { template -std::tuple create_in_memory_token_dataset( +std::tuple> create_in_memory_token_dataset( const std::string& text, uint32_t seq_length); template diff --git a/tt-train/sources/ttml/tokenizers/bpe_tokenizer.hpp b/tt-train/sources/ttml/tokenizers/bpe_tokenizer.hpp index 7e86ef52222..8dec1f7ed65 100644 --- a/tt-train/sources/ttml/tokenizers/bpe_tokenizer.hpp +++ b/tt-train/sources/ttml/tokenizers/bpe_tokenizer.hpp @@ -21,7 +21,7 @@ class BPETokenizer : public TokenizerBase { [[nodiscard]] std::vector encode(const std::string& text) const override; [[nodiscard]] std::string decode(const std::vector& tokens) const override; - [[nodiscard]] uint32_t get_vocab_size() const; + [[nodiscard]] uint32_t get_vocab_size() const override; private: class BPETokenizerImpl; diff --git a/tt-train/sources/ttml/tokenizers/char_tokenizer.hpp b/tt-train/sources/ttml/tokenizers/char_tokenizer.hpp index f5f84ca45c1..a8f50e7eaa8 100644 --- a/tt-train/sources/ttml/tokenizers/char_tokenizer.hpp +++ b/tt-train/sources/ttml/tokenizers/char_tokenizer.hpp @@ -33,7 +33,7 @@ class CharTokenizer : public TokenizerBase { [[nodiscard]] const CharTokenizer::Vocabulary& get_vocabulary() const; - [[nodiscard]] uint32_t get_vocab_size() const; + [[nodiscard]] uint32_t get_vocab_size() const override; ~CharTokenizer() override = default; diff --git a/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.cpp b/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.cpp index 6fec9cbbe51..abcb9bafdc4 100644 --- a/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.cpp +++ b/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.cpp @@ -10,7 +10,7 @@ namespace ttml::tokenizers { -CharTokenizer CharTokenizerTrainer::train(const std::string& text, bool add_padding_token) { +std::unique_ptr CharTokenizerTrainer::train(const std::string& text, bool add_padding_token) { CharTokenizer::Vocabulary vocabulary; // using set instead of unordered_set to stabilize order @@ -24,7 +24,7 @@ CharTokenizer CharTokenizerTrainer::train(const std::string& text, bool add_padd vocabulary[std::string(1, chr)] = static_cast(vocabulary.size()); } - return CharTokenizer(vocabulary); + return std::make_unique(vocabulary); } } // namespace ttml::tokenizers diff --git a/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.hpp b/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.hpp index b0b5f782156..18059482b37 100644 --- a/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.hpp +++ b/tt-train/sources/ttml/tokenizers/char_tokenizer_trainer.hpp @@ -10,6 +10,6 @@ namespace ttml::tokenizers { // right now it is very simple class CharTokenizerTrainer { public: - [[nodiscard]] static CharTokenizer train(const std::string& text, bool add_padding_token = true); + [[nodiscard]] static std::unique_ptr train(const std::string& text, bool add_padding_token = true); }; } // namespace ttml::tokenizers diff --git a/tt-train/sources/ttml/tokenizers/tokenizer_base.hpp b/tt-train/sources/ttml/tokenizers/tokenizer_base.hpp index f62c77294a6..8ae835cb885 100644 --- a/tt-train/sources/ttml/tokenizers/tokenizer_base.hpp +++ b/tt-train/sources/ttml/tokenizers/tokenizer_base.hpp @@ -25,6 +25,9 @@ class TokenizerBase { // Pure virtual function to decode a vector of token IDs back into a string [[nodiscard]] virtual std::string decode(const std::vector& tokens) const = 0; + + // Pure virtual function to get the vocabulary size + [[nodiscard]] virtual uint32_t get_vocab_size() const = 0; }; } // namespace ttml::tokenizers diff --git a/tt-train/tests/core/tensor_utils_test.cpp b/tt-train/tests/core/tensor_utils_test.cpp index 72e518de091..97e93534297 100644 --- a/tt-train/tests/core/tensor_utils_test.cpp +++ b/tt-train/tests/core/tensor_utils_test.cpp @@ -27,6 +27,22 @@ TEST(TensorUtilsTest, TestFloatToFromTensorEven) { } } +TEST(TensorUtilsTest, TestFloatToFromTensorGPT2Tokenizer) { + auto* device = &ttml::autograd::ctx().get_device(); + const size_t N = 50304; + std::vector test_data(N, 0.F); + + auto shape = ttml::core::create_shape({1, 1, 1, N}); + auto tensor = ttml::core::from_vector(test_data, shape, device); + + auto vec_back = ttml::core::to_vector(tensor); + + ASSERT_EQ(vec_back.size(), test_data.size()); + for (size_t i = 0; i < test_data.size(); i++) { + EXPECT_EQ(vec_back[i], test_data[i]); + } +} + TEST(TensorUtilsTest, TestFloatToFromTensorOdd) { auto* device = &ttml::autograd::ctx().get_device(); std::vector test_data = {30.F, 20.F, 2.F}; diff --git a/tt-train/tests/model/gpt2s_test.cpp b/tt-train/tests/model/gpt2s_test.cpp index bfad28597d8..5791f4628c3 100644 --- a/tt-train/tests/model/gpt2s_test.cpp +++ b/tt-train/tests/model/gpt2s_test.cpp @@ -53,8 +53,8 @@ TEST(GPT2SBatch64Test, Matmul) { {{{768, 65536}, {65536, 2304}, false, false}, ExpectedResult::OK}, {{{65536, 768}, {65536, 2304}, true, false}, ExpectedResult::OK}, {{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::ERROR}, - {{{65536, 768}, {50257, 768}, false, true}, ExpectedResult::ERROR}, - {{{65536, 50257}, {50257, 768}, false, false}, ExpectedResult::ERROR}, + {{{65536, 768}, {50304, 768}, false, true}, ExpectedResult::ERROR}, + {{{65536, 50304}, {50304, 768}, false, false}, ExpectedResult::ERROR}, }; auto run_matmul = [](auto& a, auto& b, bool transpose_a, bool transpose_b) { diff --git a/tt-train/tests/model/nano_gpt_test.cpp b/tt-train/tests/model/nano_gpt_test.cpp index b5f02ede312..2fd35899e3a 100644 --- a/tt-train/tests/model/nano_gpt_test.cpp +++ b/tt-train/tests/model/nano_gpt_test.cpp @@ -95,8 +95,7 @@ void train_test(bool use_moreh_adamw = false) { mask, ttml::core::create_shape({config.batch_size, num_heads, sequence_length, sequence_length}), device)); std::function && samples)> collate_fn = - [sequence_length, num_heads, vocab_size = tokenizer.get_vocab_size(), device, &cached_data]( - std::vector &&samples) { + [sequence_length, num_heads, device, &cached_data](std::vector &&samples) { auto start_timer = std::chrono::high_resolution_clock::now(); const uint32_t batch_size = samples.size(); std::vector &data = cached_data.data; @@ -126,7 +125,7 @@ void train_test(bool use_moreh_adamw = false) { auto train_dataloader = DataLoader(dataset, /* batch_size */ config.batch_size, /* shuffle */ true, collate_fn); fmt::print("Overriding vocab size to be divisible by 32\n"); - config.transformer_config.vocab_size = (tokenizer.get_vocab_size() + 31) / 32 * 32; + config.transformer_config.vocab_size = (tokenizer->get_vocab_size() + 31) / 32 * 32; auto model = ttml::models::gpt2::create(config.transformer_config); auto adamw_params = ttml::optimizers::AdamWConfig(); diff --git a/tt-train/tests/tokenizers/char_tokenizer_trainer_test.cpp b/tt-train/tests/tokenizers/char_tokenizer_trainer_test.cpp index 8e1490456b0..bed54a3e5d7 100644 --- a/tt-train/tests/tokenizers/char_tokenizer_trainer_test.cpp +++ b/tt-train/tests/tokenizers/char_tokenizer_trainer_test.cpp @@ -18,18 +18,18 @@ class CharTokenizerTrainerTest : public ::testing::Test { // Test that the trainer creates a tokenizer with the correct vocabulary TEST_F(CharTokenizerTrainerTest, TrainVocabulary) { std::string text = "hello world"; - CharTokenizer tokenizer = trainer.train(text); + std::unique_ptr tokenizer_ptr = trainer.train(text); CharTokenizer::Vocabulary expected_vocabulary = { {" ", 1}, {"d", 2}, {"e", 3}, {"h", 4}, {"l", 5}, {"o", 6}, {"r", 7}, {"w", 8}}; // Verify that the generated vocabulary matches the expected one const auto special_tokens_count = 3UL; - ASSERT_EQ(tokenizer.get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); + ASSERT_EQ(tokenizer_ptr->get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); for (const auto& pair : expected_vocabulary) { - auto it = tokenizer.get_vocabulary().find(pair.first); - ASSERT_NE(it, tokenizer.get_vocabulary().end()); + auto it = tokenizer_ptr->get_vocabulary().find(pair.first); + ASSERT_NE(it, tokenizer_ptr->get_vocabulary().end()); ASSERT_EQ(it->second, pair.second); } } @@ -37,17 +37,17 @@ TEST_F(CharTokenizerTrainerTest, TrainVocabulary) { // Test that the trainer handles duplicate characters correctly TEST_F(CharTokenizerTrainerTest, TrainWithDuplicateCharacters) { std::string text = "aaaabbbb"; - CharTokenizer tokenizer = trainer.train(text); + std::unique_ptr tokenizer_ptr = trainer.train(text); CharTokenizer::Vocabulary expected_vocabulary = {{"a", 1}, {"b", 2}}; // Verify that the generated vocabulary has no duplicates const auto special_tokens_count = 3UL; - ASSERT_EQ(tokenizer.get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); + ASSERT_EQ(tokenizer_ptr->get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); for (const auto& pair : expected_vocabulary) { - auto it = tokenizer.get_vocabulary().find(pair.first); - ASSERT_NE(it, tokenizer.get_vocabulary().end()); + auto it = tokenizer_ptr->get_vocabulary().find(pair.first); + ASSERT_NE(it, tokenizer_ptr->get_vocabulary().end()); ASSERT_EQ(it->second, pair.second); } } @@ -55,17 +55,17 @@ TEST_F(CharTokenizerTrainerTest, TrainWithDuplicateCharacters) { // Test that the trainer starts indexing from the specified starting index TEST_F(CharTokenizerTrainerTest, TrainWithNoPaddingToken) { std::string text = "abc"; - CharTokenizer tokenizer = trainer.train(text, /* add_padding_token */ false); + std::unique_ptr tokenizer_ptr = trainer.train(text, /* add_padding_token */ false); CharTokenizer::Vocabulary expected_vocabulary = {{"a", 0}, {"b", 1}, {"c", 2}}; // Verify that the generated vocabulary starts at the correct index const auto special_tokens_count = 2UL; - ASSERT_EQ(tokenizer.get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); + ASSERT_EQ(tokenizer_ptr->get_vocabulary().size(), expected_vocabulary.size() + special_tokens_count); for (const auto& pair : expected_vocabulary) { - auto it = tokenizer.get_vocabulary().find(pair.first); - ASSERT_NE(it, tokenizer.get_vocabulary().end()); + auto it = tokenizer_ptr->get_vocabulary().find(pair.first); + ASSERT_NE(it, tokenizer_ptr->get_vocabulary().end()); ASSERT_EQ(it->second, pair.second); } } @@ -73,9 +73,9 @@ TEST_F(CharTokenizerTrainerTest, TrainWithNoPaddingToken) { // Test that the trainer handles an empty string correctly TEST_F(CharTokenizerTrainerTest, TrainWithEmptyString) { std::string text; - CharTokenizer tokenizer = trainer.train(text, /* add_padding_token */ false); + std::unique_ptr tokenizer_ptr = trainer.train(text, /* add_padding_token */ false); // Verify that the generated vocabulary is empty const auto special_tokens_count = 2UL; - ASSERT_EQ(tokenizer.get_vocabulary().size(), special_tokens_count); + ASSERT_EQ(tokenizer_ptr->get_vocabulary().size(), special_tokens_count); } From ee62a8638dae436582d6c6039b734e2ce8dc8141 Mon Sep 17 00:00:00 2001 From: Kartik Paigwar <132708568+kpaigwar@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:32:56 -0500 Subject: [PATCH 11/13] #15877: added support for subcoregrid in sdpa decode (#15927) ### Ticket [Link to Github Issue](https://github.com/tenstorrent/tt-metal/issues/15877) ### Problem description SDPA decode can now run on subcoregrids if sub_core_grids is passed in SDPA Program Config --- ...est_scaled_dot_product_attention_decode.py | 95 ++++++++++++++++++- .../operations/transformer/sdpa_config.hpp | 1 + .../device/sdpa_decode_program_factory.cpp | 80 +++++++++++----- .../transformer/transformer_pybind.cpp | 4 +- 4 files changed, 151 insertions(+), 29 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py index 8be5f8b317e..1c908f0ab94 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py @@ -321,11 +321,21 @@ def run_test_sdpa_decode_single_iter( sharded_out=False, start_indices=None, causal=True, + start_core=ttnn.CoreCoord(0, 0), + sub_core_grids=None, ): compute_grid_size = device.compute_with_storage_grid_size() - if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: - pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") - + if sub_core_grids is None: + if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: + pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") + else: + unharvested_grid_size = (7, 10) + if compute_grid_size.x > unharvested_grid_size[0] or compute_grid_size.y > unharvested_grid_size[1]: + pytest.skip(f"Need {unharvested_grid_size} grid size to run this test but core grid is {compute_grid_size}") + if grid_size[0] * grid_size[1] > sub_core_grids.num_cores(): + pytest.skip( + f"Need {grid_size[0]*grid_size[1]} grid size to run this test but core grid is {sub_core_grids.num_cores()}" + ) padded_num_heads = nearest_pow_2(nearest_n(nh, n=32)) torch.manual_seed(1234) @@ -346,7 +356,14 @@ def run_test_sdpa_decode_single_iter( ) dram_memcfg = ttnn.DRAM_MEMORY_CONFIG - shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) + if sub_core_grids is None: + shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) + compute_sub_core_grids = None + else: + shard_grid = ttnn.num_cores_to_corerangeset_in_subcoregrids(start_core, b, sub_core_grids, row_wise=True) + compute_sub_core_grids = ttnn.num_cores_to_corerangeset_in_subcoregrids( + start_core, grid_size[0] * grid_size[1], sub_core_grids, row_wise=True + ) shard_spec = ttnn.ShardSpec(shard_grid, (padded_num_heads, d), ttnn.ShardOrientation.ROW_MAJOR, False) height_sharded_memcfg = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) @@ -364,6 +381,7 @@ def run_test_sdpa_decode_single_iter( k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, + sub_core_grids=compute_sub_core_grids, q_chunk_size=padded_num_heads, k_chunk_size=k_chunk_size, exp_approx_mode=False, @@ -904,6 +922,75 @@ def test_sdpa_decode_sharded(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype ) +@skip_for_blackhole("Unsupported on BH, see #12349") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], indirect=True) +@pytest.mark.parametrize( + "dtype, q_dtype", + [ + [ttnn.bfloat8_b, ttnn.bfloat16], + ], + ids=[ + "bfp8_cache_bf16_act", + ], +) +@pytest.mark.parametrize( + "b, nh, nkv, s, d, grid_size", + ( + [8, 8, 1, 2048, 128, (8, 4)], + [8, 8, 1, 256, 128, (8, 4)], + ), # Llama2-70B +) +@pytest.mark.parametrize( + "start_core, sub_core_grids", + [ + ( + ttnn.CoreCoord(1, 0), + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + ), + ], +) +def test_sdpa_decode_sharded_on_subcoregrids( + device, use_program_cache, b, nh, nkv, s, d, dtype, grid_size, q_dtype, start_core, sub_core_grids +): + run_test_sdpa_decode_single_iter( + device, + b, + nh, + nkv, + s, + d, + dtype, + grid_size, + q_dtype, + sharded_in=True, + sharded_out=True, + start_core=start_core, + sub_core_grids=sub_core_grids, + ) + run_test_sdpa_decode_single_iter( + device, + b, + nh, + nkv, + s, + d, + dtype, + grid_size, + q_dtype, + sharded_in=True, + sharded_out=True, + start_core=start_core, + sub_core_grids=sub_core_grids, + ) + assert device.num_program_cache_entries() == 1 + + @skip_for_blackhole("Unsupported on BH, see #12349") @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @pytest.mark.skip("Skipping Perf Test in CI") diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp index 8dc18614f00..c968f5d8a7f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp @@ -11,6 +11,7 @@ namespace ttnn::operations::transformer { struct SDPAProgramConfig { CoreCoord compute_with_storage_grid_size; + std::optional sub_core_grids; std::size_t q_chunk_size; std::size_t k_chunk_size; std::optional exp_approx_mode; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 7c09d0e4de0..9615b729578 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -122,9 +122,20 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( CoreCoord grid_size = program_config.has_value() ? program_config->compute_with_storage_grid_size : device->compute_with_storage_grid_size(); - auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); uint32_t num_cores_available = grid_size.x * grid_size.y; + CoreRangeSet core_grid; + bool on_subcoregrid = false; + if (program_config.has_value() && program_config->sub_core_grids.has_value()) { + core_grid = program_config->sub_core_grids.value(); + TT_FATAL( + core_grid.num_cores() == num_cores_available, + "Number of cores in sub_core_grids must match the number of cores available"); + on_subcoregrid = true; + } else { + core_grid = CoreRangeSet(std::vector{CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1})}); + } + uint32_t num_cores_in_grid = device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; TT_FATAL(num_cores_available <= num_cores_in_grid, "Expected number of cores available to be less than or equal to the number of cores in the grid, got {} and {}", num_cores_available, num_cores_in_grid); TT_FATAL(num_cores_available >= B, "Expect number of cores available to be greater or equal to batch size, got {} and {}", num_cores_available, B); @@ -154,32 +165,53 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // h_worker2) head_reducer2 to head_reducerk then send the result to head_reducer1, which is also the batch_output1 std::vector core_group; std::vector core_group_idle; - if (is_q_sharded || is_output_sharded) { - int reducer_idx = 0; - int worker_idx = num_output_cores; - - for (int i = 0; i < num_cores_available; ++i) { - CoreCoord core; - if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { - core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x}; - reducer_idx++; - } else { - core = {worker_idx % grid_size.x, worker_idx / grid_size.x}; - worker_idx++; - } - if (i < num_active_cores) { - core_group.push_back(core); - } else { - core_group_idle.push_back(core); + if (on_subcoregrid) { + if (is_q_sharded || is_output_sharded) { + auto cores_vec = corerange_to_cores(core_grid, num_cores_available, true); + int reducer_idx = 0; + int worker_idx = num_output_cores; + for (int i = 0; i < num_cores_available; ++i) { + if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { + i < num_active_cores ? core_group.push_back(cores_vec[reducer_idx]) + : core_group_idle.push_back(cores_vec[reducer_idx]); + reducer_idx++; + } else { + i < num_active_cores ? core_group.push_back(cores_vec[worker_idx]) + : core_group_idle.push_back(cores_vec[worker_idx]); + worker_idx++; + } } + } else { + TT_FATAL(false, "We only support SDPA on subcoregrids with sharded Q and sharded output"); } } else { - for (int i = 0; i < num_cores_available; ++i) { - CoreCoord core = {i % grid_size.x, i / grid_size.x}; - if (i < num_active_cores) { - core_group.push_back(core); - } else { - core_group_idle.push_back(core); + if (is_q_sharded || is_output_sharded) { + int reducer_idx = 0; + int worker_idx = num_output_cores; + + for (int i = 0; i < num_cores_available; ++i) { + CoreCoord core; + if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { + core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x}; + reducer_idx++; + } else { + core = {worker_idx % grid_size.x, worker_idx / grid_size.x}; + worker_idx++; + } + if (i < num_active_cores) { + core_group.push_back(core); + } else { + core_group_idle.push_back(core); + } + } + } else { + for (int i = 0; i < num_cores_available; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + if (i < num_active_cores) { + core_group.push_back(core); + } else { + core_group_idle.push_back(core); + } } } } diff --git a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp index a1c1129cea6..75ae4fffad6 100644 --- a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp @@ -21,13 +21,15 @@ namespace py = pybind11; void py_module(py::module& module) { py::class_(module, "SDPAProgramConfig") .def( - py::init>(), + py::init, std::size_t, std::size_t, std::optional>(), py::kw_only(), py::arg("compute_with_storage_grid_size"), + py::arg("sub_core_grids") = std::nullopt, py::arg("q_chunk_size").noconvert(), py::arg("k_chunk_size").noconvert(), py::arg("exp_approx_mode") = std::nullopt) .def_readwrite("compute_with_storage_grid_size", &SDPAProgramConfig::compute_with_storage_grid_size) + .def_readwrite("sub_core_grids", &SDPAProgramConfig::sub_core_grids) .def_readwrite("q_chunk_size", &SDPAProgramConfig::q_chunk_size) .def_readwrite("k_chunk_size", &SDPAProgramConfig::k_chunk_size) .def_readwrite("exp_approx_mode", &SDPAProgramConfig::exp_approx_mode); From e08a8324b725e1e3268ad3a94d13a430ed4dacb5 Mon Sep 17 00:00:00 2001 From: Mohamed Bahnas <116673264+mbahnasTT@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:46:26 -0800 Subject: [PATCH 12/13] [skip ci] Update SD demo README.md (#15949) ### Ticket Link to Github Issue ### Problem description Provide context for the problem. ### What's changed Describe the approach used to solve the problem. Summarize the changes made and its impact. ### Checklist - [ ] Post commit CI passes - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- .../demos/wormhole/stable_diffusion/README.md | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/models/demos/wormhole/stable_diffusion/README.md b/models/demos/wormhole/stable_diffusion/README.md index 1f9b3397114..ada6e5ef5f7 100644 --- a/models/demos/wormhole/stable_diffusion/README.md +++ b/models/demos/wormhole/stable_diffusion/README.md @@ -19,22 +19,36 @@ Inputs by default are provided from `input_data.json`. If you wish to change the > > If you are using Wormhole, you must set the `WH_ARCH_YAML` environment variable. > -> ``` +> ```sh > export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml > ``` To run the demo, make sure to build the project, activate the environment, and set the appropriate environment variables. For more information, refer [installation and build guide](https://github.com/tenstorrent/tt-metal/blob/main/INSTALLING.md). -Use `pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` to run the demo. - -If you wish to run the demo with a different input use `pytest --disable-warnings --input-path="" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo` - -If you would like to run an interactive demo which will prompt you for the input, use `pytest models/demos/wormhole/stable_diffusion/demo/demo.py::test_interactive_demo` - -Our second demo is designed to run poloclub/diffusiondb dataset, run this with `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb`. - -If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps, use `pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[-]` +```sh +pytest --disable-warnings --input-path="models/demos/wormhole/stable_diffusion/demo/input_data.json" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo +``` + +If you wish to run the demo with a different input: +```sh +pytest --disable-warnings --input-path="" models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo + +``` +If you would like to run an interactive demo which will prompt you for the input: +```sh +pytest models/demos/wormhole/stable_diffusion/demo/demo.py::test_interactive_demo +``` + +Our second demo is designed to run poloclub/diffusiondb dataset, run this with: +```sh +pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb +``` + +If you wish to run for `num_prompts` samples and `num_inference_steps` denoising steps: +```sh +pytest --disable-warnings models/demos/wormhole/stable_diffusion/demo/demo.py::test_demo_diffusiondb[-] +``` Note: ttnn stable diffusion utilizes `PNDMScheduler` and requires `num_inference_steps to be greater than or equal to 4`. [Reference](https://arxiv.org/pdf/2202.09778) From 660d2497bc004934ad839cf7c9ffdabaf10c63ce Mon Sep 17 00:00:00 2001 From: Andrew Fuller Date: Thu, 12 Dec 2024 17:56:52 -0500 Subject: [PATCH 13/13] Enable throw-by-val/catch-by-ref check (#15981) ### Ticket #15123 ### Problem description C++ should always throw exceptions by value and catch them by reference. ### What's changed Enabled the check in clang-tidy. Fixed the violations. --- .clang-tidy | 3 --- .../ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp | 2 +- .../ethernet/test_ethernet_hop_latencies_no_edm.cpp | 2 +- .../ethernet/test_ethernet_link_ping_latency_no_edm.cpp | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 9e775a89796..993d62c018a 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -26,11 +26,9 @@ Checks: > -bugprone-unhandled-self-assignment, -bugprone-unused-raii, -cert-env33-c, - -cert-err09-cpp, -cert-err33-c, -cert-err34-c, -cert-err58-cpp, - -cert-err61-cpp, -cert-flp30-c, -cert-msc30-c, -cert-msc32-c, @@ -121,7 +119,6 @@ Checks: > -misc-no-recursion, -misc-non-private-member-variables-in-classes, -misc-redundant-expression, - -misc-throw-by-value-catch-by-reference, -misc-unconventional-assign-operator, -misc-uniqueptr-reset-release, -misc-unused-parameters, diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp index f351b0a75d7..98eba2469d1 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_bidirectional_bandwidth_no_edm.cpp @@ -270,7 +270,7 @@ int main(int argc, char** argv) { } } } - } catch (std::exception e) { + } catch (std::exception& e) { test_fixture.TearDown(); return -1; } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp index 476e9890797..f9adf4f13fb 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_hop_latencies_no_edm.cpp @@ -533,7 +533,7 @@ int main(int argc, char** argv) { } } } - } catch (std::exception e) { + } catch (std::exception& e) { test_fixture.TearDown(); return -1; } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp index d2544f39271..d473716b05c 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_link_ping_latency_no_edm.cpp @@ -287,7 +287,7 @@ int main(int argc, char** argv) { } } } - } catch (std::exception e) { + } catch (std::exception& e) { test_fixture.TearDown(); return -1; }