Skip to content

Commit

Permalink
[TT-Train] Added Yaml Configs support (#15352)
Browse files Browse the repository at this point in the history
### What's changed
* Moved models out of examples.
* Added yaml configs support
* A few small improvements

### Checklist
- [x] Post commit CI passes
- [x] Blackhole Post commit (if applicable)
- [x] Model regression CI testing passes (if applicable)
- [x] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
dmakoviichuk-tt authored Nov 25, 2024
1 parent 37fc6b6 commit dce7015
Show file tree
Hide file tree
Showing 29 changed files with 428 additions and 268 deletions.
4 changes: 3 additions & 1 deletion tt-train/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{
"editor.formatOnSave": true
"editor.formatOnSave": true,
"files.autoSave": "afterDelay",
"C_Cpp.clang_format_style": ".clang-format"
}
14 changes: 14 additions & 0 deletions tt-train/configs/training_mnist_mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
training_config:
batch_size: 128
logging_interval: 50
num_epochs: 10
learning_rate: 0.1
momentum: 0.9
weight_decay: 0.0
is_eval: false
model_save_interval: 500
model_path: "/tmp/mnist_mlp.msgpack"
mlp_config:
input_features: 784
hidden_features: [128]
output_features: 10
17 changes: 17 additions & 0 deletions tt-train/configs/training_shakespear_nanogpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
training_config:
project_name: "tt_train_nano_gpt"
seed: 5489
model_save_interval: 500
batch_size: 64
num_epochs: 1
max_steps: 5000
learning_rate: 0.0003
weight_decay: 0.01

transformer_config:
num_heads: 6
embedding_dim: 384
dropout_prob: 0.2
num_blocks: 6
vocab_size: 96
max_sequence_length: 256
2 changes: 1 addition & 1 deletion tt-train/sources/examples/graph_capture/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int main() {
ttml::core::zeros(ttml::core::create_shape({batch_size, 1, 1, num_targets}), device));

auto model_params = ttml::modules::MultiLayerPerceptronParameters{
.m_input_features = num_features, .m_hidden_features = {128}, .m_output_features = num_targets};
.input_features = num_features, .hidden_features = {128}, .output_features = num_targets};
auto model = ttml::modules::MultiLayerPerceptron(model_params);

auto mode = tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH;
Expand Down
1 change: 0 additions & 1 deletion tt-train/sources/examples/linear_regression/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
project(linear_regression)

set(SOURCES main.cpp)

add_executable(linear_regression ${SOURCES})
target_link_libraries(linear_regression PRIVATE ttml)
7 changes: 4 additions & 3 deletions tt-train/sources/examples/linear_regression/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/tt_tensor_utils.hpp"
#include "datasets/dataloader.hpp"
#include "datasets/generators.hpp"
#include "models/linear_regression.hpp"
#include "modules/linear_module.hpp"
#include "ops/losses.hpp"
#include "optimizers/sgd.hpp"
Expand Down Expand Up @@ -66,18 +67,18 @@ int main() {
const uint32_t batch_size = 128;
auto train_dataloader = DataLoader(training_dataset, batch_size, /* shuffle */ true, collate_fn);

auto model = ttml::modules::LinearLayer(num_features, num_targets);
auto model = ttml::models::linear_regression::create(num_features, num_targets);

float learning_rate = 0.1F * num_targets * (batch_size / 128.F);
auto sgd_config = ttml::optimizers::SGDConfig{.lr = learning_rate, .momentum = 0.0F};
auto optimizer = ttml::optimizers::SGD(model.parameters(), sgd_config);
auto optimizer = ttml::optimizers::SGD(model->parameters(), sgd_config);

int training_step = 0;
const int num_epochs = 10;
for (int epoch = 0; epoch < num_epochs; ++epoch) {
for (const auto& [data, targets] : train_dataloader) {
optimizer.zero_grad();
auto output = model(data);
auto output = (*model)(data);
auto loss = ttml::ops::mse_loss(output, targets);
auto loss_float = ttml::core::to_vector(loss->get_value())[0];
fmt::print("Step: {} Loss: {}\n", training_step++, loss_float);
Expand Down
3 changes: 1 addition & 2 deletions tt-train/sources/examples/mnist_mlp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ project(mnist_mlp)
set(SOURCES
main.cpp
utils.cpp
models.cpp
)

CPMAddPackage(NAME mnist_dataset GITHUB_REPOSITORY wichtounet/mnist GIT_TAG master)
include_directories(${mnist_dataset_SOURCE_DIR}/include)

# Add executable and link libraries
add_executable(mnist_mlp ${SOURCES})
target_link_libraries(mnist_mlp PRIVATE ttml)
target_compile_definitions(mnist_mlp PRIVATE MNIST_DATA_LOCATION="${mnist_dataset_SOURCE_DIR}/")
add_definitions(-DCONFIGS_FOLDER="${CMAKE_SOURCE_DIR}/configs")
82 changes: 51 additions & 31 deletions tt-train/sources/examples/mnist_mlp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
#include "core/tt_tensor_utils.hpp"
#include "datasets/dataloader.hpp"
#include "datasets/in_memory_dataset.hpp"
#include "models.hpp"
#include "models/mlp.hpp"
#include "ops/losses.hpp"
#include "optimizers/sgd.hpp"
#include "utils.hpp"
#include "yaml-cpp/node/node.h"

using ttml::autograd::TensorPtr;

Expand Down Expand Up @@ -55,27 +56,46 @@ float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) {
return num_correct / num_samples;
};

int main(int argc, char **argv) {
CLI::App app{"Mnist Example"};
argv = app.ensure_utf8(argv);

struct TrainingConfig {
uint32_t batch_size = 128;
int logging_interval = 50;
size_t num_epochs = 10;
bool is_eval = false;
float learning_rate = 0.1;
float momentum = 0.9F;
float weight_decay = 0.F;
int model_save_interval = 500;
std::string model_path = "/tmp/mnist_mlp.msgpack";
ttml::modules::MultiLayerPerceptronParameters mlp_config;
};

app.add_option("-b,--batch_size", batch_size, "Batch size")->default_val(batch_size);
app.add_option("-l,--logging_interval", logging_interval, "Logging interval")->default_val(logging_interval);
app.add_option("-m,--model_save_interval", model_save_interval, "model save interval")
->default_val(model_save_interval);
TrainingConfig parse_config(const YAML::Node &yaml_config) {
TrainingConfig config;
auto training_config = yaml_config["training_config"];

config.batch_size = training_config["batch_size"].as<uint32_t>();
config.logging_interval = training_config["logging_interval"].as<int>();
config.num_epochs = training_config["num_epochs"].as<size_t>();
config.learning_rate = training_config["learning_rate"].as<float>();
config.momentum = training_config["momentum"].as<float>();
config.weight_decay = training_config["weight_decay"].as<float>();
config.model_save_interval = training_config["model_save_interval"].as<int>();
config.mlp_config = ttml::models::mlp::read_config(training_config["mlp_config"]);
return config;
}

app.add_option("-n,--num_epochs", num_epochs, "Number of epochs")->default_val(num_epochs);
app.add_option("-s,--model_path", model_path, "Model path")->default_val(model_path);
app.add_option("-e,--eval", is_eval, "eval only mode")->default_val(is_eval);
int main(int argc, char **argv) {
CLI::App app{"Mnist Example"};
argv = app.ensure_utf8(argv);

std::string config_name = std::string(CONFIGS_FOLDER) + "/training_mnist_mlp.yaml";
bool is_eval = false;
app.add_option("-c,--config", config_name, "Yaml Config name")->default_val(config_name);
app.add_option("-e,--eval", config_name, "Evaluate")->default_val(is_eval);

CLI11_PARSE(app, argc, argv);
auto yaml_config = YAML::LoadFile(config_name);
TrainingConfig config = parse_config(yaml_config);

// Load MNIST data
const size_t num_targets = 10;
const size_t num_features = 784;
Expand Down Expand Up @@ -111,14 +131,14 @@ int main(int argc, char **argv) {
return std::make_pair(data_tensor, targets_tensor);
};

auto train_dataloader = DataLoader(training_dataset, batch_size, /* shuffle */ true, collate_fn);
auto test_dataloader = DataLoader(test_dataset, batch_size, /* shuffle */ false, collate_fn);
auto train_dataloader = DataLoader(training_dataset, config.batch_size, /* shuffle */ true, collate_fn);
auto test_dataloader = DataLoader(test_dataset, config.batch_size, /* shuffle */ false, collate_fn);

auto model = create_base_mlp(784, 10);
auto model = ttml::models::mlp::create(config.mlp_config);

const float learning_rate = 0.1F * (static_cast<float>(batch_size) / 128.F);
const float momentum = 0.9F;
const float weight_decay = 0.F;
const float learning_rate = config.learning_rate * (static_cast<float>(config.batch_size) / 128.F);
const float momentum = config.momentum;
const float weight_decay = config.weight_decay;
auto sgd_config =
ttml::optimizers::SGDConfig{.lr = learning_rate, .momentum = momentum, .weight_decay = weight_decay};

Expand All @@ -129,9 +149,9 @@ int main(int argc, char **argv) {
fmt::print(" Weight decay: {}\n", sgd_config.weight_decay);
fmt::print(" Nesterov: {}\n", sgd_config.nesterov);
auto optimizer = ttml::optimizers::SGD(model->parameters(), sgd_config);
if (!model_path.empty() && std::filesystem::exists(model_path)) {
fmt::print("Loading model from {}\n", model_path);
load_model_and_optimizer(model_path, model, optimizer, model_name, optimizer_name);
if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) {
fmt::print("Loading model from {}\n", config.model_path);
load_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name);
}

// evaluate model before training (sanity check to get reasonable accuracy
Expand All @@ -144,19 +164,19 @@ int main(int argc, char **argv) {

LossAverageMeter loss_meter;
int training_step = 0;
for (size_t epoch = 0; epoch < num_epochs; ++epoch) {
for (size_t epoch = 0; epoch < config.num_epochs; ++epoch) {
for (const auto &[data, target] : train_dataloader) {
optimizer.zero_grad();
auto output = (*model)(data);
auto loss = ttml::ops::cross_entropy_loss(output, target);
auto loss_float = ttml::core::to_vector(loss->get_value())[0];
loss_meter.update(loss_float, batch_size);
if (training_step % logging_interval == 0) {
loss_meter.update(loss_float, config.batch_size);
if (training_step % config.logging_interval == 0) {
fmt::print("Step: {:5d} | Average Loss: {:.4f}\n", training_step, loss_meter.average());
}
if (!model_path.empty() && training_step % model_save_interval == 0) {
fmt::print("Saving model to {}\n", model_path);
save_model_and_optimizer(model_path, model, optimizer, model_name, optimizer_name);
if (!config.model_path.empty() && training_step % config.model_save_interval == 0) {
fmt::print("Saving model to {}\n", config.model_path);
save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name);
}

loss->backward();
Expand All @@ -174,9 +194,9 @@ int main(int argc, char **argv) {
loss_meter.reset();
}

if (!model_path.empty()) {
fmt::print("Saving model to {}\n", model_path);
save_model_and_optimizer(model_path, model, optimizer, model_name, optimizer_name);
if (!config.model_path.empty()) {
fmt::print("Saving model to {}\n", config.model_path);
save_model_and_optimizer(config.model_path, model, optimizer, model_name, optimizer_name);
}

return 0;
Expand Down
46 changes: 0 additions & 46 deletions tt-train/sources/examples/mnist_mlp/models.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions tt-train/sources/examples/mnist_mlp/models.hpp

This file was deleted.

3 changes: 1 addition & 2 deletions tt-train/sources/examples/nano_gpt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ project(nano_gpt)
set(SOURCES
main.cpp
utils.cpp
models.cpp
)

add_executable(nano_gpt ${SOURCES})
target_link_libraries(nano_gpt PRIVATE ttml)

add_definitions(-DDATA_FOLDER="${CMAKE_SOURCE_DIR}/data")
add_definitions(-DCONFIGS_FOLDER="${CMAKE_SOURCE_DIR}/configs")

# Define the target file location
set(SHAKESPEARE_URL "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt")
Expand Down
Loading

0 comments on commit dce7015

Please sign in to comment.