Skip to content

Commit

Permalink
adding split_heads argument for retaining original (Q, K) dimensionan…
Browse files Browse the repository at this point in the history
…lity
  • Loading branch information
djsaunde committed Dec 18, 2024
1 parent 4e2dd6d commit 8290095
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 50 deletions.
14 changes: 10 additions & 4 deletions src/axolotl/cli/integrations/convert_differential_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import (
convert_to_differential_attention,
)
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path):

# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
model = convert_to_differential_attention(
model = convert_to_diff_attn(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)


def load_model_and_tokenizer(
Expand Down
11 changes: 8 additions & 3 deletions src/axolotl/integrations/differential_transformer/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ def copy_attention_weights(
)


def convert_to_differential_attention(
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0

# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads

def convert_module(module):
nonlocal layer_idx
Expand All @@ -111,7 +115,8 @@ def convert_module(module):

# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)

# Replace the layer
setattr(module, name, new_attention)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,51 @@ def __init__(
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads

if config.split_heads:
self.head_dim = config.hidden_size // config.num_attention_heads // 2
else:
self.head_dim = config.hidden_size // config.num_attention_heads

self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True

# For Q1 and Q2
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)

# For K1 and K2
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)
self.split_heads = config.split_heads

if config.split_heads:
# Split heads mode
assert (
self.base_num_heads % 2 == 0
), "Number of heads must be even for splitting"
self.heads_per_component = self.base_num_heads // 2

# Single projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
else:
# Double projection mode
self.heads_per_component = self.base_num_heads

# Double-sized projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)

# Single V projection
self.v_proj = nn.Linear(
Expand Down Expand Up @@ -125,8 +150,14 @@ def __init__(

self.rotary_emb = LlamaRotaryEmbedding(config=config)
sublayer_norm = getattr(config, "sublayer_norm", True)

if self.split_heads:
subln_dim = 2 * self.head_dim
else:
subln_dim = self.head_dim

self.subln = (
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
Expand Down Expand Up @@ -167,7 +198,10 @@ def forward(
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Reshape V
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Apply rotary embeddings
if position_embeddings is None:
Expand All @@ -177,6 +211,10 @@ def forward(
else:
cos, sin = position_embeddings

if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)

q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)

Expand All @@ -192,8 +230,6 @@ def forward(
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)

# Calculate attention scores for both parts
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)

Expand Down Expand Up @@ -307,13 +343,18 @@ def forward(
k1, k2 = kp.chunk(2, dim=-1)

# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Apply rotary embeddings
if position_embeddings is None:
Expand All @@ -323,6 +364,10 @@ def forward(
else:
cos, sin = position_embeddings

if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)

q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)

Expand Down Expand Up @@ -468,13 +513,18 @@ def forward(
k1, k2 = kp.chunk(2, dim=-1)

# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

# Apply rotary embeddings
if position_embeddings is None:
Expand All @@ -484,6 +534,10 @@ def forward(
else:
cos, sin = position_embeddings

if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)

q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)

Expand All @@ -506,20 +560,54 @@ def forward(

# Calculate attention using Flash Attention
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
if self.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(
q1,
k1,
v1,
dropout_p=dropout_p,
causal=True,
)
attn12 = flash_attn_func(
q1,
k1,
v2,
dropout_p=dropout_p,
causal=True,
)
attn1 = torch.cat([attn11, attn12], dim=-1)

attn21 = flash_attn_func(
q2,
k2,
v1,
dropout_p=dropout_p,
causal=True,
)
attn22 = flash_attn_func(
q2,
k2,
v2,
dropout_p=dropout_p,
causal=True,
)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)

attn1 = attn1.transpose(1, 2)
attn2 = attn2.transpose(1, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()


@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True

config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)

cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))

assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()

0 comments on commit 8290095

Please sign in to comment.