Skip to content

Commit

Permalink
fixed linear scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
dmakoviichuk-tt committed Dec 10, 2024
1 parent 10ef463 commit 131360a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 34 deletions.
48 changes: 25 additions & 23 deletions tt-train/sources/ttml/schedulers/linear_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,47 @@
#include "optimizers/optimizer_base.hpp"

namespace ttml::schedulers {
LinearScheduler::LinearScheduler(optimizers::OptimizerBase *optimizer, float end_lr, int total_steps) :

LinearScheduler::LinearScheduler(
optimizers::OptimizerBase* optimizer, float start_factor, float end_factor, size_t total_steps) :
LRSchedulerBase(optimizer),
m_start_lr(optimizer->get_lr()),
m_end_lr(end_lr),
m_base_lr(optimizer->get_lr()),
m_last_lr(m_base_lr),
m_start_factor(start_factor),
m_end_factor(end_factor),
m_total_steps(total_steps),
m_current_step(0),
m_last_lr(m_start_lr) {
if (total_steps <= 0) {
throw std::invalid_argument("total_steps must be a positive integer.");
}
m_last_step(0) {
}
void LinearScheduler::step() {
m_current_step += 1;
m_last_step += 1;

// Compute progress ratio (clamped at 1.0)
float progress = static_cast<float>(m_current_step) / m_total_steps;
float progress = static_cast<float>(m_last_step) / m_total_steps;
progress = std::min(progress, 1.0f);

// Linearly interpolate between start_lr and end_lr
float new_lr = m_start_lr + (m_end_lr - m_start_lr) * progress;
float current_factor = m_start_factor + (m_end_factor - m_start_factor) * progress;
float new_lr = m_base_lr * current_factor;

get_optimizer()->set_lr(new_lr);
m_last_lr = new_lr;
}
float LinearScheduler::get_last_lr() const {
return m_last_lr;
}
float LinearScheduler::get_current_lr() const {
return get_optimizer()->get_lr();
}

void LinearScheduler::set_state_dict(const serialization::StateDict &dict) {
m_current_step = serialization::get_value_type<int>(dict, "m_current_step");
void LinearScheduler::set_state_dict(const serialization::StateDict& dict) {
m_last_step = serialization::get_value_type<int>(dict, "m_last_step");
m_last_lr = serialization::get_value_type<float>(dict, "m_last_lr");
}

serialization::StateDict LinearScheduler::get_state_dict() const {
serialization::StateDict res;
res["m_current_step"] = m_current_step;
res["m_last_step"] = m_last_step;
res["m_last_lr"] = m_last_lr;
return res;
};

float LinearScheduler::get_last_lr() const {
return m_last_lr;
}

float LinearScheduler::get_current_lr() const {
return get_optimizer()->get_lr();
}

} // namespace ttml::schedulers
13 changes: 6 additions & 7 deletions tt-train/sources/ttml/schedulers/linear_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

#pragma once

#include "core/not_null.hpp"
#include "scheduler_base.hpp"

// Assuming necessary includes and that LRSchedulerBase and OptimizerBase are defined
namespace ttml::schedulers {

class LinearScheduler : public LRSchedulerBase {
public:
LinearScheduler(optimizers::OptimizerBase *optimizer, float end_lr, int total_steps);
LinearScheduler(optimizers::OptimizerBase *optimizer, float start_factor, float end_factor, size_t total_steps);

void step() override;

Expand All @@ -20,14 +19,14 @@ class LinearScheduler : public LRSchedulerBase {
[[nodiscard]] float get_current_lr() const override;

[[nodiscard]] serialization::StateDict get_state_dict() const override;

void set_state_dict(const serialization::StateDict &dict) override;

private:
float m_start_lr = 0.F;
float m_end_lr = 0.F;
float m_base_lr = 0.F;
float m_start_factor = 0.F;
float m_end_factor = 0.F;
int m_total_steps = 0;
int m_current_step = 0;
int m_last_step = 0;
float m_last_lr = 0.F;
};
} // namespace ttml::schedulers
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/schedulers/sequential_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class SequentialScheduler : public LRSchedulerBase {
[[nodiscard]] float get_last_lr() const override;

[[nodiscard]] float get_current_lr() const override;
[[nodiscard]] serialization::StateDict get_state_dict() const override;

[[nodiscard]] serialization::StateDict get_state_dict() const override;
void set_state_dict(const serialization::StateDict &dict) override;

private:
Expand Down
36 changes: 33 additions & 3 deletions tt-train/tests/schedulers/schedulers_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ TEST(StepLRSchedulerTest, BasicDecay) {
// Tests for LinearScheduler
// ----------------------------------
TEST(LinearSchedulerTest, DecreasingLR) {
auto optimizer = std::make_unique<ttml::optimizers::MockOptimizer>(0.2f);
auto optimizer = std::make_unique<ttml::optimizers::MockOptimizer>(0.2F);

// Linearly go from 0.2 to 0.0 in 4 steps
ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 0.0f, 4);
ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 1.0F, 0.0F, 4);

// step 1: progress = 1/4=0.25 lr = 0.2 + (0.0-0.2)*0.25 = 0.2 - 0.05=0.15
scheduler.step();
Expand Down Expand Up @@ -159,7 +159,7 @@ TEST(SequentialSchedulerTest, ChainSchedulers) {
auto step_scheduler = std::make_unique<ttml::schedulers::StepScheduler>(optimizer.get(), 1, 0.5F);

// Then: LinearScheduler for 2 steps from current LR to 0.1
auto linear_scheduler = std::make_unique<ttml::schedulers::LinearScheduler>(optimizer.get(), 0.1F, 2);
auto linear_scheduler = std::make_unique<ttml::schedulers::LinearScheduler>(optimizer.get(), 1.0F, 0.1F, 2);

std::vector<std::unique_ptr<ttml::schedulers::LRSchedulerBase>> schedulers;
std::vector<size_t> milestones;
Expand Down Expand Up @@ -194,3 +194,33 @@ TEST(SequentialSchedulerTest, ChainSchedulers) {
seq_scheduler.step();
EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F);
}

TEST(SequentialSchedulerTest, WarmupSetup) {
auto start_lr = 3.e-4F;
auto optimizer = std::make_unique<ttml::optimizers::MockOptimizer>(start_lr);

// First: LinearScheduler for 10 steps from 0 to start_lr
auto warmup_scheduler = std::make_unique<ttml::schedulers::LinearScheduler>(optimizer.get(), 0.0F, 1.0F, 10);

// Then: LinearScheduler for 50 steps from start_lr to 0.1F * start_lr
auto linear_scheduler = std::make_unique<ttml::schedulers::LinearScheduler>(optimizer.get(), 1.F, 0.1F, 50);

std::vector<std::unique_ptr<ttml::schedulers::LRSchedulerBase>> schedulers;
std::vector<size_t> milestones;
schedulers.push_back(std::move(warmup_scheduler));
schedulers.push_back(std::move(linear_scheduler));
milestones.push_back(10);
milestones.push_back(50);
ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones));

for (int i = 0; i < 10; i++) {
// Linear warmup: 10 steps from 0 to start_lr
seq_scheduler.step();
EXPECT_FLOAT_EQ(optimizer->get_lr(), start_lr * (i + 1) / 10);
}
for (int i = 0; i < 50; i++) {
// Linear decay: 50 steps from start_lr to 0.1F * start_lr
seq_scheduler.step();
EXPECT_FLOAT_EQ(optimizer->get_lr(), start_lr * (1.0F - 0.9F * (i + 1) / 50));
}
}

0 comments on commit 131360a

Please sign in to comment.