-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tt-train] Memory efficient option to run GPT2 (#16205)
### Problem description We can't fit GPT2-S with batch size larger than 8 into DRAM of N150. ### What's changed Add memory efficient runner. Performance 320 (default) vs 410 (memory_efficient) ### (GPT2-S, batch size 64, vocab size 96) Total model memory usage (dropout disabled) **Default runner** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 51038.1875 **Memory efficient runner** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 12078.1875 **Memory efficient runner after updates (including optimizer and model)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 11296.763854980469 **Default runner (single block)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 10578.984375 **Memory efficient runner (single block)** Peak L1 memory usage (in MB): 7.64453125 Peak DRAM memory usage (in MB): 10866.984375 ### (NanoGPT, batch size 64, vocab size 96) Total model memory usage (dropout disabled) **Default runner** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 2334.16796875 **Memory efficient runner** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 838.16796875 **Default runner (single block)** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 735.76953125 **Memory efficient runner (single block)** Peak L1 memory usage (in MB): 1.2578125 Peak DRAM memory usage (in MB): 759.76953125 Loss curves completely coincide <img width="341" alt="Screenshot 2024-12-18 at 8 58 58 PM" src="https://github.com/user-attachments/assets/bc613500-308a-40e8-9e75-e99b27f16385" /> ### Checklist - [x] Post commit CI passes https://github.com/tenstorrent/tt-metal/actions/runs/12420110503 - [x] New/Existing tests provide coverage for changes
- Loading branch information
Showing
9 changed files
with
186 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
tt-train/configs/training_shakespear_nanogpt_memory_eff.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
training_config: | ||
project_name: "tt_train_nano_gpt" | ||
seed: 5489 | ||
model_save_interval: 500 | ||
batch_size: 64 | ||
num_epochs: 1 | ||
max_steps: 5000 | ||
learning_rate: 0.0003 | ||
weight_decay: 0.01 | ||
use_moreh_adamw: true | ||
use_kahan_summation: false | ||
transformer_config: | ||
num_heads: 6 | ||
embedding_dim: 384 | ||
dropout_prob: 0.2 | ||
num_blocks: 6 | ||
vocab_size: 96 | ||
max_sequence_length: 256 | ||
runner_type: memory_efficient | ||
positional_embedding_type: trainable | ||
experimental: | ||
use_composite_layernorm: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <utility> | ||
|
||
namespace ttml::core { | ||
|
||
template <typename OpenFunction, typename CloseFunction> | ||
class Scoped { | ||
CloseFunction close_func_; | ||
|
||
public: | ||
Scoped(OpenFunction&& open_func, CloseFunction&& close_func) : close_func_(std::move(close_func)) { | ||
open_func(); | ||
} | ||
|
||
Scoped(const Scoped&) = delete; | ||
Scoped& operator=(const Scoped&) = delete; | ||
Scoped(Scoped&& other) = delete; | ||
Scoped& operator=(Scoped&&) = delete; | ||
|
||
~Scoped() { | ||
close_func_(); | ||
} | ||
}; | ||
|
||
} // namespace ttml::core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "core/scoped.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <core/ttnn_all_includes.hpp> | ||
|
||
TEST(ScopedTest, Scoped) { | ||
int variable = 0; | ||
|
||
{ | ||
EXPECT_EQ(variable, 0); | ||
auto scoped = ttml::core::Scoped([&variable]() { variable = 1; }, [&variable]() { variable = 2; }); | ||
EXPECT_EQ(variable, 1); | ||
} | ||
|
||
EXPECT_EQ(variable, 2); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters