Skip to content

Commit

Permalink
[tt-train] Memory efficient option to run GPT2 (#16205)
Browse files Browse the repository at this point in the history
### Problem description
We can't fit GPT2-S with batch size larger than 8 into DRAM of N150. 

### What's changed
Add memory efficient runner.
Performance 320 (default) vs 410 (memory_efficient)

### (GPT2-S, batch size 64, vocab size 96) Total model memory usage
(dropout disabled)

**Default runner**

Peak L1 memory usage (in MB): 7.64453125
Peak DRAM memory usage (in MB): 51038.1875

**Memory efficient runner**

Peak L1 memory usage (in MB): 7.64453125
Peak DRAM memory usage (in MB): 12078.1875

**Memory efficient runner after updates (including optimizer and
model)**

Peak L1 memory usage (in MB): 7.64453125
Peak DRAM memory usage (in MB): 11296.763854980469

**Default runner (single block)**

Peak L1 memory usage (in MB): 7.64453125
Peak DRAM memory usage (in MB): 10578.984375

**Memory efficient runner (single block)**

Peak L1 memory usage (in MB): 7.64453125
Peak DRAM memory usage (in MB): 10866.984375

### (NanoGPT, batch size 64, vocab size 96) Total model memory usage
(dropout disabled)

**Default runner**

Peak L1 memory usage (in MB): 1.2578125
Peak DRAM memory usage (in MB): 2334.16796875

**Memory efficient runner**

Peak L1 memory usage (in MB): 1.2578125
Peak DRAM memory usage (in MB): 838.16796875

**Default runner (single block)**

Peak L1 memory usage (in MB): 1.2578125
Peak DRAM memory usage (in MB): 735.76953125

**Memory efficient runner (single block)**

Peak L1 memory usage (in MB): 1.2578125
Peak DRAM memory usage (in MB): 759.76953125

Loss curves completely coincide
<img width="341" alt="Screenshot 2024-12-18 at 8 58 58 PM"
src="https://github.com/user-attachments/assets/bc613500-308a-40e8-9e75-e99b27f16385"
/>

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12420110503
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
rfurko-tt authored Jan 7, 2025
1 parent fa84979 commit bf94433
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 5 deletions.
1 change: 1 addition & 0 deletions tt-train/configs/training_shakespear_nanogpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ training_config:
vocab_size: 96
max_sequence_length: 256
positional_embedding_type: trainable
runner_type: default
experimental:
use_composite_layernorm: false
22 changes: 22 additions & 0 deletions tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
training_config:
project_name: "tt_train_nano_gpt"
seed: 5489
model_save_interval: 500
batch_size: 64
num_epochs: 1
max_steps: 5000
learning_rate: 0.0003
weight_decay: 0.01
use_moreh_adamw: true
use_kahan_summation: false
transformer_config:
num_heads: 6
embedding_dim: 384
dropout_prob: 0.2
num_blocks: 6
vocab_size: 96
max_sequence_length: 256
runner_type: memory_efficient
positional_embedding_type: trainable
experimental:
use_composite_layernorm: false
4 changes: 4 additions & 0 deletions tt-train/sources/ttml/autograd/auto_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ std::mt19937& AutoContext::get_generator() {
return m_generator;
}

void AutoContext::set_generator(const std::mt19937& generator) {
m_generator = generator;
}

void AutoContext::set_seed(uint32_t seed) {
m_seed = seed;
m_generator = std::mt19937(m_seed);
Expand Down
1 change: 1 addition & 0 deletions tt-train/sources/ttml/autograd/auto_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AutoContext {
static AutoContext& get_instance();

std::mt19937& get_generator();
void set_generator(const std::mt19937& generator);

void set_seed(uint32_t seed);

Expand Down
30 changes: 30 additions & 0 deletions tt-train/sources/ttml/core/scoped.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <utility>

namespace ttml::core {

template <typename OpenFunction, typename CloseFunction>
class Scoped {
CloseFunction close_func_;

public:
Scoped(OpenFunction&& open_func, CloseFunction&& close_func) : close_func_(std::move(close_func)) {
open_func();
}

Scoped(const Scoped&) = delete;
Scoped& operator=(const Scoped&) = delete;
Scoped(Scoped&& other) = delete;
Scoped& operator=(Scoped&&) = delete;

~Scoped() {
close_func_();
}
};

} // namespace ttml::core
74 changes: 73 additions & 1 deletion tt-train/sources/ttml/models/gpt2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,63 @@

#include "gpt2.hpp"

#include "autograd/graph_utils.hpp"
#include "autograd/tensor.hpp"
#include "core/scoped.hpp"
#include "modules/positional_embeddings.hpp"
#include "ops/binary_ops.hpp"
#include "ops/unary_ops.hpp"

namespace ttml::models::gpt2 {

namespace {

autograd::TensorPtr memory_efficient_runner(
auto&& forward_impl, const autograd::TensorPtr& input, const autograd::TensorPtr& mask) {
if (autograd::ctx().get_gradient_mode() == autograd::GradMode::DISABLED) {
return forward_impl(input, mask);
}

// make a copy of a generator before running forward pass
auto generator = autograd::ctx().get_generator();

// running forward pass
autograd::TensorPtr out;
{
auto scoped = ttml::core::Scoped(
[]() { autograd::ctx().set_gradient_mode(autograd::GradMode::DISABLED); },
[]() { autograd::ctx().set_gradient_mode(autograd::GradMode::ENABLED); });
out = forward_impl(input, mask);
}

// define grad function and copy generator (in the state before forward pass)
autograd::GradFunction grad = [input, mask, out, &forward_impl, generator]() {
// detach input from existing graph
auto input_detached = autograd::create_tensor(input->get_value());
// run forward pass again
autograd::TensorPtr output;
{
// set generator to the state before forward pass during construction
// restore generator state after grad function is executed
auto scoped = ttml::core::Scoped(
[&generator]() { autograd::ctx().set_generator(generator); },
[generator = autograd::ctx().get_generator()]() { autograd::ctx().set_generator(generator); });
output = forward_impl(input_detached, mask);
}
// use gradients from new output
output->set_grad(out->get_grad());
output->backward();
// reuse gradients from detached input
input->add_grad(input_detached->get_grad());
};

auto links = autograd::get_links(input);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
return out;
}

} // namespace

Transformer::Transformer(const TransformerConfig& config) {
uint32_t vocab_size = config.vocab_size;
uint32_t max_sequence_length = config.max_sequence_length;
Expand All @@ -19,6 +70,7 @@ Transformer::Transformer(const TransformerConfig& config) {
uint32_t num_blocks = config.num_blocks;
auto position_embedding_type = config.positional_embedding_type;
auto use_composite_layernorm = config.experimental.use_composite_layernorm;
runner_type = config.runner_type;

fmt::print("Transformer configuration:\n");
fmt::print(" Vocab size: {}\n", vocab_size);
Expand All @@ -30,6 +82,7 @@ Transformer::Transformer(const TransformerConfig& config) {
fmt::print(
" Positional embedding type: {}\n",
position_embedding_type == PositionalEmbeddingType::Trainable ? "Trainable" : "Fixed");
fmt::print(" Runner type: {}\n", runner_type == RunnerType::Default ? "Default" : "Memory efficient");
fmt::print(" Composite layernorm: {}\n", use_composite_layernorm);

uint32_t vocab_size_divisible_by_32 = (vocab_size + 31) / 32 * 32;
Expand Down Expand Up @@ -83,14 +136,32 @@ ttml::autograd::TensorPtr Transformer::operator()(
auto tok_emb_out = (*tok_emb)(x);
auto out = (*pos_emb)(tok_emb_out);
for (auto& block : blocks) {
out = (*block)(out, mask);
if (runner_type == RunnerType::MemoryEfficient) {
out = memory_efficient_runner(*block, out, mask);
} else if (runner_type == RunnerType::Default) {
out = (*block)(out, mask);
} else {
throw std::runtime_error("Unknown runner type. Supported runner types ['default', 'memory_efficient']");
}
}
out = (*ln_fc)(out);
auto logits = (*fc)(out);
auto log_softmax = ttml::ops::log_softmax(logits, 3);
return log_softmax;
}

RunnerType read_runner_type(const YAML::Node& config) {
auto runner_type_str = config["runner_type"].as<std::string>("default");
if (runner_type_str == "default") {
return RunnerType::Default;
} else if (runner_type_str == "memory_efficient") {
return RunnerType::MemoryEfficient;
} else {
throw std::runtime_error(fmt::format(
"Unknown runner type: {}. Supported runner types [default, memory_efficient]", runner_type_str));
}
}

PositionalEmbeddingType read_positional_embedding_type(const YAML::Node& config) {
auto positional_embedding_str = config["positional_embedding_type"].as<std::string>("trainable");
if (positional_embedding_str == "trainable") {
Expand All @@ -113,6 +184,7 @@ TransformerConfig read_config(const YAML::Node& config) {
transformer_config.vocab_size = config["vocab_size"].as<uint32_t>();
transformer_config.max_sequence_length = config["max_sequence_length"].as<uint32_t>();
transformer_config.positional_embedding_type = read_positional_embedding_type(config);
transformer_config.runner_type = read_runner_type(config);

if (auto experimental_config = config["experimental"]) {
transformer_config.experimental.use_composite_layernorm =
Expand Down
7 changes: 7 additions & 0 deletions tt-train/sources/ttml/models/gpt2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ enum class PositionalEmbeddingType {
Fixed,
};

enum class RunnerType {
MemoryEfficient,
Default,
};

struct TransformerConfig {
uint32_t num_heads = 6;
uint32_t embedding_dim = 384;
float dropout_prob = 0.2F;
uint32_t num_blocks = 6;
uint32_t vocab_size = 256;
uint32_t max_sequence_length = 256;
RunnerType runner_type = RunnerType::Default;
PositionalEmbeddingType positional_embedding_type = PositionalEmbeddingType::Trainable;

struct Experimental {
Expand All @@ -35,6 +41,7 @@ struct TransformerConfig {

class Transformer : public ttml::autograd::ModuleBase {
private:
RunnerType runner_type = RunnerType::Default;
std::shared_ptr<ttml::modules::Embedding> tok_emb;
std::shared_ptr<ttml::modules::PositionalEmbeddingBase> pos_emb;
std::vector<std::shared_ptr<ttml::modules::GPTBlock>> blocks;
Expand Down
21 changes: 21 additions & 0 deletions tt-train/tests/core/scoped_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "core/scoped.hpp"

#include <gtest/gtest.h>

#include <core/ttnn_all_includes.hpp>

TEST(ScopedTest, Scoped) {
int variable = 0;

{
EXPECT_EQ(variable, 0);
auto scoped = ttml::core::Scoped([&variable]() { variable = 1; }, [&variable]() { variable = 2; });
EXPECT_EQ(variable, 1);
}

EXPECT_EQ(variable, 2);
};
31 changes: 27 additions & 4 deletions tt-train/tests/model/nano_gpt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ struct TrainingConfig {
ttml::models::gpt2::TransformerConfig transformer_config;
};

void train_test(bool use_moreh_adamw = false) {
void train_test(bool use_moreh_adamw = false, bool memory_efficient = false) {
auto config = TrainingConfig();
config.transformer_config.dropout_prob = 0.0F;
config.transformer_config.runner_type =
memory_efficient ? ttml::models::gpt2::RunnerType::MemoryEfficient : ttml::models::gpt2::RunnerType::Default;
config.data_path = "/shakespeare.txt";

// set seed
Expand Down Expand Up @@ -185,7 +187,10 @@ void train_test(bool use_moreh_adamw = false) {

// verify time per step
size_t num_steps_below = 0;
double expected_time_ms = 330.0;
const double expected_default_runner_time_ms = 330.0;
const double expected_memory_efficient_runner_time_ms = 450.0;
double expected_time_ms =
memory_efficient ? expected_memory_efficient_runner_time_ms : expected_default_runner_time_ms;
for (auto &time : steps_time) {
num_steps_below += (time < expected_time_ms);
}
Expand Down Expand Up @@ -241,7 +246,7 @@ TEST_F(NanoGPTTest, AdamW) {
GTEST_SKIP() << "Skipping AdamW";
return;
if (should_run_tests()) {
train_test(/* use_moreh_adamw */ false);
train_test(/* use_moreh_adamw */ false, /* memory_efficient */ false);
}
}

Expand All @@ -250,6 +255,24 @@ TEST_F(NanoGPTTest, MorehAdamW) {
return;

if (should_run_tests()) {
train_test(/* use_moreh_adamw */ true);
train_test(/* use_moreh_adamw */ true, /* memory_efficient */ false);
}
}

TEST_F(NanoGPTTest, AdamW_MemoryEfficient) {
GTEST_SKIP() << "Skipping AdamW + MemoryEfficient";
return;

if (should_run_tests()) {
train_test(/* use_moreh_adamw */ false, /* memory_efficient */ true);
}
}

TEST_F(NanoGPTTest, MorehAdamW_MemoryEfficient) {
GTEST_SKIP() << "Skipping MorehAdamW + MemoryEfficient";
return;

if (should_run_tests()) {
train_test(/* use_moreh_adamw */ true, /* memory_efficient */ true);
}
}

0 comments on commit bf94433

Please sign in to comment.