diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.cpp b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp index 72468246e0e..0fcf51cc4d4 100644 --- a/tt-train/sources/ttml/schedulers/linear_scheduler.cpp +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp @@ -7,45 +7,47 @@ #include "optimizers/optimizer_base.hpp" namespace ttml::schedulers { -LinearScheduler::LinearScheduler(optimizers::OptimizerBase *optimizer, float end_lr, int total_steps) : + +LinearScheduler::LinearScheduler( + optimizers::OptimizerBase* optimizer, float start_factor, float end_factor, size_t total_steps) : LRSchedulerBase(optimizer), - m_start_lr(optimizer->get_lr()), - m_end_lr(end_lr), + m_base_lr(optimizer->get_lr()), + m_last_lr(m_base_lr), + m_start_factor(start_factor), + m_end_factor(end_factor), m_total_steps(total_steps), - m_current_step(0), - m_last_lr(m_start_lr) { - if (total_steps <= 0) { - throw std::invalid_argument("total_steps must be a positive integer."); - } + m_last_step(0) { } void LinearScheduler::step() { - m_current_step += 1; + m_last_step += 1; - // Compute progress ratio (clamped at 1.0) - float progress = static_cast(m_current_step) / m_total_steps; + float progress = static_cast(m_last_step) / m_total_steps; progress = std::min(progress, 1.0f); - // Linearly interpolate between start_lr and end_lr - float new_lr = m_start_lr + (m_end_lr - m_start_lr) * progress; + float current_factor = m_start_factor + (m_end_factor - m_start_factor) * progress; + float new_lr = m_base_lr * current_factor; get_optimizer()->set_lr(new_lr); m_last_lr = new_lr; } -float LinearScheduler::get_last_lr() const { - return m_last_lr; -} -float LinearScheduler::get_current_lr() const { - return get_optimizer()->get_lr(); -} - -void LinearScheduler::set_state_dict(const serialization::StateDict &dict) { - m_current_step = serialization::get_value_type(dict, "m_current_step"); +void LinearScheduler::set_state_dict(const serialization::StateDict& dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); m_last_lr = serialization::get_value_type(dict, "m_last_lr"); } + serialization::StateDict LinearScheduler::get_state_dict() const { serialization::StateDict res; - res["m_current_step"] = m_current_step; + res["m_last_step"] = m_last_step; res["m_last_lr"] = m_last_lr; return res; }; + +float LinearScheduler::get_last_lr() const { + return m_last_lr; +} + +float LinearScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} + } // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.hpp b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp index 35e9ad2a617..7879945d9f2 100644 --- a/tt-train/sources/ttml/schedulers/linear_scheduler.hpp +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp @@ -4,14 +4,13 @@ #pragma once -#include "core/not_null.hpp" #include "scheduler_base.hpp" -// Assuming necessary includes and that LRSchedulerBase and OptimizerBase are defined namespace ttml::schedulers { + class LinearScheduler : public LRSchedulerBase { public: - LinearScheduler(optimizers::OptimizerBase *optimizer, float end_lr, int total_steps); + LinearScheduler(optimizers::OptimizerBase *optimizer, float start_factor, float end_factor, size_t total_steps); void step() override; @@ -20,14 +19,14 @@ class LinearScheduler : public LRSchedulerBase { [[nodiscard]] float get_current_lr() const override; [[nodiscard]] serialization::StateDict get_state_dict() const override; - void set_state_dict(const serialization::StateDict &dict) override; private: - float m_start_lr = 0.F; - float m_end_lr = 0.F; + float m_base_lr = 0.F; + float m_start_factor = 0.F; + float m_end_factor = 0.F; int m_total_steps = 0; - int m_current_step = 0; + int m_last_step = 0; float m_last_lr = 0.F; }; } // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp index f096fd5cc7e..eac686c1b2c 100644 --- a/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp +++ b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp @@ -27,8 +27,8 @@ class SequentialScheduler : public LRSchedulerBase { [[nodiscard]] float get_last_lr() const override; [[nodiscard]] float get_current_lr() const override; - [[nodiscard]] serialization::StateDict get_state_dict() const override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; void set_state_dict(const serialization::StateDict &dict) override; private: diff --git a/tt-train/tests/schedulers/schedulers_test.cpp b/tt-train/tests/schedulers/schedulers_test.cpp index ad9b7ee6079..14b7791dcc3 100644 --- a/tt-train/tests/schedulers/schedulers_test.cpp +++ b/tt-train/tests/schedulers/schedulers_test.cpp @@ -123,10 +123,10 @@ TEST(StepLRSchedulerTest, BasicDecay) { // Tests for LinearScheduler // ---------------------------------- TEST(LinearSchedulerTest, DecreasingLR) { - auto optimizer = std::make_unique(0.2f); + auto optimizer = std::make_unique(0.2F); // Linearly go from 0.2 to 0.0 in 4 steps - ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 0.0f, 4); + ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 1.0F, 0.0F, 4); // step 1: progress = 1/4=0.25 lr = 0.2 + (0.0-0.2)*0.25 = 0.2 - 0.05=0.15 scheduler.step(); @@ -159,7 +159,7 @@ TEST(SequentialSchedulerTest, ChainSchedulers) { auto step_scheduler = std::make_unique(optimizer.get(), 1, 0.5F); // Then: LinearScheduler for 2 steps from current LR to 0.1 - auto linear_scheduler = std::make_unique(optimizer.get(), 0.1F, 2); + auto linear_scheduler = std::make_unique(optimizer.get(), 1.0F, 0.1F, 2); std::vector> schedulers; std::vector milestones; @@ -194,3 +194,33 @@ TEST(SequentialSchedulerTest, ChainSchedulers) { seq_scheduler.step(); EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); } + +TEST(SequentialSchedulerTest, WarmupSetup) { + auto start_lr = 3.e-4F; + auto optimizer = std::make_unique(start_lr); + + // First: LinearScheduler for 10 steps from 0 to start_lr + auto warmup_scheduler = std::make_unique(optimizer.get(), 0.0F, 1.0F, 10); + + // Then: LinearScheduler for 50 steps from start_lr to 0.1F * start_lr + auto linear_scheduler = std::make_unique(optimizer.get(), 1.F, 0.1F, 50); + + std::vector> schedulers; + std::vector milestones; + schedulers.push_back(std::move(warmup_scheduler)); + schedulers.push_back(std::move(linear_scheduler)); + milestones.push_back(10); + milestones.push_back(50); + ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones)); + + for (int i = 0; i < 10; i++) { + // Linear warmup: 10 steps from 0 to start_lr + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), start_lr * (i + 1) / 10); + } + for (int i = 0; i < 50; i++) { + // Linear decay: 50 steps from start_lr to 0.1F * start_lr + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), start_lr * (1.0F - 0.9F * (i + 1) / 50)); + } +}