-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tt-train] Add nanogpt tests with AdamW and MorehAdamW (#15443)
### Problem description Changes in tt-metal can influence tt-train results. Tests will provide us a layer of protection against it. ### What's changed Add two tests for NanoGPT training. ### Checklist - [x] Post commit CI passes - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [x] New/Existing tests provide coverage for changes All post-commit tests: https://github.com/tenstorrent/tt-metal/actions/runs/12036195701
- Loading branch information
Showing
4 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <fmt/format.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include <core/ttnn_all_includes.hpp> | ||
|
||
#include "autograd/auto_context.hpp" | ||
#include "core/tt_tensor_utils.hpp" | ||
#include "datasets/dataloader.hpp" | ||
#include "datasets/in_memory_token_dataset.hpp" | ||
#include "datasets/utils.hpp" | ||
#include "models/gpt2.hpp" | ||
#include "ops/losses.hpp" | ||
#include "optimizers/adamw.hpp" | ||
#include "optimizers/optimizer_base.hpp" | ||
#include "tokenizers/char_tokenizer.hpp" | ||
|
||
using ttml::autograd::TensorPtr; | ||
|
||
using DatasetSample = std::pair<std::span<const uint32_t>, std::span<const uint32_t>>; | ||
// tokens, targets, mask, positions | ||
using BatchType = std::tuple<TensorPtr, TensorPtr, TensorPtr, TensorPtr>; | ||
using DataLoader = ttml::datasets::DataLoader< | ||
ttml::datasets::InMemoryTokenDataset, | ||
std::function<BatchType(std::vector<DatasetSample> &&samples)>, | ||
BatchType>; | ||
|
||
struct TrainingConfig { | ||
std::string project_name; | ||
uint32_t seed = 5489U; | ||
uint32_t model_save_interval = 500; | ||
uint32_t batch_size = 64; | ||
uint32_t num_epochs = 1; | ||
uint32_t max_steps = 100; | ||
float learning_rate = 3e-4F; | ||
float weight_decay = 1e-2F; | ||
std::string model_path; | ||
std::string data_path; | ||
ttml::models::gpt2::TransformerConfig transformer_config; | ||
}; | ||
|
||
void train_test(bool use_moreh_adamw = false) { | ||
auto config = TrainingConfig(); | ||
config.transformer_config.dropout_prob = 0.0F; | ||
config.data_path = std::string(TEST_DATA_DIR) + "/shakespeare.txt"; | ||
|
||
// set seed | ||
ttml::autograd::ctx().set_seed(config.seed); | ||
|
||
std::string text; | ||
// reading training data from txt file | ||
{ | ||
std::ifstream file(config.data_path); | ||
if (!file.is_open()) { | ||
throw std::runtime_error("Failed to open file: " + config.data_path); | ||
} | ||
|
||
std::stringstream buffer; | ||
buffer << file.rdbuf(); | ||
|
||
text = buffer.str(); | ||
} | ||
|
||
auto *device = &ttml::autograd::ctx().get_device(); | ||
device->enable_program_cache(); | ||
|
||
auto sequence_length = config.transformer_config.max_sequence_length; | ||
|
||
auto [dataset, tokenizer] = | ||
ttml::datasets::create_in_memory_token_dataset<ttml::tokenizers::CharTokenizer>(text, sequence_length); | ||
|
||
struct CachedHostData { | ||
std::vector<uint32_t> data; | ||
std::vector<int32_t> targets; | ||
ttml::autograd::TensorPtr masks_tensor; | ||
ttml::autograd::TensorPtr positions_tensor; | ||
}; | ||
CachedHostData cached_data; | ||
std::vector<uint32_t> positions; | ||
std::vector<float> mask; | ||
positions.reserve((size_t)config.batch_size * sequence_length); | ||
for (int sample_idx = 0; sample_idx < config.batch_size; ++sample_idx) { | ||
for (int i = 0; i < sequence_length; ++i) { | ||
positions.push_back(i); | ||
} | ||
} | ||
auto num_heads = config.transformer_config.num_heads; | ||
mask.reserve((size_t)config.batch_size * sequence_length * sequence_length * num_heads); | ||
for (int sample_idx = 0; sample_idx < config.batch_size; ++sample_idx) { | ||
for (int head = 0; head < num_heads; ++head) { | ||
for (int i = 0; i < sequence_length; ++i) { | ||
for (int j = 0; j < sequence_length; ++j) { | ||
mask.push_back(i >= j ? 1.0F : 0.0F); | ||
} | ||
} | ||
} | ||
} | ||
cached_data.masks_tensor = ttml::autograd::create_tensor(ttml::core::from_vector( | ||
mask, ttml::core::create_shape({config.batch_size, num_heads, sequence_length, sequence_length}), device)); | ||
cached_data.positions_tensor = ttml::autograd::create_tensor(ttml::core::from_vector<uint32_t, DataType::UINT32>( | ||
positions, ttml::core::create_shape({config.batch_size, 1, 1, sequence_length}), device, Layout::ROW_MAJOR)); | ||
|
||
std::function<BatchType(std::vector<DatasetSample> && samples)> collate_fn = | ||
[sequence_length, num_heads, vocab_size = tokenizer.get_vocab_size(), device, &cached_data]( | ||
std::vector<DatasetSample> &&samples) { | ||
auto start_timer = std::chrono::high_resolution_clock::now(); | ||
const uint32_t batch_size = samples.size(); | ||
std::vector<uint32_t> &data = cached_data.data; | ||
std::vector<int32_t> &targets = cached_data.targets; | ||
|
||
data.clear(); | ||
targets.clear(); | ||
|
||
data.reserve((size_t)batch_size * sequence_length); | ||
targets.reserve((size_t)batch_size * sequence_length); | ||
for (auto &[features, target_span] : samples) { | ||
std::copy(features.begin(), features.end(), std::back_inserter(data)); | ||
std::copy(target_span.begin(), target_span.end(), std::back_inserter(targets)); | ||
} | ||
auto end_timer = std::chrono::high_resolution_clock::now(); | ||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_timer - start_timer).count(); | ||
fmt::print("dataloader host only step time {} ms\n", (double)duration / 1000.); | ||
auto data_tensor = ttml::autograd::create_tensor(ttml::core::from_vector<uint32_t, DataType::UINT32>( | ||
data, ttml::core::create_shape({batch_size, 1, 1, sequence_length}), device, Layout::ROW_MAJOR)); | ||
auto targets_tensor = ttml::autograd::create_tensor( | ||
ttml::core::from_vector<int32_t, DataType::INT32>(targets, {batch_size * sequence_length}, device)); | ||
end_timer = std::chrono::high_resolution_clock::now(); | ||
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_timer - start_timer).count(); | ||
fmt::print("dataloader step time {} ms\n", (double)duration / 1000.); | ||
return std::make_tuple(data_tensor, targets_tensor, cached_data.masks_tensor, cached_data.positions_tensor); | ||
}; | ||
auto train_dataloader = DataLoader(dataset, /* batch_size */ config.batch_size, /* shuffle */ true, collate_fn); | ||
|
||
fmt::print("Overriding vocab size to be divisible by 32\n"); | ||
config.transformer_config.vocab_size = (tokenizer.get_vocab_size() + 31) / 32 * 32; | ||
auto model = ttml::models::gpt2::create(config.transformer_config); | ||
|
||
auto adamw_params = ttml::optimizers::AdamWConfig(); | ||
adamw_params.lr = config.learning_rate; | ||
adamw_params.weight_decay = config.weight_decay; | ||
|
||
auto create_optimizer = [&]() -> std::shared_ptr<ttml::optimizers::OptimizerBase> { | ||
if (use_moreh_adamw) { | ||
return std::make_shared<ttml::optimizers::MorehAdamW>(model->parameters(), adamw_params); | ||
} else { | ||
return std::make_shared<ttml::optimizers::AdamW>(model->parameters(), adamw_params); | ||
} | ||
}; | ||
|
||
auto optimizer = create_optimizer(); | ||
|
||
std::vector<double> steps_time; | ||
std::vector<float> losses; | ||
|
||
for (auto [features, target, masks, positions] : train_dataloader) { | ||
auto start_timer = std::chrono::high_resolution_clock::now(); | ||
optimizer->zero_grad(); | ||
auto output = (*model)(features, positions, masks); | ||
auto loss = ttml::ops::nll_loss(output, target); | ||
auto loss_float = ttml::core::to_vector(loss->get_value())[0]; | ||
loss->backward(); | ||
optimizer->step(); | ||
ttml::autograd::ctx().reset_graph(); | ||
auto global_step = optimizer->get_steps(); | ||
losses.emplace_back(loss_float); | ||
if (global_step >= config.max_steps) { | ||
break; | ||
} | ||
auto end_timer = std::chrono::high_resolution_clock::now(); | ||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_timer - start_timer).count(); | ||
steps_time.emplace_back((double)duration / 1000.0); | ||
} | ||
|
||
// verify program cache | ||
auto program_cache_entries = device->num_program_cache_entries(); | ||
if (!use_moreh_adamw) { | ||
EXPECT_EQ(program_cache_entries, 124); | ||
} else { | ||
EXPECT_EQ(program_cache_entries, 103); | ||
} | ||
|
||
// verify time per step | ||
size_t num_steps_below = 0; | ||
double expected_time_ms = 330.0; | ||
for (auto &time : steps_time) { | ||
num_steps_below += (time < expected_time_ms); | ||
} | ||
if (num_steps_below / static_cast<double>(steps_time.size()) < 0.9) { | ||
EXPECT_TRUE(false); | ||
} | ||
|
||
// verify loss | ||
if (!use_moreh_adamw) { | ||
EXPECT_EQ(losses.size(), config.max_steps); | ||
EXPECT_EQ(losses[0], 4.6875); | ||
EXPECT_EQ(losses[9], 2.96875); | ||
EXPECT_EQ(losses[19], 2.703125); | ||
EXPECT_EQ(losses[29], 2.59375); | ||
EXPECT_EQ(losses[39], 2.546875); | ||
EXPECT_EQ(losses[49], 2.5); | ||
EXPECT_EQ(losses[59], 2.484375); | ||
EXPECT_EQ(losses[69], 2.46875); | ||
EXPECT_EQ(losses[79], 2.453125); | ||
EXPECT_EQ(losses[89], 2.4375); | ||
EXPECT_EQ(losses[99], 2.453125); | ||
} else { | ||
EXPECT_EQ(losses.size(), config.max_steps); | ||
EXPECT_EQ(losses[0], 4.6875); | ||
EXPECT_EQ(losses[9], 2.96875); | ||
EXPECT_EQ(losses[19], 2.703125); | ||
EXPECT_EQ(losses[29], 2.59375); | ||
EXPECT_EQ(losses[39], 2.546875); | ||
EXPECT_EQ(losses[49], 2.484375); | ||
EXPECT_EQ(losses[59], 2.484375); | ||
EXPECT_EQ(losses[69], 2.46875); | ||
EXPECT_EQ(losses[79], 2.453125); | ||
EXPECT_EQ(losses[89], 2.4375); | ||
EXPECT_EQ(losses[99], 2.4375); | ||
} | ||
} | ||
|
||
bool should_run_tests() { | ||
const char *env_var = std::getenv("ENABLE_CI_ONLY_TT_TRAIN_TESTS"); | ||
return env_var ? true : ENABLE_CI_ONLY_TT_TRAIN_TESTS; | ||
} | ||
|
||
/* | ||
This tests are supposed to run only in CI. | ||
Change the value of ENABLE_CI_ONLY_TT_TRAIN_TESTS to true to run them. | ||
If one of these tests fails, it means one (or more) of the following: | ||
- program cache size changed (new ops added/removed silently) | ||
- time per step changed (performance regression) | ||
- loss values changed (regression in ops accuracy) | ||
*/ | ||
|
||
TEST(NanoGPTTest, AdamW) { | ||
if (should_run_tests()) { | ||
train_test(/* use_moreh_adamw */ false); | ||
} | ||
} | ||
|
||
TEST(NanoGPTTest, MorehAdamW) { | ||
if (should_run_tests()) { | ||
train_test(/* use_moreh_adamw */ true); | ||
} | ||
} |