From bf944339391347469909c851ffaa774d3a84e17e Mon Sep 17 00:00:00 2001 From: Roman Furko Date: Mon, 6 Jan 2025 16:35:32 -0800 Subject: [PATCH] [tt-train] Memory efficient option to run GPT2 (#16205) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Problem description We can't fit GPT2-S with batch size larger than 8 into DRAM of N150. ### What's changed Add memory efficient runner. Performance 320 (default) vs 410 (memory_efficient) ### (GPT2-S, batch size 64, vocab size 96) Total model memory usage (dropout disabled) **Default runner** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 51038.1875 **Memory efficient runner** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 12078.1875 **Memory efficient runner after updates (including optimizer and model)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 11296.763854980469 **Default runner (single block)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 10578.984375 **Memory efficient runner (single block)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 10866.984375 ### (NanoGPT, batch size 64, vocab size 96) Total model memory usage (dropout disabled) **Default runner** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 2334.16796875 **Memory efficient runner** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 838.16796875 **Default runner (single block)** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 735.76953125 **Memory efficient runner (single block)** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 759.76953125 Loss curves completely coincide Screenshot 2024-12-18 at 8 58 58 PM ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12420110503 - [x] New/Existing tests provide coverage for changes --- .../configs/training_shakespear_nanogpt.yaml | 1 + ...raining_shakespear_nanogpt_memory_eff.yaml | 22 ++++++ .../sources/ttml/autograd/auto_context.cpp | 4 + .../sources/ttml/autograd/auto_context.hpp | 1 + tt-train/sources/ttml/core/scoped.hpp | 30 ++++++++ tt-train/sources/ttml/models/gpt2.cpp | 74 ++++++++++++++++++- tt-train/sources/ttml/models/gpt2.hpp | 7 ++ tt-train/tests/core/scoped_test.cpp | 21 ++++++ tt-train/tests/model/nano_gpt_test.cpp | 31 +++++++- 9 files changed, 186 insertions(+), 5 deletions(-) create mode 100644 tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml create mode 100644 tt-train/sources/ttml/core/scoped.hpp create mode 100644 tt-train/tests/core/scoped_test.cpp diff --git a/tt-train/configs/training_shakespear_nanogpt.yaml b/tt-train/configs/training_shakespear_nanogpt.yaml index 2c61fd0c7b8..1933373cd2b 100644 --- a/tt-train/configs/training_shakespear_nanogpt.yaml +++ b/tt-train/configs/training_shakespear_nanogpt.yaml @@ -17,5 +17,6 @@ training_config: vocab_size: 96 max_sequence_length: 256 positional_embedding_type: trainable + runner_type: default experimental: use_composite_layernorm: false diff --git a/tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml b/tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml new file mode 100644 index 00000000000..70919333c84 --- /dev/null +++ b/tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml @@ -0,0 +1,22 @@ +training_config: + project_name: "tt_train_nano_gpt" + seed: 5489 + model_save_interval: 500 + batch_size: 64 + num_epochs: 1 + max_steps: 5000 + learning_rate: 0.0003 + weight_decay: 0.01 + use_moreh_adamw: true + use_kahan_summation: false + transformer_config: + num_heads: 6 + embedding_dim: 384 + dropout_prob: 0.2 + num_blocks: 6 + vocab_size: 96 + max_sequence_length: 256 + runner_type: memory_efficient + positional_embedding_type: trainable + experimental: + use_composite_layernorm: false diff --git a/tt-train/sources/ttml/autograd/auto_context.cpp b/tt-train/sources/ttml/autograd/auto_context.cpp index ea0e27e269b..dff1ac0d5ff 100644 --- a/tt-train/sources/ttml/autograd/auto_context.cpp +++ b/tt-train/sources/ttml/autograd/auto_context.cpp @@ -12,6 +12,10 @@ std::mt19937& AutoContext::get_generator() { return m_generator; } +void AutoContext::set_generator(const std::mt19937& generator) { + m_generator = generator; +} + void AutoContext::set_seed(uint32_t seed) { m_seed = seed; m_generator = std::mt19937(m_seed); diff --git a/tt-train/sources/ttml/autograd/auto_context.hpp b/tt-train/sources/ttml/autograd/auto_context.hpp index a4124862ed3..cd62b151137 100644 --- a/tt-train/sources/ttml/autograd/auto_context.hpp +++ b/tt-train/sources/ttml/autograd/auto_context.hpp @@ -26,6 +26,7 @@ class AutoContext { static AutoContext& get_instance(); std::mt19937& get_generator(); + void set_generator(const std::mt19937& generator); void set_seed(uint32_t seed); diff --git a/tt-train/sources/ttml/core/scoped.hpp b/tt-train/sources/ttml/core/scoped.hpp new file mode 100644 index 00000000000..e2922daa277 --- /dev/null +++ b/tt-train/sources/ttml/core/scoped.hpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace ttml::core { + +template +class Scoped { + CloseFunction close_func_; + +public: + Scoped(OpenFunction&& open_func, CloseFunction&& close_func) : close_func_(std::move(close_func)) { + open_func(); + } + + Scoped(const Scoped&) = delete; + Scoped& operator=(const Scoped&) = delete; + Scoped(Scoped&& other) = delete; + Scoped& operator=(Scoped&&) = delete; + + ~Scoped() { + close_func_(); + } +}; + +} // namespace ttml::core diff --git a/tt-train/sources/ttml/models/gpt2.cpp b/tt-train/sources/ttml/models/gpt2.cpp index 74628323e75..7116a7ce2b9 100644 --- a/tt-train/sources/ttml/models/gpt2.cpp +++ b/tt-train/sources/ttml/models/gpt2.cpp @@ -4,12 +4,63 @@ #include "gpt2.hpp" +#include "autograd/graph_utils.hpp" +#include "autograd/tensor.hpp" +#include "core/scoped.hpp" #include "modules/positional_embeddings.hpp" #include "ops/binary_ops.hpp" #include "ops/unary_ops.hpp" namespace ttml::models::gpt2 { +namespace { + +autograd::TensorPtr memory_efficient_runner( + auto&& forward_impl, const autograd::TensorPtr& input, const autograd::TensorPtr& mask) { + if (autograd::ctx().get_gradient_mode() == autograd::GradMode::DISABLED) { + return forward_impl(input, mask); + } + + // make a copy of a generator before running forward pass + auto generator = autograd::ctx().get_generator(); + + // running forward pass + autograd::TensorPtr out; + { + auto scoped = ttml::core::Scoped( + []() { autograd::ctx().set_gradient_mode(autograd::GradMode::DISABLED); }, + []() { autograd::ctx().set_gradient_mode(autograd::GradMode::ENABLED); }); + out = forward_impl(input, mask); + } + + // define grad function and copy generator (in the state before forward pass) + autograd::GradFunction grad = [input, mask, out, &forward_impl, generator]() { + // detach input from existing graph + auto input_detached = autograd::create_tensor(input->get_value()); + // run forward pass again + autograd::TensorPtr output; + { + // set generator to the state before forward pass during construction + // restore generator state after grad function is executed + auto scoped = ttml::core::Scoped( + [&generator]() { autograd::ctx().set_generator(generator); }, + [generator = autograd::ctx().get_generator()]() { autograd::ctx().set_generator(generator); }); + output = forward_impl(input_detached, mask); + } + // use gradients from new output + output->set_grad(out->get_grad()); + output->backward(); + // reuse gradients from detached input + input->add_grad(input_detached->get_grad()); + }; + + auto links = autograd::get_links(input); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +} // namespace + Transformer::Transformer(const TransformerConfig& config) { uint32_t vocab_size = config.vocab_size; uint32_t max_sequence_length = config.max_sequence_length; @@ -19,6 +70,7 @@ Transformer::Transformer(const TransformerConfig& config) { uint32_t num_blocks = config.num_blocks; auto position_embedding_type = config.positional_embedding_type; auto use_composite_layernorm = config.experimental.use_composite_layernorm; + runner_type = config.runner_type; fmt::print("Transformer configuration:\n"); fmt::print(" Vocab size: {}\n", vocab_size); @@ -30,6 +82,7 @@ Transformer::Transformer(const TransformerConfig& config) { fmt::print( " Positional embedding type: {}\n", position_embedding_type == PositionalEmbeddingType::Trainable ? "Trainable" : "Fixed"); + fmt::print(" Runner type: {}\n", runner_type == RunnerType::Default ? "Default" : "Memory efficient"); fmt::print(" Composite layernorm: {}\n", use_composite_layernorm); uint32_t vocab_size_divisible_by_32 = (vocab_size + 31) / 32 * 32; @@ -83,7 +136,13 @@ ttml::autograd::TensorPtr Transformer::operator()( auto tok_emb_out = (*tok_emb)(x); auto out = (*pos_emb)(tok_emb_out); for (auto& block : blocks) { - out = (*block)(out, mask); + if (runner_type == RunnerType::MemoryEfficient) { + out = memory_efficient_runner(*block, out, mask); + } else if (runner_type == RunnerType::Default) { + out = (*block)(out, mask); + } else { + throw std::runtime_error("Unknown runner type. Supported runner types ['default', 'memory_efficient']"); + } } out = (*ln_fc)(out); auto logits = (*fc)(out); @@ -91,6 +150,18 @@ ttml::autograd::TensorPtr Transformer::operator()( return log_softmax; } +RunnerType read_runner_type(const YAML::Node& config) { + auto runner_type_str = config["runner_type"].as("default"); + if (runner_type_str == "default") { + return RunnerType::Default; + } else if (runner_type_str == "memory_efficient") { + return RunnerType::MemoryEfficient; + } else { + throw std::runtime_error(fmt::format( + "Unknown runner type: {}. Supported runner types [default, memory_efficient]", runner_type_str)); + } +} + PositionalEmbeddingType read_positional_embedding_type(const YAML::Node& config) { auto positional_embedding_str = config["positional_embedding_type"].as("trainable"); if (positional_embedding_str == "trainable") { @@ -113,6 +184,7 @@ TransformerConfig read_config(const YAML::Node& config) { transformer_config.vocab_size = config["vocab_size"].as(); transformer_config.max_sequence_length = config["max_sequence_length"].as(); transformer_config.positional_embedding_type = read_positional_embedding_type(config); + transformer_config.runner_type = read_runner_type(config); if (auto experimental_config = config["experimental"]) { transformer_config.experimental.use_composite_layernorm = diff --git a/tt-train/sources/ttml/models/gpt2.hpp b/tt-train/sources/ttml/models/gpt2.hpp index c630dcfc0ff..0ff2fe8215f 100644 --- a/tt-train/sources/ttml/models/gpt2.hpp +++ b/tt-train/sources/ttml/models/gpt2.hpp @@ -18,6 +18,11 @@ enum class PositionalEmbeddingType { Fixed, }; +enum class RunnerType { + MemoryEfficient, + Default, +}; + struct TransformerConfig { uint32_t num_heads = 6; uint32_t embedding_dim = 384; @@ -25,6 +30,7 @@ struct TransformerConfig { uint32_t num_blocks = 6; uint32_t vocab_size = 256; uint32_t max_sequence_length = 256; + RunnerType runner_type = RunnerType::Default; PositionalEmbeddingType positional_embedding_type = PositionalEmbeddingType::Trainable; struct Experimental { @@ -35,6 +41,7 @@ struct TransformerConfig { class Transformer : public ttml::autograd::ModuleBase { private: + RunnerType runner_type = RunnerType::Default; std::shared_ptr tok_emb; std::shared_ptr pos_emb; std::vector> blocks; diff --git a/tt-train/tests/core/scoped_test.cpp b/tt-train/tests/core/scoped_test.cpp new file mode 100644 index 00000000000..0626cf8b483 --- /dev/null +++ b/tt-train/tests/core/scoped_test.cpp @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "core/scoped.hpp" + +#include + +#include + +TEST(ScopedTest, Scoped) { + int variable = 0; + + { + EXPECT_EQ(variable, 0); + auto scoped = ttml::core::Scoped([&variable]() { variable = 1; }, [&variable]() { variable = 2; }); + EXPECT_EQ(variable, 1); + } + + EXPECT_EQ(variable, 2); +}; diff --git a/tt-train/tests/model/nano_gpt_test.cpp b/tt-train/tests/model/nano_gpt_test.cpp index 5dcf53d26d0..a9883437725 100644 --- a/tt-train/tests/model/nano_gpt_test.cpp +++ b/tt-train/tests/model/nano_gpt_test.cpp @@ -53,9 +53,11 @@ struct TrainingConfig { ttml::models::gpt2::TransformerConfig transformer_config; }; -void train_test(bool use_moreh_adamw = false) { +void train_test(bool use_moreh_adamw = false, bool memory_efficient = false) { auto config = TrainingConfig(); config.transformer_config.dropout_prob = 0.0F; + config.transformer_config.runner_type = + memory_efficient ? ttml::models::gpt2::RunnerType::MemoryEfficient : ttml::models::gpt2::RunnerType::Default; config.data_path = "/shakespeare.txt"; // set seed @@ -185,7 +187,10 @@ void train_test(bool use_moreh_adamw = false) { // verify time per step size_t num_steps_below = 0; - double expected_time_ms = 330.0; + const double expected_default_runner_time_ms = 330.0; + const double expected_memory_efficient_runner_time_ms = 450.0; + double expected_time_ms = + memory_efficient ? expected_memory_efficient_runner_time_ms : expected_default_runner_time_ms; for (auto &time : steps_time) { num_steps_below += (time < expected_time_ms); } @@ -241,7 +246,7 @@ TEST_F(NanoGPTTest, AdamW) { GTEST_SKIP() << "Skipping AdamW"; return; if (should_run_tests()) { - train_test(/* use_moreh_adamw */ false); + train_test(/* use_moreh_adamw */ false, /* memory_efficient */ false); } } @@ -250,6 +255,24 @@ TEST_F(NanoGPTTest, MorehAdamW) { return; if (should_run_tests()) { - train_test(/* use_moreh_adamw */ true); + train_test(/* use_moreh_adamw */ true, /* memory_efficient */ false); + } +} + +TEST_F(NanoGPTTest, AdamW_MemoryEfficient) { + GTEST_SKIP() << "Skipping AdamW + MemoryEfficient"; + return; + + if (should_run_tests()) { + train_test(/* use_moreh_adamw */ false, /* memory_efficient */ true); + } +} + +TEST_F(NanoGPTTest, MorehAdamW_MemoryEfficient) { + GTEST_SKIP() << "Skipping MorehAdamW + MemoryEfficient"; + return; + + if (should_run_tests()) { + train_test(/* use_moreh_adamw */ true, /* memory_efficient */ true); } }