diff --git a/tt-train/sources/examples/nano_gpt/main.cpp b/tt-train/sources/examples/nano_gpt/main.cpp index a0732c43156..3495584bf6a 100644 --- a/tt-train/sources/examples/nano_gpt/main.cpp +++ b/tt-train/sources/examples/nano_gpt/main.cpp @@ -141,6 +141,8 @@ struct TrainingConfig { uint32_t max_steps = 5000; float learning_rate = 3e-4F; float weight_decay = 1e-2F; + // works only for AdamW + bool use_kahan_summation = false; std::string model_path; std::string data_path; ttml::models::gpt2::TransformerConfig transformer_config; @@ -157,6 +159,7 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) { config.max_steps = training_config["max_steps"].as(); config.learning_rate = training_config["learning_rate"].as(); config.weight_decay = training_config["weight_decay"].as(); + config.use_kahan_summation = training_config["use_kahan_summation"].as(config.use_kahan_summation); config.model_path = training_config["model_path"].as(""); config.data_path = training_config["data_path"].as(std::string(DATA_FOLDER) + "/shakespeare.txt"); config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]); @@ -295,9 +298,11 @@ int main(int argc, char **argv) { auto adamw_params = ttml::optimizers::AdamWConfig(); adamw_params.lr = config.learning_rate; adamw_params.weight_decay = config.weight_decay; + adamw_params.use_kahan_summation = config.use_kahan_summation; fmt::print("AdamW configuration:\n"); fmt::print(" Learning rate: {}\n", adamw_params.lr); fmt::print(" Weight decay: {}\n", adamw_params.weight_decay); + fmt::print(" Use Kahan summation: {}\n", adamw_params.use_kahan_summation); auto optimizer = ttml::optimizers::AdamW(model->parameters(), adamw_params); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { diff --git a/tt-train/sources/ttml/optimizers/adamw.cpp b/tt-train/sources/ttml/optimizers/adamw.cpp index c11724ac17d..8770b9da5a8 100644 --- a/tt-train/sources/ttml/optimizers/adamw.cpp +++ b/tt-train/sources/ttml/optimizers/adamw.cpp @@ -16,6 +16,7 @@ namespace { const std::string kFirstMoment = "first_moment/"; const std::string kSecondMoment = "second_moment/"; +const std::string kKahanCompensation = "kahan_compensation/"; } // namespace @@ -23,6 +24,10 @@ namespace ttml::optimizers { MorehAdamW::MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : OptimizerBase(std::move(parameters)), m_config(config) { + if (m_config.use_kahan_summation) { + throw std::runtime_error("MorehAdamW: Kahan summation is not supported. Use default AdamW instead."); + } + for (const auto& [key, tensor_ptr] : m_parameters) { if (tensor_ptr->get_requires_grad()) { m_first_moment.emplace( @@ -137,6 +142,13 @@ AdamW::AdamW(autograd::NamedParameters parameters, const AdamWConfig& config) : autograd::create_tensor( core::zeros_like(tensor_ptr->get_value(autograd::PreferredPrecision::FULL)), /* requires_grad */ false)); + if (m_config.use_kahan_summation) { + m_kahan_compensation.emplace( + key, + autograd::create_tensor( + core::zeros_like(tensor_ptr->get_value(autograd::PreferredPrecision::FULL)), + /* requires_grad */ false)); + } } } } @@ -188,11 +200,29 @@ void AdamW::step() { // weights -= lr * first_moment_hat / (sqrt(second_moment_hat) + epsilon) first_moment_ptr->set_value(first_moment); second_moment_ptr->set_value(second_moment); - tensor_ptr->set_value(ttnn::subtract( - tensor_ptr->get_value(autograd::PreferredPrecision::FULL), - ttnn_fixed::divide( - ttnn::multiply(first_moment_hat, m_config.lr), - ttnn::add(ttnn::sqrt(second_moment_hat), m_config.epsilon)))); + + auto update_tensor = ttnn_fixed::divide( + ttnn::multiply(first_moment_hat, -m_config.lr), ttnn::add(ttnn::sqrt(second_moment_hat), m_config.epsilon)); + + if (!m_config.use_kahan_summation) { + tensor_ptr->set_value(ttnn::add(tensor_ptr->get_value(autograd::PreferredPrecision::FULL), update_tensor)); + } else { + auto value_tensor = tensor_ptr->get_value(autograd::PreferredPrecision::FULL); + + const auto& kahan_compensation_ptr = m_kahan_compensation.at(key); + // A running compensation for lost low-order bits + auto compensation_tensor = kahan_compensation_ptr->get_value(autograd::PreferredPrecision::FULL); + // Adjust the update with the compensation + auto adjusted_update = ttnn::subtract(update_tensor, compensation_tensor); + // Update the value with the adjusted update + auto result = ttnn::add(value_tensor, adjusted_update); + // (result - value_tensor) cancels the high-order part of adjusted_update; + // subtracting adjusted_update recovers negative (low part of adjusted_update) + compensation_tensor = ttnn::subtract(ttnn::subtract(result, value_tensor), adjusted_update); + + tensor_ptr->set_value(result); + kahan_compensation_ptr->set_value(compensation_tensor); + } } } @@ -206,6 +236,10 @@ void AdamW::step() { state_dict.emplace(kSecondMoment + key, second_moment); } + for (const auto& [key, kahan_compensation] : m_kahan_compensation) { + state_dict.emplace(kKahanCompensation + key, kahan_compensation); + } + return state_dict; } @@ -215,6 +249,8 @@ void AdamW::set_state_dict(const autograd::NamedParameters& dict) { m_first_moment[key.substr(kFirstMoment.size())] = tensor; } else if (key.starts_with(kSecondMoment)) { m_second_moment[key.substr(kSecondMoment.size())] = tensor; + } else if (key.starts_with(kKahanCompensation)) { + m_kahan_compensation[key.substr(kKahanCompensation.size())] = tensor; } else { throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key)); } diff --git a/tt-train/sources/ttml/optimizers/adamw.hpp b/tt-train/sources/ttml/optimizers/adamw.hpp index 001b3e5c683..da3847f66db 100644 --- a/tt-train/sources/ttml/optimizers/adamw.hpp +++ b/tt-train/sources/ttml/optimizers/adamw.hpp @@ -16,6 +16,9 @@ struct AdamWConfig { float epsilon{1e-8F}; float weight_decay{0.01F}; // TODO: add amsgrad + + // flag to enable kahan summation to reduce floating point errors + bool use_kahan_summation{false}; }; class MorehAdamW : public OptimizerBase { @@ -58,6 +61,7 @@ class AdamW : public OptimizerBase { AdamWConfig m_config; autograd::NamedParameters m_first_moment; autograd::NamedParameters m_second_moment; + autograd::NamedParameters m_kahan_compensation; }; } // namespace ttml::optimizers