Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.
BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.
Multi-token prediction - link
Weight tying - link
Grouped query attention - link
Capping logits - link
QKV bias - link
Zero-init projection layer - link
Post and pre-RMSNorm - link
Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324
Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527
KV-Shifting attention - link - seems to improve performance - loss: 3.3310
-> 3.3138
- peak memory consumption: 42858 MiB
Dilated Attention (LongNet) - link
MegaByte - link - loss: 3.810
FTP (heavily modified) - link - loss: 3.901
Rene - link - loss: 3.340
Rwkv7 - link - loss: 4.450
Zamba2 - link - Zamba2 > Rene > Rwkv7
Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710
Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505
Tokenformer (in BlaGPT model) - link - loss: 3.390
PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results
Ademamix - link - Unstable even after trying different learning rates.
Adopt - link - straight up Nan
CAdamW - link - loss: 3.3517
AdamW with independent weight decay - link - loss: 3.320
Adam - loss: 3.3224
AdamW - loss: 3.3310
, peak VRAM: 42053 MiB
, step_time: 533ms
DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676
, peak VRAM: 41534 MiB
, step_time: 820ms
Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324
, peak VRAM: 41534 MiB
, step_time: 610ms
BlaGPT with the following configurations:
{
"params": {
"norm_layer": "rmsnorm",
"attention": "GQA",
"activation": "swiglu",
"tie_embed_weights": true,
"zero_init_proj_layers": true,
"use_rotary_emb": true,
"rmsnorm_before_qk": true
},
"config": {
"block_size": 1024,
"vocab_size": 50304,
"n_layer": 12,
"n_head": 12,
"n_embd": 768,
"dropout": 0.0,
"bias": true,
"norm_layer": "rmsnorm",
"attention": "GQA",
"activation": "swiglu",
"use_soft_logit_capping": false,
"n_kv_head": 4,
"tie_embed_weights": true,
"zero_init_proj_layers": true,
"rmsnorm_before_qk": true,
"use_rotary_emb": true
},
"val_loss": 3.2993,
"memory_usage": 49403,
},
- Implement the model
- Return the loss in the forward function
- Register the model
- And start training
See one of the implementations for details.
-
Get the data by running
data/fineweb10B_cached.py
-
Start training with:
torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt
The initial code is based on
Nano GPT - link
Modded NanoGPT - link