diff --git a/tt-train/sources/examples/mnist_mlp/main.cpp b/tt-train/sources/examples/mnist_mlp/main.cpp index 0528933d7bc..868e827d296 100644 --- a/tt-train/sources/examples/mnist_mlp/main.cpp +++ b/tt-train/sources/examples/mnist_mlp/main.cpp @@ -4,6 +4,8 @@ #include #include +#include +#include #include #include #include @@ -19,7 +21,6 @@ #include "optimizers/sgd.hpp" #include "utils.hpp" #include "yaml-cpp/node/node.h" - using ttml::autograd::TensorPtr; using DatasetSample = std::pair, uint8_t>; @@ -95,7 +96,6 @@ int main(int argc, char **argv) { CLI11_PARSE(app, argc, argv); auto yaml_config = YAML::LoadFile(config_name); TrainingConfig config = parse_config(yaml_config); - // Load MNIST data const size_t num_targets = 10; const size_t num_features = 784; @@ -151,7 +151,7 @@ int main(int argc, char **argv) { auto optimizer = ttml::optimizers::SGD(model->parameters(), sgd_config); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { fmt::print("Loading model from {}\n", config.model_path); - load_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + load_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } // evaluate model before training (sanity check to get reasonable accuracy @@ -176,7 +176,7 @@ int main(int argc, char **argv) { } if (!config.model_path.empty() && training_step % config.model_save_interval == 0) { fmt::print("Saving model to {}\n", config.model_path); - save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } loss->backward(); @@ -196,7 +196,7 @@ int main(int argc, char **argv) { if (!config.model_path.empty()) { fmt::print("Saving model to {}\n", config.model_path); - save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name); + save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); } return 0; diff --git a/tt-train/sources/examples/mnist_mlp/utils.hpp b/tt-train/sources/examples/mnist_mlp/utils.hpp index 00b28a6ffe7..863cb9311eb 100644 --- a/tt-train/sources/examples/mnist_mlp/utils.hpp +++ b/tt-train/sources/examples/mnist_mlp/utils.hpp @@ -38,7 +38,7 @@ class Timers { }; template -void save_model_and_optimizer( +void save_training_state( std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, @@ -51,7 +51,7 @@ void save_model_and_optimizer( } template -void load_model_and_optimizer( +void load_training_state( std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, diff --git a/tt-train/sources/examples/nano_gpt/main.cpp b/tt-train/sources/examples/nano_gpt/main.cpp index e086b21095e..87136c9f079 100644 --- a/tt-train/sources/examples/nano_gpt/main.cpp +++ b/tt-train/sources/examples/nano_gpt/main.cpp @@ -142,6 +142,8 @@ struct TrainingConfig { uint32_t gradient_accumulation_steps = 1; std::string model_path; std::string data_path; + std::string scheduler_type = "identity"; + ttml::models::gpt2::TransformerConfig transformer_config; }; @@ -161,10 +163,17 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) { training_config["gradient_accumulation_steps"].as(config.gradient_accumulation_steps); config.model_path = training_config["model_path"].as(""); config.data_path = training_config["data_path"].as(std::string(DATA_FOLDER) + "/shakespeare.txt"); + config.scheduler_type = training_config["scheduler_type"].as(config.scheduler_type); + config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]); return config; } +const std::unordered_map< + std::string, + std::function(ttml::optimizers::OptimizerBase *, size_t)>> + schedulers = {{"identity", create_idendity_scheduler}, {"warmup_linear", create_warmup_with_linear_scheduler}}; + int main(int argc, char **argv) { auto result = signal(SIGINT, signal_handler); if (result == SIG_ERR) { @@ -186,7 +195,6 @@ int main(int argc, char **argv) { CLI11_PARSE(app, argc, argv); auto yaml_config = YAML::LoadFile(config_name); TrainingConfig config = parse_config(yaml_config); - wandbcpp::init({.project = config.project_name, .name = generate_run_name(config, add_time_to_name)}); wandbcpp::update_config({ {"model", "transformer"}, @@ -206,10 +214,12 @@ int main(int argc, char **argv) { config.transformer_config.positional_embedding_type == ttml::models::gpt2::PositionalEmbeddingType::Trainable ? "trainable" : "fixed"}, + {"scheduler_type", config.scheduler_type}, }); // set seed ttml::autograd::ctx().set_seed(config.seed); + auto schedule_func = schedulers.at(config.scheduler_type); std::string text; try { @@ -218,11 +228,11 @@ int main(int argc, char **argv) { std::cerr << e.what() << std::endl; return -1; } - fmt::print("Max steps {}\n", config.max_steps); fmt::print("Batch size {}\n", config.batch_size); fmt::print("Gradient accumulation steps {}\n", config.gradient_accumulation_steps); fmt::print("Total batch size {}\n", config.batch_size * config.gradient_accumulation_steps); + fmt::print("Scheduler type {}\n", config.scheduler_type); fmt::print("Seed {}\n", ttml::autograd::ctx().get_seed()); auto sequence_length = config.transformer_config.max_sequence_length; @@ -304,10 +314,10 @@ int main(int argc, char **argv) { fmt::print(" Weight decay: {}\n", adamw_params.weight_decay); fmt::print(" Use Kahan summation: {}\n", adamw_params.use_kahan_summation); auto optimizer = ttml::optimizers::AdamW(model->parameters(), adamw_params); - + auto scheduler = schedule_func(&optimizer, config.max_steps); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { fmt::print("Loading model from {}\n", config.model_path); - load_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + load_training_state(config.model_path, model, scheduler, "transformer", "adamw"); fmt::print("Model loaded after {} steps\n", optimizer.get_steps()); } @@ -345,6 +355,7 @@ int main(int argc, char **argv) { if (gradient_accumulator_helper.should_step()) { optimizer.step(); + scheduler->step(); auto global_step = optimizer.get_steps(); fmt::print("Step: {}, Loss: {}\n", global_step, gradient_accumulator_helper.average_loss()); loss_meter.update(gradient_accumulator_helper.average_loss()); @@ -353,11 +364,12 @@ int main(int argc, char **argv) { wandbcpp::log( {{"Step", (int)global_step}, {"Samples", (int)get_samples_count(global_step)}, - {"Loss", loss_meter.average()}}); + {"Loss", loss_meter.average()}, + {"Learning rate", optimizer.get_lr()}}); loss_meter.reset(); } if (!config.model_path.empty() && global_step % config.model_save_interval == 0) { - save_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + save_training_state(config.model_path, model, scheduler, "transformer", "adamw"); } if (global_step >= config.max_steps) { @@ -379,7 +391,7 @@ int main(int argc, char **argv) { } if (!config.model_path.empty()) { - save_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw"); + save_training_state(config.model_path, model, scheduler, "transformer", "adamw"); } auto end_timer = std::chrono::high_resolution_clock::now(); diff --git a/tt-train/sources/examples/nano_gpt/utils.cpp b/tt-train/sources/examples/nano_gpt/utils.cpp index 408a3c01a38..e89b90a8f29 100644 --- a/tt-train/sources/examples/nano_gpt/utils.cpp +++ b/tt-train/sources/examples/nano_gpt/utils.cpp @@ -73,3 +73,22 @@ void GradientAccumulator::reset() { float GradientAccumulator::average_loss() const { return m_total_loss / static_cast(m_total_samples); } + +std::unique_ptr create_idendity_scheduler( + ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps) { + return std::make_unique(optimizer, [](int epoch) { return 1.0F; }); +} + +std::unique_ptr create_warmup_with_linear_scheduler( + ttml::optimizers::OptimizerBase *optimizer, size_t total_steps) { + const float default_warmup_factor = 0.1F; + const size_t warmup_steps = size_t(total_steps * default_warmup_factor); + const size_t linear_decay_steps = total_steps - warmup_steps; + + std::vector> schedulers; + schedulers.push_back(std::make_unique(optimizer, 0.0F, 1.0F, warmup_steps)); + schedulers.push_back( + std::make_unique(optimizer, 1.0F, 0.01F, linear_decay_steps)); + std::vector steps = {warmup_steps, linear_decay_steps}; + return std::make_unique(optimizer, std::move(schedulers), std::move(steps)); +} diff --git a/tt-train/sources/examples/nano_gpt/utils.hpp b/tt-train/sources/examples/nano_gpt/utils.hpp index e390dc3f483..c7383c1e9ac 100644 --- a/tt-train/sources/examples/nano_gpt/utils.hpp +++ b/tt-train/sources/examples/nano_gpt/utils.hpp @@ -10,6 +10,10 @@ #include #include "autograd/tensor.hpp" +#include "schedulers/lambda_scheduler.hpp" +#include "schedulers/linear_scheduler.hpp" +#include "schedulers/scheduler_base.hpp" +#include "schedulers/sequential_scheduler.hpp" #include "serialization/msgpack_file.hpp" #include "serialization/serialization.hpp" @@ -25,32 +29,42 @@ class LossAverageMeter { void reset(); }; +std::unique_ptr create_idendity_scheduler( + ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps); + +std::unique_ptr create_warmup_with_linear_scheduler( + ttml::optimizers::OptimizerBase *optimizer, size_t total_steps); + std::string read_file_to_str(const std::string &file_path); -template -void save_model_and_optimizer( +template +void save_training_state( std::string &model_path, const std::shared_ptr &model, - Optimizer &optimizer, + const std::unique_ptr &scheduler, const std::string &model_name, const std::string &optimizer_name) { ttml::serialization::MsgPackFile serializer; ttml::serialization::write_module(serializer, model_name, model.get()); - ttml::serialization::write_optimizer(serializer, optimizer_name, &optimizer); + ttml::serialization::write_optimizer(serializer, optimizer_name, scheduler->get_optimizer().get()); + ttml::serialization::write_state_dict(serializer, "scheduler", scheduler->get_state_dict()); serializer.serialize(model_path); } -template -void load_model_and_optimizer( +template +void load_training_state( std::string &model_path, const std::shared_ptr &model, - Optimizer &optimizer, + const std::unique_ptr &scheduler, const std::string &model_name, const std::string &optimizer_name) { ttml::serialization::MsgPackFile deserializer; deserializer.deserialize(model_path); ttml::serialization::read_module(deserializer, model_name, model.get()); - ttml::serialization::read_optimizer(deserializer, optimizer_name, &optimizer); + ttml::serialization::read_optimizer(deserializer, optimizer_name, scheduler->get_optimizer().get()); + auto state_dict = scheduler->get_state_dict(); + ttml::serialization::read_state_dict(deserializer, "scheduler", state_dict); + scheduler->set_state_dict(state_dict); } uint32_t round_up_to_tile(uint32_t value, uint32_t tile_size = 32); @@ -110,11 +124,12 @@ std::string generate_run_name(const TrainingConfig &config, bool add_time_to_run if (config.gradient_accumulation_steps > 1) { ss << "_grad_acc_" << config.gradient_accumulation_steps; } - + ss << "_sched_" << config.scheduler_type; if (add_time_to_run_name) { auto now = std::chrono::system_clock::now(); std::time_t current_time = std::chrono::system_clock::to_time_t(now); ss << "_date_" << std::put_time(std::localtime(¤t_time), "%Y-%m-%d_%H:%M:%S"); } + return ss.str(); } diff --git a/tt-train/sources/ttml/autograd/module_base.cpp b/tt-train/sources/ttml/autograd/module_base.cpp index 4cc13b09826..b3f771b3688 100644 --- a/tt-train/sources/ttml/autograd/module_base.cpp +++ b/tt-train/sources/ttml/autograd/module_base.cpp @@ -4,8 +4,6 @@ #include "module_base.hpp" -#include "auto_context.hpp" - namespace ttml::autograd { void ModuleBase::register_tensor(const TensorPtr& tensor_ptr, const std::string& name) { @@ -30,8 +28,8 @@ const std::string& ModuleBase::get_name() const { return m_name; } -NamedParameters ModuleBase::parameters() const { - NamedParameters params; +serialization::NamedParameters ModuleBase::parameters() const { + serialization::NamedParameters params; std::queue> modules_to_process; modules_to_process.emplace(this, get_name() + "/"); diff --git a/tt-train/sources/ttml/autograd/module_base.hpp b/tt-train/sources/ttml/autograd/module_base.hpp index 442d0dc36f1..b2729bde46e 100644 --- a/tt-train/sources/ttml/autograd/module_base.hpp +++ b/tt-train/sources/ttml/autograd/module_base.hpp @@ -7,6 +7,7 @@ #include #include +#include "serialization/serializable.hpp" #include "tensor.hpp" namespace ttml::autograd { @@ -15,7 +16,6 @@ enum class RunMode { TRAIN, EVAL }; class ModuleBase; using ModuleBasePtr = std::shared_ptr; -using NamedParameters = std::unordered_map; class ModuleBase { private: @@ -39,7 +39,7 @@ class ModuleBase { ModuleBase& operator=(ModuleBase&&) = default; [[nodiscard]] const std::string& get_name() const; - [[nodiscard]] NamedParameters parameters() const; + [[nodiscard]] serialization::NamedParameters parameters() const; void train(); void eval(); diff --git a/tt-train/sources/ttml/optimizers/adamw.cpp b/tt-train/sources/ttml/optimizers/adamw.cpp index 8770b9da5a8..d18901797d8 100644 --- a/tt-train/sources/ttml/optimizers/adamw.cpp +++ b/tt-train/sources/ttml/optimizers/adamw.cpp @@ -10,19 +10,19 @@ #include "core/debug.hpp" #include "core/tt_tensor_utils.hpp" #include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" #include "ttnn_fixed/trivial_ttnn_ops.hpp" - namespace { -const std::string kFirstMoment = "first_moment/"; -const std::string kSecondMoment = "second_moment/"; -const std::string kKahanCompensation = "kahan_compensation/"; - +const std::string kFirstMoment = "first_moment"; +const std::string kSecondMoment = "second_moment"; +const std::string kKahanCompensation = "kahan_compensation"; +const std::string kSteps = "steps"; } // namespace namespace ttml::optimizers { -MorehAdamW::MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : +MorehAdamW::MorehAdamW(serialization::NamedParameters parameters, const AdamWConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { if (m_config.use_kahan_summation) { throw std::runtime_error("MorehAdamW: Kahan summation is not supported. Use default AdamW instead."); @@ -95,29 +95,19 @@ void MorehAdamW::step() { } } -[[nodiscard]] autograd::NamedParameters MorehAdamW::get_state_dict() const { - autograd::NamedParameters state_dict; - for (const auto& [key, first_moment] : m_first_moment) { - state_dict.emplace(kFirstMoment + key, first_moment); - } - - for (const auto& [key, second_moment] : m_second_moment) { - state_dict.emplace(kSecondMoment + key, second_moment); - } +[[nodiscard]] serialization::StateDict MorehAdamW::get_state_dict() const { + serialization::StateDict state_dict; + state_dict[kFirstMoment] = m_first_moment; + state_dict[kSecondMoment] = m_second_moment; + state_dict[kSteps] = m_steps; return state_dict; } -void MorehAdamW::set_state_dict(const autograd::NamedParameters& dict) { - for (const auto& [key, tensor] : dict) { - if (key.starts_with(kFirstMoment)) { - m_first_moment[key.substr(kFirstMoment.size())] = tensor; - } else if (key.starts_with(kSecondMoment)) { - m_second_moment[key.substr(kSecondMoment.size())] = tensor; - } else { - throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key)); - } - } +void MorehAdamW::set_state_dict(const serialization::StateDict& dict) { + m_first_moment = std::get(dict.at(kFirstMoment)); + m_second_moment = std::get(dict.at(kSecondMoment)); + m_steps = serialization::get_value_type(dict, kSteps); } [[nodiscard]] size_t MorehAdamW::get_steps() const { @@ -128,7 +118,14 @@ void MorehAdamW::set_steps(size_t steps) { m_steps = steps; } -AdamW::AdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : +float MorehAdamW::get_lr() const { + return m_config.lr; +} +void MorehAdamW::set_lr(float lr) { + m_config.lr = lr; +} + +AdamW::AdamW(serialization::NamedParameters parameters, const AdamWConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { for (const auto& [key, tensor_ptr] : m_parameters) { if (tensor_ptr->get_requires_grad()) { @@ -226,35 +223,21 @@ void AdamW::step() { } } -[[nodiscard]] autograd::NamedParameters AdamW::get_state_dict() const { - autograd::NamedParameters state_dict; - for (const auto& [key, first_moment] : m_first_moment) { - state_dict.emplace(kFirstMoment + key, first_moment); - } - - for (const auto& [key, second_moment] : m_second_moment) { - state_dict.emplace(kSecondMoment + key, second_moment); - } - - for (const auto& [key, kahan_compensation] : m_kahan_compensation) { - state_dict.emplace(kKahanCompensation + key, kahan_compensation); - } +[[nodiscard]] serialization::StateDict AdamW::get_state_dict() const { + serialization::StateDict state_dict; + state_dict[kFirstMoment] = m_first_moment; + state_dict[kSecondMoment] = m_second_moment; + state_dict[kKahanCompensation] = m_kahan_compensation; + state_dict[kSteps] = m_steps; return state_dict; } -void AdamW::set_state_dict(const autograd::NamedParameters& dict) { - for (const auto& [key, tensor] : dict) { - if (key.starts_with(kFirstMoment)) { - m_first_moment[key.substr(kFirstMoment.size())] = tensor; - } else if (key.starts_with(kSecondMoment)) { - m_second_moment[key.substr(kSecondMoment.size())] = tensor; - } else if (key.starts_with(kKahanCompensation)) { - m_kahan_compensation[key.substr(kKahanCompensation.size())] = tensor; - } else { - throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key)); - } - } +void AdamW::set_state_dict(const serialization::StateDict& dict) { + m_first_moment = std::get(dict.at(kFirstMoment)); + m_second_moment = std::get(dict.at(kSecondMoment)); + m_kahan_compensation = std::get(dict.at(kKahanCompensation)); + m_steps = serialization::get_value_type(dict, kSteps); } [[nodiscard]] size_t AdamW::get_steps() const { @@ -265,4 +248,10 @@ void AdamW::set_steps(size_t steps) { m_steps = steps; } +float AdamW::get_lr() const { + return m_config.lr; +} +void AdamW::set_lr(float lr) { + m_config.lr = lr; +} } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/adamw.hpp b/tt-train/sources/ttml/optimizers/adamw.hpp index da3847f66db..d4505d8cb01 100644 --- a/tt-train/sources/ttml/optimizers/adamw.hpp +++ b/tt-train/sources/ttml/optimizers/adamw.hpp @@ -4,8 +4,8 @@ #include -#include "autograd/module_base.hpp" #include "optimizer_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { @@ -23,45 +23,52 @@ struct AdamWConfig { class MorehAdamW : public OptimizerBase { public: - MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config); + MorehAdamW(serialization::NamedParameters parameters, const AdamWConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override; + void set_lr(float lr) override; + private: size_t m_steps{0}; AdamWConfig m_config; - autograd::NamedParameters m_first_moment; - autograd::NamedParameters m_second_moment; + serialization::NamedParameters m_first_moment; + serialization::NamedParameters m_second_moment; }; class AdamW : public OptimizerBase { public: - AdamW(autograd::NamedParameters parameters, const AdamWConfig& config); + AdamW(serialization::NamedParameters parameters, const AdamWConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override; + + void set_lr(float lr) override; + private: size_t m_steps{0}; AdamWConfig m_config; - autograd::NamedParameters m_first_moment; - autograd::NamedParameters m_second_moment; - autograd::NamedParameters m_kahan_compensation; + serialization::NamedParameters m_first_moment; + serialization::NamedParameters m_second_moment; + serialization::NamedParameters m_kahan_compensation; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/optimizer_base.cpp b/tt-train/sources/ttml/optimizers/optimizer_base.cpp index 446f23d6714..7971998d087 100644 --- a/tt-train/sources/ttml/optimizers/optimizer_base.cpp +++ b/tt-train/sources/ttml/optimizers/optimizer_base.cpp @@ -8,7 +8,7 @@ namespace ttml::optimizers { -OptimizerBase::OptimizerBase(autograd::NamedParameters&& parameters) : m_parameters(std::move(parameters)) { +OptimizerBase::OptimizerBase(serialization::NamedParameters&& parameters) : m_parameters(std::move(parameters)) { } void OptimizerBase::print_stats() const { diff --git a/tt-train/sources/ttml/optimizers/optimizer_base.hpp b/tt-train/sources/ttml/optimizers/optimizer_base.hpp index 49f1f4a32aa..690d0fd9ed6 100644 --- a/tt-train/sources/ttml/optimizers/optimizer_base.hpp +++ b/tt-train/sources/ttml/optimizers/optimizer_base.hpp @@ -4,13 +4,13 @@ #pragma once -#include "autograd/module_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { class OptimizerBase { public: - explicit OptimizerBase(autograd::NamedParameters&& parameters); + explicit OptimizerBase(serialization::NamedParameters&& parameters); OptimizerBase(const OptimizerBase&) = delete; OptimizerBase& operator=(const OptimizerBase&) = delete; OptimizerBase(OptimizerBase&&) = delete; @@ -21,16 +21,19 @@ class OptimizerBase { virtual void step() = 0; - [[nodiscard]] virtual autograd::NamedParameters get_state_dict() const = 0; - virtual void set_state_dict(const autograd::NamedParameters& dict) = 0; + [[nodiscard]] virtual serialization::StateDict get_state_dict() const = 0; + virtual void set_state_dict(const serialization::StateDict& dict) = 0; [[nodiscard]] virtual size_t get_steps() const = 0; virtual void set_steps(size_t steps) = 0; + virtual void set_lr(float lr) = 0; + [[nodiscard]] virtual float get_lr() const = 0; + virtual void print_stats() const; protected: - autograd::NamedParameters m_parameters; + serialization::NamedParameters m_parameters; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/sgd.cpp b/tt-train/sources/ttml/optimizers/sgd.cpp index 0e25feb95fe..48298585644 100644 --- a/tt-train/sources/ttml/optimizers/sgd.cpp +++ b/tt-train/sources/ttml/optimizers/sgd.cpp @@ -9,10 +9,11 @@ #include "autograd/autocast_tensor.hpp" #include "core/debug.hpp" #include "core/tt_tensor_utils.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { -SGD::SGD(ttml::autograd::NamedParameters parameters, const SGDConfig& config) : +SGD::SGD(ttml::serialization::NamedParameters parameters, const SGDConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { for (const auto& [name, tensor_ptr] : m_parameters) { if (tensor_ptr->get_requires_grad()) { @@ -53,7 +54,7 @@ void SGD::step() { } if (m_config.momentum != 0.0F) { - if (steps != 0) { + if (m_steps != 0) { // apply momentum theta = ttnn::multiply(theta, m_config.momentum); // dampening @@ -76,23 +77,27 @@ void SGD::step() { tensor_ptr->set_value(ttnn::subtract( tensor_ptr->get_value(autograd::PreferredPrecision::FULL), ttnn::multiply(gradients, m_config.lr))); } - steps++; + m_steps++; } -autograd::NamedParameters SGD::get_state_dict() const { - return m_theta; +serialization::StateDict SGD::get_state_dict() const { + serialization::StateDict dict; + dict["theta"] = m_theta; + dict["steps"] = m_steps; + return dict; } -void SGD::set_state_dict(const autograd::NamedParameters& dict) { - m_theta = dict; +void SGD::set_state_dict(const serialization::StateDict& dict) { + m_theta = std::get(dict.at("theta")); + m_steps = serialization::get_value_type(dict, "steps"); } size_t SGD::get_steps() const { - return steps; + return m_steps; } void SGD::set_steps(size_t steps) { - this->steps = steps; + this->m_steps = steps; } } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/optimizers/sgd.hpp b/tt-train/sources/ttml/optimizers/sgd.hpp index 756facdf26c..298aef045f6 100644 --- a/tt-train/sources/ttml/optimizers/sgd.hpp +++ b/tt-train/sources/ttml/optimizers/sgd.hpp @@ -4,12 +4,8 @@ #pragma once -#include - -#include "autograd/module_base.hpp" -#include "autograd/tensor.hpp" -#include "core/tt_tensor_utils.hpp" #include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" namespace ttml::optimizers { @@ -23,22 +19,30 @@ struct SGDConfig { class SGD : public OptimizerBase { public: - explicit SGD(ttml::autograd::NamedParameters parameters, const SGDConfig& config); + explicit SGD(ttml::serialization::NamedParameters parameters, const SGDConfig& config); void zero_grad() override; void step() override; - [[nodiscard]] autograd::NamedParameters get_state_dict() const override; - void set_state_dict(const autograd::NamedParameters& dict) override; + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict& dict) override; [[nodiscard]] size_t get_steps() const override; void set_steps(size_t steps) override; + [[nodiscard]] float get_lr() const override { + return m_config.lr; + } + + void set_lr(float lr) override { + m_config.lr = lr; + } + private: - size_t steps{0}; + size_t m_steps{0}; SGDConfig m_config; - ttml::autograd::NamedParameters m_theta; + ttml::serialization::NamedParameters m_theta; }; } // namespace ttml::optimizers diff --git a/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp b/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp new file mode 100644 index 00000000000..f49ac00732d --- /dev/null +++ b/tt-train/sources/ttml/schedulers/lambda_scheduler.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "lambda_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" +namespace ttml::schedulers { + +LambdaScheduler::LambdaScheduler(optimizers::OptimizerBase *optimizer, std::function lr_lambda) : + LRSchedulerBase(optimizer), + m_lr_lambda(std::move(lr_lambda)), + m_last_step(0), + m_base_lr(optimizer->get_lr()), + m_last_lr(optimizer->get_lr()) { +} +void LambdaScheduler::step() { + m_last_step += 1; + float lr_factor = m_lr_lambda(m_last_step); + float new_lr = m_base_lr * lr_factor; + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} +float LambdaScheduler::get_last_lr() const { + return m_last_lr; +} +float LambdaScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} +void LambdaScheduler::set_state_dict(const serialization::StateDict &dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} +serialization::StateDict LambdaScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp b/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp new file mode 100644 index 00000000000..e75b167104c --- /dev/null +++ b/tt-train/sources/ttml/schedulers/lambda_scheduler.hpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { +class LambdaScheduler : public LRSchedulerBase { +public: + explicit LambdaScheduler(optimizers::OptimizerBase *optimizer, std::function lr_lambda); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + + void set_state_dict(const serialization::StateDict &dict) override; + +private: + std::function m_lr_lambda; + size_t m_last_step = 0; + float m_base_lr = 0.0F; + float m_last_lr = 0.0F; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.cpp b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp new file mode 100644 index 00000000000..964bfb0b4f1 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.cpp @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "linear_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" + +namespace ttml::schedulers { + +LinearScheduler::LinearScheduler( + optimizers::OptimizerBase* optimizer, float start_factor, float end_factor, size_t total_steps) : + LRSchedulerBase(optimizer), + m_base_lr(optimizer->get_lr()), + m_last_lr(m_base_lr), + m_start_factor(start_factor), + m_end_factor(end_factor), + m_total_steps(total_steps), + m_last_step(0) { +} + +void LinearScheduler::step() { + m_last_step += 1; + + float progress = static_cast(m_last_step) / m_total_steps; + progress = std::min(progress, 1.0f); + + float current_factor = m_start_factor + (m_end_factor - m_start_factor) * progress; + float new_lr = m_base_lr * current_factor; + + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} + +void LinearScheduler::set_state_dict(const serialization::StateDict& dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} + +serialization::StateDict LinearScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; + +float LinearScheduler::get_last_lr() const { + return m_last_lr; +} + +float LinearScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/linear_scheduler.hpp b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp new file mode 100644 index 00000000000..9a8edcf18bb --- /dev/null +++ b/tt-train/sources/ttml/schedulers/linear_scheduler.hpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class LinearScheduler : public LRSchedulerBase { +public: + LinearScheduler(optimizers::OptimizerBase *optimizer, float start_factor, float end_factor, size_t total_steps); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict &dict) override; + +private: + float m_base_lr = 0.F; + float m_start_factor = 0.F; + float m_end_factor = 0.F; + int m_total_steps = 0; + size_t m_last_step = 0; + float m_last_lr = 0.F; +}; +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/scheduler_base.cpp b/tt-train/sources/ttml/schedulers/scheduler_base.cpp new file mode 100644 index 00000000000..7e9e90c9092 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/scheduler_base.cpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +core::not_null ttml::schedulers::LRSchedulerBase::get_optimizer() const { + return m_optimizer; +} +LRSchedulerBase::LRSchedulerBase(optimizers::OptimizerBase *optimizer) : m_optimizer(optimizer) { +} + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/scheduler_base.hpp b/tt-train/sources/ttml/schedulers/scheduler_base.hpp new file mode 100644 index 00000000000..4fd52ff5526 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/scheduler_base.hpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "core/not_null.hpp" +#include "serialization/serializable.hpp" + +namespace ttml::optimizers { +class OptimizerBase; +} + +namespace ttml::schedulers { + +class LRSchedulerBase { +public: + explicit LRSchedulerBase(optimizers::OptimizerBase *optimizer); + + virtual ~LRSchedulerBase() = default; + + virtual void step() = 0; + + [[nodiscard]] virtual float get_last_lr() const = 0; + + [[nodiscard]] virtual float get_current_lr() const = 0; + + [[nodiscard]] core::not_null get_optimizer() const; + + [[nodiscard]] virtual serialization::StateDict get_state_dict() const = 0; + virtual void set_state_dict(const serialization::StateDict &dict) = 0; + +private: + core::not_null m_optimizer; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp b/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp new file mode 100644 index 00000000000..cac72d08156 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/sequential_scheduler.cpp @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "sequential_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" +#include "serialization/serializable.hpp" +namespace { +const std::string kCurrentScheduler = "current_scheduler/"; +} +namespace ttml::schedulers { +SequentialScheduler::SequentialScheduler( + optimizers::OptimizerBase *optimizer, + std::vector> schedulers, + std::vector milestones) : + LRSchedulerBase(optimizer), + m_schedulers(std::move(schedulers)), + m_milestones(std::move(milestones)), + m_current_scheduler_index(0), + m_current_step_in_scheduler(0), + m_last_lr(optimizer->get_lr()) { + if (m_schedulers.empty()) { + throw std::invalid_argument("SequentialScheduler requires at least one scheduler."); + } + + // Validate that each scheduler is non-null + for (auto &scheduler : m_schedulers) { + if (!scheduler) { + throw std::invalid_argument("Null scheduler provided to SequentialScheduler."); + } + } +} +void SequentialScheduler::step() { + if (m_current_scheduler_index >= m_schedulers.size()) { + return; + } + + auto ¤t_scheduler = m_schedulers[m_current_scheduler_index]; + auto current_sched_steps = m_milestones[m_current_scheduler_index]; + current_scheduler->step(); + m_current_step_in_scheduler += 1; + m_last_lr = current_scheduler->get_last_lr(); + + if (m_current_step_in_scheduler >= current_sched_steps) { + m_current_scheduler_index += 1; + m_current_step_in_scheduler = 0; + } +} +float SequentialScheduler::get_last_lr() const { + if (m_current_scheduler_index == 0) { + return (m_current_scheduler_index < m_schedulers.size()) + ? m_schedulers[m_current_scheduler_index]->get_last_lr() + : m_last_lr; + } else if (m_current_scheduler_index < m_schedulers.size()) { + return m_schedulers[m_current_scheduler_index]->get_last_lr(); + } + return m_last_lr; +} +float SequentialScheduler::get_current_lr() const { + // The current LR of the optimizer should reflect the last scheduler's step + return get_optimizer()->get_lr(); +} + +void SequentialScheduler::set_state_dict(const serialization::StateDict &dict) { + m_current_step_in_scheduler = serialization::get_value_type(dict, "m_current_step_in_scheduler"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); + m_current_scheduler_index = serialization::get_value_type(dict, "m_current_scheduler_index"); + serialization::StateDict current_scheduler_dict; + for (auto &[key, value] : dict) { + if (key.find(kCurrentScheduler) == 0) { + current_scheduler_dict[key.substr(kCurrentScheduler.length())] = value; + } + } + m_schedulers[m_current_scheduler_index]->set_state_dict(current_scheduler_dict); +} +serialization::StateDict SequentialScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_current_step_in_scheduler"] = m_current_step_in_scheduler; + res["m_last_lr"] = m_last_lr; + res["m_current_scheduler_index"] = m_current_scheduler_index; + for (auto &[key, value] : m_schedulers[m_current_scheduler_index]->get_state_dict()) { + res[kCurrentScheduler + key] = value; + } + return res; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp new file mode 100644 index 00000000000..eac686c1b2c --- /dev/null +++ b/tt-train/sources/ttml/schedulers/sequential_scheduler.hpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class SequentialScheduler : public LRSchedulerBase { +public: + // Each element in the schedulers vector is a (scheduler, steps) pair. + // The scheduler runs for 'steps' times, then we move on to the next one. + // A little bit different from the PyTorch implementation, where the milestones might be less then the number of + // schedulers which is missleading + SequentialScheduler( + optimizers::OptimizerBase *optimizer, + std::vector> schedulers, + std::vector milestones); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + void set_state_dict(const serialization::StateDict &dict) override; + +private: + std::vector> m_schedulers; + std::vector m_milestones; + size_t m_current_scheduler_index = 0; + int m_current_step_in_scheduler = 0; + float m_last_lr = 0.F; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/step_scheduler.cpp b/tt-train/sources/ttml/schedulers/step_scheduler.cpp new file mode 100644 index 00000000000..ec1acf8cb01 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/step_scheduler.cpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "step_scheduler.hpp" + +#include "optimizers/optimizer_base.hpp" + +namespace ttml::schedulers { + +StepScheduler::StepScheduler(optimizers::OptimizerBase *optimizer, size_t step_size, float gamma) : + LRSchedulerBase(optimizer), + m_step_size(step_size), + m_gamma(gamma), + m_last_step(0), + m_base_lr(optimizer->get_lr()), + m_last_lr(m_base_lr) { + if (gamma <= 0.0f) { + throw std::invalid_argument(fmt::format("gamma = {} must be greater than zero.", gamma)); + } +} +void StepScheduler::step() { + m_last_step += 1; + + // Every step_size epochs, lr is scaled by gamma + int num_steps = m_last_step / m_step_size; + float new_lr = m_base_lr * std::pow(m_gamma, static_cast(num_steps)); + + get_optimizer()->set_lr(new_lr); + m_last_lr = new_lr; +} +float StepScheduler::get_last_lr() const { + return m_last_lr; +} +float StepScheduler::get_current_lr() const { + return get_optimizer()->get_lr(); +} + +void StepScheduler::set_state_dict(const serialization::StateDict &dict) { + m_last_step = serialization::get_value_type(dict, "m_last_step"); + m_last_lr = serialization::get_value_type(dict, "m_last_lr"); +} +serialization::StateDict StepScheduler::get_state_dict() const { + serialization::StateDict res; + res["m_last_step"] = m_last_step; + res["m_last_lr"] = m_last_lr; + return res; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/schedulers/step_scheduler.hpp b/tt-train/sources/ttml/schedulers/step_scheduler.hpp new file mode 100644 index 00000000000..2f0189e9d78 --- /dev/null +++ b/tt-train/sources/ttml/schedulers/step_scheduler.hpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "scheduler_base.hpp" + +namespace ttml::schedulers { + +class StepScheduler : public LRSchedulerBase { +public: + StepScheduler(optimizers::OptimizerBase *optimizer, size_t step_size, float gamma = 0.1f); + + void step() override; + + [[nodiscard]] float get_last_lr() const override; + + [[nodiscard]] float get_current_lr() const override; + + [[nodiscard]] serialization::StateDict get_state_dict() const override; + + void set_state_dict(const serialization::StateDict &dict) override; + +private: + size_t m_step_size = 0; + float m_gamma = 0; + size_t m_last_step = 0; + + float m_base_lr = 0.F; + float m_last_lr = 0.F; +}; + +} // namespace ttml::schedulers diff --git a/tt-train/sources/ttml/serialization/msgpack_file.cpp b/tt-train/sources/ttml/serialization/msgpack_file.cpp index 42fb0b53378..573218a0b5d 100644 --- a/tt-train/sources/ttml/serialization/msgpack_file.cpp +++ b/tt-train/sources/ttml/serialization/msgpack_file.cpp @@ -6,12 +6,11 @@ #include -#include +#include #include #define MSGPACK_NO_BOOST #include #include -#include #include #include #include @@ -122,6 +121,10 @@ class MsgPackFile::Impl { } // Overloads for std::span + void put(std::string_view key, std::span value) { + m_data[std::string(key)] = std::vector(value.begin(), value.end()); + } + void put(std::string_view key, std::span value) { m_data[std::string(key)] = std::vector(value.begin(), value.end()); } @@ -142,6 +145,10 @@ class MsgPackFile::Impl { m_data[std::string(key)] = std::vector(value.begin(), value.end()); } + void put(std::string_view key, const ValueType& value) { + m_data[std::string(key)] = value; + } + // Serialization method void serialize(const std::string& filename) { // Create a buffer for packing @@ -216,6 +223,10 @@ class MsgPackFile::Impl { return get_value(key, value); } + bool get(std::string_view key, std::vector& value) const { + return get_value(key, value); + } + bool get(std::string_view key, std::vector& value) const { return get_value(key, value); } @@ -236,23 +247,11 @@ class MsgPackFile::Impl { return get_value(key, value); } -private: - using ValueType = std::variant< - bool, - char, - int, - float, - double, - uint32_t, - size_t, - std::string, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector, - std::vector>; + bool get(std::string_view key, ValueType& value) const { + return get_value(key, value); + } +private: std::unordered_map m_data; // Helper function to get value from m_data @@ -271,6 +270,17 @@ class MsgPackFile::Impl { throw std::runtime_error(fmt::format("Key not found: {}", key)); } } + template <> + bool get_value(std::string_view key, ValueType& value) const { + auto it = m_data.find(std::string(key)); + if (it != m_data.end()) { + value = it->second; + return true; + } else { + // Key not found + throw std::runtime_error(fmt::format("Key not found: {}", key)); + } + } }; MsgPackFile::MsgPackFile() : m_impl(std::make_unique()) { @@ -312,6 +322,10 @@ void MsgPackFile::put(std::string_view key, std::string_view value) { m_impl->put(key, value); } +void MsgPackFile::put(std::string_view key, std::span value) { + m_impl->put(key, value); +} + void MsgPackFile::put(std::string_view key, std::span value) { m_impl->put(key, value); } @@ -332,6 +346,14 @@ void MsgPackFile::put(std::string_view key, std::span value) m_impl->put(key, value); } +void MsgPackFile::put(std::string_view key, const char* value) { + put(key, std::string_view(value)); +} + +void MsgPackFile::put(std::string_view key, const ValueType& value) { + m_impl->put(key, value); +} + void MsgPackFile::serialize(const std::string& filename) { m_impl->serialize(filename); } @@ -372,6 +394,10 @@ void MsgPackFile::get(std::string_view key, std::string& value) const { m_impl->get(key, value); } +void MsgPackFile::get(std::string_view key, std::vector& value) const { + m_impl->get(key, value); +} + void MsgPackFile::get(std::string_view key, std::vector& value) const { m_impl->get(key, value); } @@ -392,7 +418,8 @@ void MsgPackFile::get(std::string_view key, std::vector& value) con m_impl->get(key, value); } -void MsgPackFile::put(std::string_view key, const char* value) { - put(key, std::string_view(value)); +void MsgPackFile::get(std::string_view key, ValueType& value) const { + m_impl->get(key, value); } + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/msgpack_file.hpp b/tt-train/sources/ttml/serialization/msgpack_file.hpp index 19f36f6cca9..6e170483a6f 100644 --- a/tt-train/sources/ttml/serialization/msgpack_file.hpp +++ b/tt-train/sources/ttml/serialization/msgpack_file.hpp @@ -13,6 +13,23 @@ namespace ttml::serialization { +using ValueType = std::variant< + bool, + char, + int, + float, + double, + uint32_t, + size_t, + std::string, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector, + std::vector>; + class MsgPackFile { public: MsgPackFile(); @@ -44,12 +61,14 @@ class MsgPackFile { void put(std::string_view key, const char* value); // Overloads for std::span + void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); void put(std::string_view key, std::span value); + void put(std::string_view key, const ValueType& value); // Serialization method void serialize(const std::string& filename); @@ -67,12 +86,15 @@ class MsgPackFile { void get(std::string_view key, std::string& value) const; // Methods to get vectors (from spans) + void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; void get(std::string_view key, std::vector& value) const; + void get(std::string_view key, ValueType& type) const; + private: class Impl; std::unique_ptr m_impl; diff --git a/tt-train/sources/ttml/serialization/serializable.hpp b/tt-train/sources/ttml/serialization/serializable.hpp new file mode 100644 index 00000000000..689aa24d7ed --- /dev/null +++ b/tt-train/sources/ttml/serialization/serializable.hpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include +#include + +#include "autograd/tensor.hpp" +#include "msgpack_file.hpp" + +namespace ttml::serialization { +using NamedParameters = std::unordered_map; +using SerializableType = std::variant; +using StateDict = std::unordered_map; + +template +concept IsValueType = requires { + { std::get(std::declval()) }; +}; + +template +const T& get_value_type(const StateDict& dict, const std::string& key) { + const auto& val_type = std::get(dict.at(key)); + return std::get(val_type); +} + +} // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/serialization.cpp b/tt-train/sources/ttml/serialization/serialization.cpp index d96e26f014f..401b96a26bf 100644 --- a/tt-train/sources/ttml/serialization/serialization.cpp +++ b/tt-train/sources/ttml/serialization/serialization.cpp @@ -21,24 +21,23 @@ namespace ttml::serialization { // trivial type to the std::string template -std::string to_bytes(const T& value) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); - std::string bytes(sizeof(T), '\0'); - std::memcpy(bytes.data(), &value, sizeof(T)); - return bytes; +std::span to_bytes(T& value) { + static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); + auto ptr = reinterpret_cast(&value); + return std::span(ptr, sizeof(T)); } template -void from_bytes(const std::string& bytes, T& value) { - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); +void from_bytes(std::span bytes, T& value) { + static_assert(std::is_trivially_copyable_v, "T must be trivially copyable"); if (bytes.size() != sizeof(T)) { - throw std::invalid_argument(fmt::format( - "Invalid byte size for conversion to type T. Expected: {} Actual: {}, type: {} ", - sizeof(T), - bytes.size(), - core::demangle(typeid(T).name()))); + std::ostringstream oss; + oss << "Invalid byte size for conversion to type T. Expected: " << sizeof(T) << " Actual: " << bytes.size() + << ", type: " << typeid(T).name(); + throw std::invalid_argument(oss.str()); } + std::memcpy(&value, bytes.data(), sizeof(T)); } @@ -77,7 +76,7 @@ void read_ttnn_tensor(MsgPackFile& file, std::string_view name, tt::tt_metal::Te tt::tt_metal::StorageType storage_type{}; auto shape = core::create_shape({1, 1, 1, 1}); - std::string bytes; + std::vector bytes; file.get(std::string(name) + "/shape", bytes); from_bytes(bytes, shape); @@ -127,12 +126,13 @@ void read_autograd_tensor(MsgPackFile& file, std::string_view name, ttml::autogr } } -void write_named_parameters(MsgPackFile& file, std::string_view name, const ttml::autograd::NamedParameters& params) { +void write_named_parameters( + MsgPackFile& file, std::string_view name, const ttml::serialization::NamedParameters& params) { for (const auto& [key, value] : params) { write_autograd_tensor(file, std::string(name) + "/" + key, value); } } -void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autograd::NamedParameters& params) { +void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::serialization::NamedParameters& params) { for (auto& [key, value] : params) { read_autograd_tensor(file, std::string(name) + "/" + key, value); } @@ -141,22 +141,15 @@ void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autog void write_optimizer(MsgPackFile& file, std::string_view name, const optimizers::OptimizerBase* optimizer) { assert(optimizer); auto state_dict = optimizer->get_state_dict(); - for (const auto& [key, value] : state_dict) { - ttml::serialization::write_autograd_tensor(file, std::string(name) + "/" + key, value); - } - file.put(std::string(name) + "/steps", optimizer->get_steps()); + write_state_dict(file, std::string(name), state_dict); } void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::OptimizerBase* optimizer) { assert(optimizer); size_t steps = 0; auto state_dict = optimizer->get_state_dict(); - for (auto& [key, value] : state_dict) { - ttml::serialization::read_autograd_tensor(file, std::string(name) + "/" + key, value); - } + read_state_dict(file, name, state_dict); optimizer->set_state_dict(state_dict); - file.get(std::string(name) + "/steps", steps); - optimizer->set_steps(steps); } void write_module(MsgPackFile& file, std::string_view name, const autograd::ModuleBase* module) { @@ -171,4 +164,35 @@ void read_module(MsgPackFile& file, std::string_view name, autograd::ModuleBase* read_named_parameters(file, name, named_parameters); } +void write_state_dict(MsgPackFile& file, std::string_view name, const serialization::StateDict& state_dict) { + for (const auto& [key, value] : state_dict) { + if (std::holds_alternative(value)) { + file.put(std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_ttnn_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_autograd_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + write_named_parameters(file, std::string(name) + "/" + key, std::get(value)); + } else { + throw std::runtime_error("Unsupported type in state dict"); + } + } +} +void read_state_dict(MsgPackFile& file, std::string_view name, serialization::StateDict& state_dict) { + for (auto& [key, value] : state_dict) { + if (std::holds_alternative(value)) { + file.get(std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_ttnn_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_autograd_tensor(file, std::string(name) + "/" + key, std::get(value)); + } else if (std::holds_alternative(value)) { + read_named_parameters(file, std::string(name) + "/" + key, std::get(value)); + } else { + throw std::runtime_error("Unsupported type in state dict"); + } + } +} + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/serialization/serialization.hpp b/tt-train/sources/ttml/serialization/serialization.hpp index 1d4198e9996..6eee8247b53 100644 --- a/tt-train/sources/ttml/serialization/serialization.hpp +++ b/tt-train/sources/ttml/serialization/serialization.hpp @@ -23,8 +23,9 @@ void write_autograd_tensor( MsgPackFile& file, std::string_view name, const ttml::autograd::TensorPtr& tensor, bool save_grads = false); void read_autograd_tensor(MsgPackFile& file, std::string_view name, ttml::autograd::TensorPtr& tensor); -void write_named_parameters(MsgPackFile& file, std::string_view name, const ttml::autograd::NamedParameters& params); -void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::autograd::NamedParameters& params); +void write_named_parameters( + MsgPackFile& file, std::string_view name, const ttml::serialization::NamedParameters& params); +void read_named_parameters(MsgPackFile& file, std::string_view name, ttml::serialization::NamedParameters& params); void write_optimizer(MsgPackFile& file, std::string_view name, const optimizers::OptimizerBase* optimizer); void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::OptimizerBase* optimizer); @@ -32,4 +33,7 @@ void read_optimizer(MsgPackFile& file, std::string_view name, optimizers::Optimi void write_module(MsgPackFile& file, std::string_view name, const autograd::ModuleBase* module); void read_module(MsgPackFile& file, std::string_view name, autograd::ModuleBase* module); +void write_state_dict(MsgPackFile& file, std::string_view name, const serialization::StateDict& state_dict); +void read_state_dict(MsgPackFile& file, std::string_view name, serialization::StateDict& state_dict); + } // namespace ttml::serialization diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp index dee98552ef6..564c985c198 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp @@ -4,7 +4,6 @@ #pragma once #include -#include namespace ttml::ttnn_fixed { diff --git a/tt-train/tests/autograd/module_base_parameters_test.cpp b/tt-train/tests/autograd/module_base_parameters_test.cpp index 1edbf7d212e..3275c0a0a54 100644 --- a/tt-train/tests/autograd/module_base_parameters_test.cpp +++ b/tt-train/tests/autograd/module_base_parameters_test.cpp @@ -8,6 +8,7 @@ #include #include "autograd/module_base.hpp" +#include "core/tt_tensor_utils.hpp" #include "modules/dropout_module.hpp" #include "modules/layer_norm_module.hpp" #include "modules/linear_module.hpp" diff --git a/tt-train/tests/model/linear_regression_full_test.cpp b/tt-train/tests/model/linear_regression_full_test.cpp index 1af4f315405..0915b05abf7 100644 --- a/tt-train/tests/model/linear_regression_full_test.cpp +++ b/tt-train/tests/model/linear_regression_full_test.cpp @@ -8,6 +8,7 @@ #include #include "autograd/auto_context.hpp" +#include "core/tt_tensor_utils.hpp" #include "modules/linear_module.hpp" #include "ops/losses.hpp" #include "optimizers/sgd.hpp" diff --git a/tt-train/tests/schedulers/schedulers_test.cpp b/tt-train/tests/schedulers/schedulers_test.cpp new file mode 100644 index 00000000000..5fc6d8ae80d --- /dev/null +++ b/tt-train/tests/schedulers/schedulers_test.cpp @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "core/not_null.hpp" +#include "optimizers/optimizer_base.hpp" +#include "schedulers/lambda_scheduler.hpp" +#include "schedulers/linear_scheduler.hpp" +#include "schedulers/sequential_scheduler.hpp" +#include "schedulers/step_scheduler.hpp" + +namespace ttml::optimizers { +class MockOptimizer : public OptimizerBase { +public: + explicit MockOptimizer(float lr) : OptimizerBase(ttml::serialization::NamedParameters{}), m_lr(lr) { + } + + void zero_grad() override {}; + + void step() override {}; + + [[nodiscard]] serialization::StateDict get_state_dict() const override { + return {}; + } + + void set_state_dict(const serialization::StateDict &dict) override {}; + + [[nodiscard]] size_t get_steps() const override { + return {}; + }; + void set_steps(size_t steps) override {}; + + void set_lr(float lr) override { + m_lr = lr; + } + + [[nodiscard]] float get_lr() const override { + return m_lr; + } + +private: + float m_lr = 0; +}; +} // namespace ttml::optimizers + +// ---------------------------------- +// Tests for LambdaScheduler +// ---------------------------------- +TEST(LambdaSchedulerTest, ConstantFactor) { + auto optimizer = std::make_unique(0.1F); + + // Lambda that keeps the LR constant + // The learning rate of each parameter group is set to the initial lr times a given function. When last_epoch=-1, + // sets initial lr as lr. + + ttml::schedulers::LambdaScheduler scheduler(optimizer.get(), [](int epoch) { + (void)epoch; + return 0.5F; + }); + + // Initial LR + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + scheduler.step(); // epoch 0 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); + + scheduler.step(); // epoch 1 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); + + scheduler.step(); // epoch 2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F * 0.5F); +} + +TEST(LambdaSchedulerTest, VaryingFactor) { + auto optimizer = std::make_unique(1.0f); + + // Lambda: lr_factor = 1.0 / (epoch+1) + ttml::schedulers::LambdaScheduler scheduler(optimizer.get(), [](int epoch) { return 1.0F / (epoch + 1); }); + + // Before stepping + EXPECT_FLOAT_EQ(optimizer->get_lr(), 1.0F); + + scheduler.step(); // epoch 0: factor = 1/1=0.5F + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.5F); + + scheduler.step(); // epoch 1: factor = 1/2=0.5 lr=1.0*0.5=0.5 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 1.F / 3.F); + + scheduler.step(); // epoch 2: factor = 1/3≈0.3333 lr=1.0*0.3333=0.3333 + EXPECT_NEAR(optimizer->get_lr(), 1.F / 4.F, 1e-5); + + scheduler.step(); // epoch 3: factor = 1/5=0.2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F); +} + +// ---------------------------------- +// Tests for StepLRScheduler +// ---------------------------------- +TEST(StepLRSchedulerTest, BasicDecay) { + auto optimizer = std::make_unique(0.2F); + + // Decrease LR by factor of 0.1 every 3 steps + ttml::schedulers::StepScheduler scheduler(optimizer.get(), 3, 0.1F); + + for (int i = 0; i < 3; ++i) { + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F); + scheduler.step(); + } + + for (int i = 0; i < 3; ++i) { + // After 3 steps: lr = base_lr * 0.1 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F * 0.1F); + scheduler.step(); + } + // After 6 steps: lr = base_lr * 0.1^2 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.2F * 0.1F * 0.1F); +} + +// ---------------------------------- +// Tests for LinearScheduler +// ---------------------------------- +TEST(LinearSchedulerTest, DecreasingLR) { + auto optimizer = std::make_unique(0.2F); + + // Linearly go from 0.2 to 0.0 in 4 steps + ttml::schedulers::LinearScheduler scheduler(optimizer.get(), 1.0F, 0.0F, 4); + + // step 1: progress = 1/4=0.25 lr = 0.2 + (0.0-0.2)*0.25 = 0.2 - 0.05=0.15 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.15F); + + // step 2: progress=0.5 lr=0.2+(0.0-0.2)*0.5=0.2-0.1=0.1 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + // step 3: progress=0.75 lr=0.2+(0.0-0.2)*0.75=0.2-0.15=0.05 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.05F); + + // step 4: progress=1.0 lr=0.2+(0.0-0.2)*1.0=0.0 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.0f); + + // Extra steps keep it at 0.0 + scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.0f); +} + +// ---------------------------------- +// Tests for SequentialScheduler +// ---------------------------------- +TEST(SequentialSchedulerTest, ChainSchedulers) { + auto optimizer = std::make_unique(1.0f); + + // First: StepLRScheduler for 3 steps (gamma=0.5 every step_size=1) + auto step_scheduler = std::make_unique(optimizer.get(), 1, 0.5F); + + // Then: LinearScheduler for 2 steps from current LR to 0.1 + auto linear_scheduler = std::make_unique(optimizer.get(), 1.0F, 0.1F, 2); + + std::vector> schedulers; + std::vector milestones; + schedulers.push_back(std::move(step_scheduler)); + schedulers.push_back(std::move(linear_scheduler)); + milestones.push_back(3); + milestones.push_back(2); + ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones)); + + // Initial LR = 1.0 + // Run StepLRScheduler for 3 steps: + // step_scheduler: every step reduces LR by factor 0.5 + seq_scheduler.step(); // 1st step: LR=1.0*0.5=0.5 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.5F); + + seq_scheduler.step(); // 2nd step: LR=0.5*0.5=0.25 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.25F); + + seq_scheduler.step(); // 3rd step: LR=0.25*0.5=0.125 + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.125F); + + // total_steps=2, start_lr=0.125, end_lr=0.1 + // step 1: progress=1/2=0.5 lr=1.0+(0.1-1.0)*0.5=0.55 + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), 0.55, 1e-5); + + // step 2: progress=2/2=1.0 lr=1.0+(0.1-1.0)*1.0=0.1 (min lr in linear scheduler) + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); + + // Further steps do nothing (we finished all schedulers) + seq_scheduler.step(); + EXPECT_FLOAT_EQ(optimizer->get_lr(), 0.1F); +} + +TEST(SequentialSchedulerTest, WarmupSetup) { + auto start_lr = 3.e-4F; + auto optimizer = std::make_unique(start_lr); + + // First: LinearScheduler for 10 steps from 0 to start_lr + auto warmup_scheduler = std::make_unique(optimizer.get(), 0.0F, 1.0F, 10); + + // Then: LinearScheduler for 50 steps from start_lr to 0.1F * start_lr + auto linear_scheduler = std::make_unique(optimizer.get(), 1.F, 0.1F, 50); + + std::vector> schedulers; + std::vector milestones; + schedulers.push_back(std::move(warmup_scheduler)); + schedulers.push_back(std::move(linear_scheduler)); + milestones.push_back(10); + milestones.push_back(50); + ttml::schedulers::SequentialScheduler seq_scheduler(optimizer.get(), std::move(schedulers), std::move(milestones)); + + for (int i = 0; i < 10; i++) { + // Linear warmup: 10 steps from 0 to start_lr + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), start_lr * (i + 1) / 10, 1e-5); + } + for (int i = 0; i < 50; i++) { + // Linear decay: 50 steps from start_lr to 0.1F * start_lr + seq_scheduler.step(); + EXPECT_NEAR(optimizer->get_lr(), start_lr * (1.0F - 0.9F * (i + 1) / 50.F), 1e-5); + } +}