Skip to content

Commit

Permalink
[WIP] Changes for training entropy model and correcting attention in …
Browse files Browse the repository at this point in the history
…local models

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
  • Loading branch information
EntilZha committed Jan 17, 2025
1 parent caec8d2 commit 7f305b3
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 138 deletions.
7 changes: 7 additions & 0 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs

logger = logging.getLogger()

Expand Down Expand Up @@ -163,6 +164,8 @@ class TrainArgs(BaseModel):

seed: int = 42

debug_dynamo: bool = False

# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
Expand All @@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()

Expand Down
45 changes: 36 additions & 9 deletions bytelatent/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Tuple, Union

import torch
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
Expand All @@ -15,6 +15,7 @@
from xformers.ops import AttentionBias, fmha

from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID

if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
Expand All @@ -30,25 +31,31 @@ class InitStdFactor(Enum):


class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512
n_layers: int = 8
head_dim: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None

ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None

multiple_of: int = 256

norm_eps: float = 1e-5

rope_theta: float = 10000.0

init_base_std: Optional[float] = None
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED

max_seqlen: int = 1024

attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID


def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
Expand Down Expand Up @@ -294,6 +301,18 @@ def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore


def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform


class Attention(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -371,9 +390,12 @@ def forward(
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D

elif attn_impl == "fmha":
elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
# This uses B S H D instead of B H S D of pytorch

elif attn_impl == "sdpa":
Expand Down Expand Up @@ -522,14 +544,16 @@ def forward(
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
h = x + self.attention(
attn_out = self.attention(
self.attention_norm(x),
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
out = h + self.feed_forward(self.ffn_norm(h))
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out

def init_weights(self, init_std=None, factor=1.0):
Expand All @@ -545,13 +569,16 @@ def __init__(self, args: BaseTransformerArgs):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.eos_id = args.eos_id

self.layers = nn.ModuleList()
for _ in range(args.n_layers):
Expand Down
3 changes: 1 addition & 2 deletions bytelatent/configs/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ optim:

distributed:
fsdp_type: full_shard
compile: true
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
Expand Down Expand Up @@ -58,13 +57,13 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
efficient_attn: "sdpa"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512
Expand Down
3 changes: 3 additions & 0 deletions bytelatent/data/iterators/test_arrow_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = initial_state.build()
start_state = arrow_file.get_state()
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=251,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = resumed_state.build()
for example in arrow_file.create_iter():
Expand All @@ -74,6 +76,7 @@ def test_basic_arrow_file():
dataset_files=[ARROW_TEST_DATA_1],
row_num=0,
arrow_batch_size=100,
s3_profile=None,
)
arrow_file = rank_state.build()
expected_ids = []
Expand Down
1 change: 0 additions & 1 deletion bytelatent/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import subprocess
import sys
import tempfile
from dataclasses import asdict, dataclass
from functools import lru_cache, partial, reduce
from itertools import chain
from typing import List, Optional, Tuple, Union
Expand Down
10 changes: 9 additions & 1 deletion bytelatent/entropy_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
import re

import torch

from bytelatent.transformer import LMTransformer, LMTransformerArgs

logger = logging.getLogger()


def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
reloaded = json.loads(fr.read())

torch.set_default_dtype(torch.bfloat16)
model_params = reloaded["model"]
logger.warning(
"Update checkpoint to load attn and sliding window args from checkpoint"
)
entropy_model = LMTransformer(
LMTransformerArgs(
dim=model_params["dim"],
Expand All @@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
max_seqlen=model_params["max_length"],
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
vocab_size=model_params["vocab_size"],
attn_bias_type="local_block_causal",
attn_impl="xformers",
sliding_window=512,
)
)

Expand Down
Loading

0 comments on commit 7f305b3

Please sign in to comment.