diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py index a687a3f7c..b50dd43dd 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -14,9 +14,7 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import ( - convert_to_differential_attention, -) +from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn LOG = logging.getLogger(__name__) @@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") + if cli_args.split_heads and cli_args.zero_init: + LOG.warning( + Fore.YELLOW + + "Warning: Using split_heads with zero_init is not recommended; " + + "split_heads will preclude the effects of zero_init" + + Fore.RESET + ) try: - model = convert_to_differential_attention( + model = convert_to_diff_attn( model=model, zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, + split_heads=cli_args.split_heads, ) model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 2d6a5bb31..c51c4e2ab 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) sublayer_norm: bool = field(default=True) + split_heads: bool = field(default=False) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index d516f9476..4beaea7ae 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -80,14 +80,18 @@ def copy_attention_weights( ) -def convert_to_differential_attention( - model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True +def convert_to_diff_attn( + model: PreTrainedModel, + zero_init: bool = False, + sublayer_norm: bool = True, + split_heads: bool = True, ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" layer_idx = 0 # Set sublayer norm as config on the model. model.config.sublayer_norm = sublayer_norm + model.config.split_heads = split_heads def convert_module(module): nonlocal layer_idx @@ -111,7 +115,8 @@ def convert_module(module): # Copy weights from old attention to new attention new_attention.to(child.q_proj.weight.device) - copy_attention_weights(child, new_attention, zero_init=zero_init) + if not split_heads: + copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer setattr(module, name, new_attention) diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py index 1543981ea..58d4b94ec 100644 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -70,26 +70,51 @@ def __init__( self.hidden_size = config.hidden_size self.base_num_heads = config.num_attention_heads self.base_num_kv_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads + + if config.split_heads: + self.head_dim = config.hidden_size // config.num_attention_heads // 2 + else: + self.head_dim = config.hidden_size // config.num_attention_heads self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - - # For Q1 and Q2 - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size * 2, - bias=False, - ) - - # For K1 and K2 - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, - bias=False, - ) + self.split_heads = config.split_heads + + if config.split_heads: + # Split heads mode + assert ( + self.base_num_heads % 2 == 0 + ), "Number of heads must be even for splitting" + self.heads_per_component = self.base_num_heads // 2 + + # Single projections + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + bias=False, + ) + else: + # Double projection mode + self.heads_per_component = self.base_num_heads + + # Double-sized projections + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size * 2, + bias=False, + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, + bias=False, + ) # Single V projection self.v_proj = nn.Linear( @@ -125,8 +150,14 @@ def __init__( self.rotary_emb = LlamaRotaryEmbedding(config=config) sublayer_norm = getattr(config, "sublayer_norm", True) + + if self.split_heads: + subln_dim = 2 * self.head_dim + else: + subln_dim = self.head_dim + self.subln = ( - LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5) + LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5) if sublayer_norm else nn.Identity() ) @@ -167,7 +198,10 @@ def forward( k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Reshape V - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -177,6 +211,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -192,8 +230,6 @@ def forward( v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) # Calculate attention scores for both parts - # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor - # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) @@ -307,13 +343,18 @@ def forward( k1, k2 = kp.chunk(2, dim=-1) # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape V - v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -323,6 +364,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -468,13 +513,18 @@ def forward( k1, k2 = kp.chunk(2, dim=-1) # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + # Reshape V - v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + if self.split_heads: + v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) + else: + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) # Apply rotary embeddings if position_embeddings is None: @@ -484,6 +534,10 @@ def forward( else: cos, sin = position_embeddings + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -506,20 +560,54 @@ def forward( # Calculate attention using Flash Attention dropout_p = self.attention_dropout if self.training else 0.0 - attn1 = flash_attn_func( - q1, - k1, - v, - dropout_p=dropout_p, - causal=True, - ) - attn2 = flash_attn_func( - q2, - k2, - v, - dropout_p=dropout_p, - causal=True, - ) + if self.split_heads: + v1, v2 = v.chunk(2, dim=-1) + attn11 = flash_attn_func( + q1, + k1, + v1, + dropout_p=dropout_p, + causal=True, + ) + attn12 = flash_attn_func( + q1, + k1, + v2, + dropout_p=dropout_p, + causal=True, + ) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func( + q2, + k2, + v1, + dropout_p=dropout_p, + causal=True, + ) + attn22 = flash_attn_func( + q2, + k2, + v2, + dropout_p=dropout_p, + causal=True, + ) + attn2 = torch.cat([attn21, attn22], dim=-1) + else: + attn1 = flash_attn_func( + q1, + k1, + v, + dropout_p=dropout_p, + causal=True, + ) + attn2 = flash_attn_func( + q2, + k2, + v, + dropout_p=dropout_p, + causal=True, + ) attn1 = attn1.transpose(1, 2) attn2 = attn2.transpose(1, 2) diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py index 4349287bd..84e5fdaa1 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py @@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions( assert (output_dir / "model.safetensors").exists() assert (output_dir / "config.json").exists() assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): + output_dir = tmp_path / "converted" + base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()