Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TT-Train] Added LR Schedulers and updated serialization #15625

Merged
merged 15 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tt-train/sources/examples/mnist_mlp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <CLI/CLI.hpp>
#include <core/ttnn_all_includes.hpp>
#include <functional>
#include <memory>
#include <mnist/mnist_reader.hpp>
#include <ttnn/operations/eltwise/ternary/where.hpp>
#include <ttnn/tensor/tensor_utils.hpp>
Expand All @@ -19,7 +21,6 @@
#include "optimizers/sgd.hpp"
#include "utils.hpp"
#include "yaml-cpp/node/node.h"

using ttml::autograd::TensorPtr;

using DatasetSample = std::pair<std::vector<uint8_t>, uint8_t>;
Expand Down Expand Up @@ -95,7 +96,6 @@ int main(int argc, char **argv) {
CLI11_PARSE(app, argc, argv);
auto yaml_config = YAML::LoadFile(config_name);
TrainingConfig config = parse_config(yaml_config);

// Load MNIST data
const size_t num_targets = 10;
const size_t num_features = 784;
Expand Down
20 changes: 16 additions & 4 deletions tt-train/sources/examples/nano_gpt/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ struct TrainingConfig {
uint32_t gradient_accumulation_steps = 1;
std::string model_path;
std::string data_path;
std::string scheduler_type = "identity";

ttml::models::gpt2::TransformerConfig transformer_config;
};

Expand All @@ -161,10 +163,17 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) {
training_config["gradient_accumulation_steps"].as<uint32_t>(config.gradient_accumulation_steps);
config.model_path = training_config["model_path"].as<std::string>("");
config.data_path = training_config["data_path"].as<std::string>(std::string(DATA_FOLDER) + "/shakespeare.txt");
config.scheduler_type = training_config["scheduler_type"].as<std::string>(config.scheduler_type);

config.transformer_config = ttml::models::gpt2::read_config(training_config["transformer_config"]);
return config;
}

const std::unordered_map<
std::string,
std::function<std::unique_ptr<ttml::schedulers::LRSchedulerBase>(ttml::optimizers::OptimizerBase *, size_t)>>
schedulers = {{"identity", create_idendity_scheduler}, {"warmup_linear", create_warmup_with_linear_scheduler}};

int main(int argc, char **argv) {
auto result = signal(SIGINT, signal_handler);
if (result == SIG_ERR) {
Expand All @@ -186,7 +195,6 @@ int main(int argc, char **argv) {
CLI11_PARSE(app, argc, argv);
auto yaml_config = YAML::LoadFile(config_name);
TrainingConfig config = parse_config(yaml_config);

wandbcpp::init({.project = config.project_name, .name = generate_run_name(config, add_time_to_name)});
wandbcpp::update_config({
{"model", "transformer"},
Expand All @@ -206,10 +214,12 @@ int main(int argc, char **argv) {
config.transformer_config.positional_embedding_type == ttml::models::gpt2::PositionalEmbeddingType::Trainable
? "trainable"
: "fixed"},
{"scheduler_type", config.scheduler_type},
});

// set seed
ttml::autograd::ctx().set_seed(config.seed);
auto schedule_func = schedulers.at(config.scheduler_type);

std::string text;
try {
Expand All @@ -218,11 +228,11 @@ int main(int argc, char **argv) {
std::cerr << e.what() << std::endl;
return -1;
}

fmt::print("Max steps {}\n", config.max_steps);
fmt::print("Batch size {}\n", config.batch_size);
fmt::print("Gradient accumulation steps {}\n", config.gradient_accumulation_steps);
fmt::print("Total batch size {}\n", config.batch_size * config.gradient_accumulation_steps);
fmt::print("Scheduler type {}\n", config.scheduler_type);
fmt::print("Seed {}\n", ttml::autograd::ctx().get_seed());
auto sequence_length = config.transformer_config.max_sequence_length;

Expand Down Expand Up @@ -304,7 +314,7 @@ int main(int argc, char **argv) {
fmt::print(" Weight decay: {}\n", adamw_params.weight_decay);
fmt::print(" Use Kahan summation: {}\n", adamw_params.use_kahan_summation);
auto optimizer = ttml::optimizers::AdamW(model->parameters(), adamw_params);

auto scheduler = schedule_func(&optimizer, config.max_steps);
if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) {
fmt::print("Loading model from {}\n", config.model_path);
load_model_and_optimizer(config.model_path, model, optimizer, "transformer", "adamw");
Expand Down Expand Up @@ -345,6 +355,7 @@ int main(int argc, char **argv) {

if (gradient_accumulator_helper.should_step()) {
optimizer.step();
scheduler->step();
auto global_step = optimizer.get_steps();
fmt::print("Step: {}, Loss: {}\n", global_step, gradient_accumulator_helper.average_loss());
loss_meter.update(gradient_accumulator_helper.average_loss());
Expand All @@ -353,7 +364,8 @@ int main(int argc, char **argv) {
wandbcpp::log(
{{"Step", (int)global_step},
{"Samples", (int)get_samples_count(global_step)},
{"Loss", loss_meter.average()}});
{"Loss", loss_meter.average()},
{"Learning rate", optimizer.get_lr()}});
loss_meter.reset();
}
if (!config.model_path.empty() && global_step % config.model_save_interval == 0) {
Expand Down
19 changes: 19 additions & 0 deletions tt-train/sources/examples/nano_gpt/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,22 @@ void GradientAccumulator::reset() {
float GradientAccumulator::average_loss() const {
return m_total_loss / static_cast<float>(m_total_samples);
}

std::unique_ptr<ttml::schedulers::LRSchedulerBase> create_idendity_scheduler(
ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps) {
return std::make_unique<ttml::schedulers::LambdaScheduler>(optimizer, [](int epoch) { return 1.0F; });
}

std::unique_ptr<ttml::schedulers::LRSchedulerBase> create_warmup_with_linear_scheduler(
ttml::optimizers::OptimizerBase *optimizer, size_t total_steps) {
const float default_warmup_factor = 0.1F;
const size_t warmup_steps = size_t(total_steps * default_warmup_factor);
const size_t linear_decay_steps = total_steps - warmup_steps;

std::vector<std::unique_ptr<ttml::schedulers::LRSchedulerBase>> schedulers;
schedulers.push_back(std::make_unique<ttml::schedulers::LinearScheduler>(optimizer, 0.0F, 1.0F, warmup_steps));
schedulers.push_back(
std::make_unique<ttml::schedulers::LinearScheduler>(optimizer, 1.0F, 0.01F, linear_decay_steps));
std::vector<size_t> steps = {warmup_steps, linear_decay_steps};
return std::make_unique<ttml::schedulers::SequentialScheduler>(optimizer, std::move(schedulers), std::move(steps));
}
28 changes: 22 additions & 6 deletions tt-train/sources/examples/nano_gpt/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
#include <fstream>
#include <iostream>
#include <sstream>
#include <third_party/taskflow/taskflow/utility/serializer.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaked include?


#include "autograd/tensor.hpp"
#include "schedulers/lambda_scheduler.hpp"
#include "schedulers/linear_scheduler.hpp"
#include "schedulers/scheduler_base.hpp"
#include "schedulers/sequential_scheduler.hpp"
#include "serialization/msgpack_file.hpp"
#include "serialization/serialization.hpp"

Expand All @@ -25,32 +30,42 @@ class LossAverageMeter {
void reset();
};

std::unique_ptr<ttml::schedulers::LRSchedulerBase> create_idendity_scheduler(
ttml::optimizers::OptimizerBase *optimizer, [[maybe_unused]] size_t total_steps);

std::unique_ptr<ttml::schedulers::LRSchedulerBase> create_warmup_with_linear_scheduler(
ttml::optimizers::OptimizerBase *optimizer, size_t total_steps);

std::string read_file_to_str(const std::string &file_path);

template <typename Model, typename Optimizer>
template <typename Model>
void save_model_and_optimizer(
std::string &model_path,
const std::shared_ptr<Model> &model,
Optimizer &optimizer,
const std::unique_ptr<ttml::schedulers::LRSchedulerBase> &scheduler,
const std::string &model_name,
const std::string &optimizer_name) {
ttml::serialization::MsgPackFile serializer;
ttml::serialization::write_module(serializer, model_name, model.get());
ttml::serialization::write_optimizer(serializer, optimizer_name, &optimizer);
ttml::serialization::write_state_dict(serializer, "scheduler", scheduler->get_state_dict());
serializer.serialize(model_path);
}

template <typename Model, typename Optimizer>
template <typename Model>
void load_model_and_optimizer(
std::string &model_path,
const std::shared_ptr<Model> &model,
Optimizer &optimizer,
const std::unique_ptr<ttml::schedulers::LRSchedulerBase> &scheduler,
const std::string &model_name,
const std::string &optimizer_name) {
ttml::serialization::MsgPackFile deserializer;
deserializer.deserialize(model_path);
ttml::serialization::read_module(deserializer, model_name, model.get());
ttml::serialization::read_optimizer(deserializer, optimizer_name, &optimizer);
ttml::serialization::read_optimizer(deserializer, optimizer_name, scheduler->get_optimizer().get());
auto state_dict = scheduler->get_state_dict();
ttml::serialization::read_state_dict(deserializer, "scheduler", state_dict);
scheduler->set_state_dict(state_dict);
}

uint32_t round_up_to_tile(uint32_t value, uint32_t tile_size = 32);
Expand Down Expand Up @@ -110,11 +125,12 @@ std::string generate_run_name(const TrainingConfig &config, bool add_time_to_run
if (config.gradient_accumulation_steps > 1) {
ss << "_grad_acc_" << config.gradient_accumulation_steps;
}

ss << "_sched_" << config.scheduler_type;
if (add_time_to_run_name) {
auto now = std::chrono::system_clock::now();
std::time_t current_time = std::chrono::system_clock::to_time_t(now);
ss << "_date_" << std::put_time(std::localtime(&current_time), "%Y-%m-%d_%H:%M:%S");
}

return ss.str();
}
6 changes: 2 additions & 4 deletions tt-train/sources/ttml/autograd/module_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "module_base.hpp"

#include "auto_context.hpp"

namespace ttml::autograd {

void ModuleBase::register_tensor(const TensorPtr& tensor_ptr, const std::string& name) {
Expand All @@ -30,8 +28,8 @@ const std::string& ModuleBase::get_name() const {
return m_name;
}

NamedParameters ModuleBase::parameters() const {
NamedParameters params;
serialization::NamedParameters ModuleBase::parameters() const {
serialization::NamedParameters params;

std::queue<std::pair<const ModuleBase*, std::string>> modules_to_process;
modules_to_process.emplace(this, get_name() + "/");
Expand Down
4 changes: 2 additions & 2 deletions tt-train/sources/ttml/autograd/module_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>
#include <unordered_map>

#include "serialization/serializable.hpp"
#include "tensor.hpp"

namespace ttml::autograd {
Expand All @@ -15,7 +16,6 @@ enum class RunMode { TRAIN, EVAL };

class ModuleBase;
using ModuleBasePtr = std::shared_ptr<ModuleBase>;
using NamedParameters = std::unordered_map<std::string, TensorPtr>;

class ModuleBase {
private:
Expand All @@ -39,7 +39,7 @@ class ModuleBase {
ModuleBase& operator=(ModuleBase&&) = default;

[[nodiscard]] const std::string& get_name() const;
[[nodiscard]] NamedParameters parameters() const;
[[nodiscard]] serialization::NamedParameters parameters() const;

void train();
void eval();
Expand Down
91 changes: 40 additions & 51 deletions tt-train/sources/ttml/optimizers/adamw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
#include "core/debug.hpp"
#include "core/tt_tensor_utils.hpp"
#include "optimizers/optimizer_base.hpp"
#include "serialization/serializable.hpp"
#include "ttnn_fixed/trivial_ttnn_ops.hpp"

namespace {

const std::string kFirstMoment = "first_moment/";
const std::string kSecondMoment = "second_moment/";
const std::string kKahanCompensation = "kahan_compensation/";

const std::string kFirstMoment = "first_moment";
const std::string kSecondMoment = "second_moment";
const std::string kKahanCompensation = "kahan_compensation";
const std::string kSteps = "steps";
} // namespace

namespace ttml::optimizers {

MorehAdamW::MorehAdamW(autograd::NamedParameters parameters, const AdamWConfig& config) :
MorehAdamW::MorehAdamW(serialization::NamedParameters parameters, const AdamWConfig& config) :
OptimizerBase(std::move(parameters)), m_config(config) {
if (m_config.use_kahan_summation) {
throw std::runtime_error("MorehAdamW: Kahan summation is not supported. Use default AdamW instead.");
Expand Down Expand Up @@ -95,29 +95,19 @@ void MorehAdamW::step() {
}
}

[[nodiscard]] autograd::NamedParameters MorehAdamW::get_state_dict() const {
autograd::NamedParameters state_dict;
for (const auto& [key, first_moment] : m_first_moment) {
state_dict.emplace(kFirstMoment + key, first_moment);
}

for (const auto& [key, second_moment] : m_second_moment) {
state_dict.emplace(kSecondMoment + key, second_moment);
}
[[nodiscard]] serialization::StateDict MorehAdamW::get_state_dict() const {
serialization::StateDict state_dict;
state_dict[kFirstMoment] = m_first_moment;
state_dict[kSecondMoment] = m_second_moment;
state_dict[kSteps] = m_steps;

return state_dict;
}

void MorehAdamW::set_state_dict(const autograd::NamedParameters& dict) {
for (const auto& [key, tensor] : dict) {
if (key.starts_with(kFirstMoment)) {
m_first_moment[key.substr(kFirstMoment.size())] = tensor;
} else if (key.starts_with(kSecondMoment)) {
m_second_moment[key.substr(kSecondMoment.size())] = tensor;
} else {
throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key));
}
}
void MorehAdamW::set_state_dict(const serialization::StateDict& dict) {
m_first_moment = std::get<serialization::NamedParameters>(dict.at(kFirstMoment));
m_second_moment = std::get<serialization::NamedParameters>(dict.at(kSecondMoment));
m_steps = serialization::get_value_type<size_t>(dict, kSteps);
}

[[nodiscard]] size_t MorehAdamW::get_steps() const {
Expand All @@ -128,7 +118,14 @@ void MorehAdamW::set_steps(size_t steps) {
m_steps = steps;
}

AdamW::AdamW(autograd::NamedParameters parameters, const AdamWConfig& config) :
float MorehAdamW::get_lr() const {
return m_config.lr;
}
void MorehAdamW::set_lr(float lr) {
m_config.lr = lr;
}

AdamW::AdamW(serialization::NamedParameters parameters, const AdamWConfig& config) :
OptimizerBase(std::move(parameters)), m_config(config) {
for (const auto& [key, tensor_ptr] : m_parameters) {
if (tensor_ptr->get_requires_grad()) {
Expand Down Expand Up @@ -226,35 +223,21 @@ void AdamW::step() {
}
}

[[nodiscard]] autograd::NamedParameters AdamW::get_state_dict() const {
autograd::NamedParameters state_dict;
for (const auto& [key, first_moment] : m_first_moment) {
state_dict.emplace(kFirstMoment + key, first_moment);
}

for (const auto& [key, second_moment] : m_second_moment) {
state_dict.emplace(kSecondMoment + key, second_moment);
}

for (const auto& [key, kahan_compensation] : m_kahan_compensation) {
state_dict.emplace(kKahanCompensation + key, kahan_compensation);
}
[[nodiscard]] serialization::StateDict AdamW::get_state_dict() const {
serialization::StateDict state_dict;
state_dict[kFirstMoment] = m_first_moment;
state_dict[kSecondMoment] = m_second_moment;
state_dict[kKahanCompensation] = m_kahan_compensation;
state_dict[kSteps] = m_steps;

return state_dict;
}

void AdamW::set_state_dict(const autograd::NamedParameters& dict) {
for (const auto& [key, tensor] : dict) {
if (key.starts_with(kFirstMoment)) {
m_first_moment[key.substr(kFirstMoment.size())] = tensor;
} else if (key.starts_with(kSecondMoment)) {
m_second_moment[key.substr(kSecondMoment.size())] = tensor;
} else if (key.starts_with(kKahanCompensation)) {
m_kahan_compensation[key.substr(kKahanCompensation.size())] = tensor;
} else {
throw std::runtime_error(fmt::format("AdamW: Invalid key in state dict. Key = {}", key));
}
}
void AdamW::set_state_dict(const serialization::StateDict& dict) {
m_first_moment = std::get<serialization::NamedParameters>(dict.at(kFirstMoment));
m_second_moment = std::get<serialization::NamedParameters>(dict.at(kSecondMoment));
m_kahan_compensation = std::get<serialization::NamedParameters>(dict.at(kKahanCompensation));
m_steps = serialization::get_value_type<size_t>(dict, kSteps);
}

[[nodiscard]] size_t AdamW::get_steps() const {
Expand All @@ -265,4 +248,10 @@ void AdamW::set_steps(size_t steps) {
m_steps = steps;
}

float AdamW::get_lr() const {
return m_config.lr;
}
void AdamW::set_lr(float lr) {
m_config.lr = lr;
}
} // namespace ttml::optimizers
Loading
Loading