Skip to content

Commit

Permalink
Activation based merging - copied over from wip-zipit branch (#365)
Browse files Browse the repository at this point in the history
# What is this? 
This PR introduces a way to merge two models via their activations and
hidden states on a tiny sample of data.
This method uses these activations and hidden states to form correlation
matrices to then generate permutation and inverse permutation matrices
for weights in each model and then combines them

This PR consists of three main scripts
1. the first one generates the activation/hidden state for each space
2. a permutation and inverse permutation pair is generated for each
space
3. based on each space and the connected weights, the permutation and/or
inverse permutation is applied to each weight and then the weights are
combined

# Assumptions
The models to be merged are of the same architecture and equal
block/layer count

# Things that couldn't make into the final PR
on-the-fly handling of models with grouped query attention. This hasn't
been tested enough for this release but will be in the near future. For
now, users will have to resort to using this script first:

## Note:
Because this was copied over from another branch (`wip-zipit`) @shamanez
's contributions to the PR is missing, so this is explicit
acknowledgement that @shamanez has worked on this PR alongside other
authors
  • Loading branch information
metric-space authored Jul 19, 2024
1 parent 5fa7782 commit 6447a85
Show file tree
Hide file tree
Showing 6 changed files with 773 additions and 34 deletions.
59 changes: 25 additions & 34 deletions mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,84 +8,75 @@
{
"name": "model.embed_tokens.weight",
"is_embed": true,
"output_space": "h_0"
"output_space": "running_residual"
}
],
"num_layers_config_key": "num_hidden_layers",
"layer_templates": {
"weights": [
{
"name": "model.layers.${layer_index}.input_layernorm.weight",
"input_space": "h_${layer_index}"
"input_space": "running_residual"
},
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_qk_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_qk_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_v_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_v_${layer_index}",
"head_split": "output"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight",
"input_space": "attn_v_${layer_index}",
"output_space": "post_attn_${layer_index}"
"output_space": "running_residual",
"head_split": "input"
},
{
"name": "model.layers.${layer_index}.post_attention_layernorm.weight",
"input_space": "h_a_${layer_index}"
"input_space": "running_residual"
},
{
"name": "model.layers.${layer_index}.mlp.up_proj.weight",
"input_space": "h_a_${layer_index}",
"input_space": "running_residual",
"output_space": "up_${layer_index}"
},
{
"name": "model.layers.${layer_index}.mlp.gate_proj.weight",
"input_space": "h_a_${layer_index}",
"input_space": "running_residual",
"output_space": "up_${layer_index}"
},
{
"name": "model.layers.${layer_index}.mlp.down_proj.weight",
"input_space": "up_${layer_index}",
"output_space": "post_mlp_${layer_index}"
}
],
"procedural_spaces": [
{
"name": "h_a_${layer_index}",
"type": "residual",
"inputs": [
"h_${layer_index}",
"post_attn_${layer_index}"
]
},
{
"name": "h_${layer_index+1}",
"type": "residual",
"inputs": [
"h_a_${layer_index}",
"post_mlp_${layer_index}"
]
"output_space": "running_residual"
}
]
},
"post_weights": [
{
"name": "model.norm.weight",
"input_space": "h_${num_layers}"
"input_space": "running_residual"
},
{
"name": "lm_head.weight",
"input_space": "h_${num_layers}",
"is_embed": true
"input_space": "running_residual",
"is_embed":true,
"aliases": [
"model.lm_head.weight"
]
}
]
}
2 changes: 2 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class WeightInfo(BaseModel, frozen=True):
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False


class ProceduralSpaceInfo(BaseModel, frozen=True):
Expand Down
171 changes: 171 additions & 0 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import os
from typing import Optional

import click
import safetensors.torch
import torch
import tqdm
from transformers import AutoTokenizer

from mergekit.architecture import get_architecture_info
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io.tasks import LoaderCache
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions, add_merge_options


@click.command("mergekit-activation-based-merge")
@click.argument("model_path", type=str)
@click.argument("secondary_model_path", type=str)
@click.argument("merge_unmerge_directory", type=str)
@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
@click.option(
"--dtype",
type=str,
default="float16",
help="Data type to convert weights to",
)
@click.option(
"--device",
"-d",
type=str,
default="cuda",
help="Device to compute on (default: cuda)",
)
@add_merge_options
def main(
model_path: str,
secondary_model_path,
merge_unmerge_directory: str,
out_path: str,
dtype: Optional[str],
device: Optional[str],
merge_options: MergeOptions,
):
model = ModelReference.model_validate(model_path)
secondary_model = ModelReference.model_validate(secondary_model_path)

dtype = dtype_from_name(dtype) if dtype else None

cache = LoaderCache()
cache.lazy_unpickle = merge_options.lazy_unpickle
cache.hf_cache_dir = merge_options.transformers_cache

for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"):
cache.get(m)

writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)

model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_arch_info = get_architecture_info(
model.config(trust_remote_code=merge_options.trust_remote_code)
)

loader_1 = cache.get(model)
loader_2 = cache.get(secondary_model)

os.makedirs(out_path, exist_ok=True)

merge_unmerge_dictionary = {}
# load files from merge_unmerge_directory
spaces = [
f.split("_unmerge")[0]
for f in os.listdir(merge_unmerge_directory)
if "_unmerge" in f
]
for i in spaces:
logging.info(f"Loading merge/unmerge tensors for {i}")
m = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
device=device,
)
u = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
device=device,
)
merge_unmerge_dictionary[i] = (
m[i].to(device, dtype=dtype),
u[i].to(device, dtype=dtype),
)

for weight_info in model_arch_info.all_weights(config=model_config):
merge_matrix, unmerge_matrix = None, None

if weight_info.input_space in merge_unmerge_dictionary:
_, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
unmerge_matrix = unmerge_matrix.chunk(2, dim=0)

if weight_info.output_space in merge_unmerge_dictionary:
merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
merge_matrix = merge_matrix.chunk(2, dim=1)

original_w = loader_1.get_tensor(weight_info.name, device=device)
original_w2 = loader_2.get_tensor(weight_info.name, device=device)

if dtype is not None:
original_w = original_w.to(dtype=dtype)
original_w2 = original_w2.to(dtype=dtype)

w = torch.clone(original_w)
w2 = torch.clone(original_w2)

if not merge_matrix and not unmerge_matrix:
logging.warning(
f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix"
)

if merge_matrix is not None:
if weight_info.is_embed:
w = (merge_matrix[0] @ w.T).T
w2 = (merge_matrix[1] @ w2.T).T
else:
w = merge_matrix[0] @ w
w2 = merge_matrix[1] @ w2

if unmerge_matrix is not None:
w = w @ unmerge_matrix[0]
w2 = w2 @ unmerge_matrix[1]

# check if weights have not mutated, if yes then shoot warning
if torch.allclose(original_w, w):
logging.warning(
f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 1 has mutated during merge"
)

if torch.allclose(original_w2, w2):
logging.warning(
f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 2 has mutated during merge"
)

# average weights and save them
if merge_matrix:
w = w + w2
else:
w = (w + w2) / 2
writer.save_tensor(weight_info.name, w)
writer.finalize()

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.save_pretrained(out_path, safe_serialization=True)

# write config
model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code)
if dtype:
model_out_config.torch_dtype = dtype
model_out_config.save_pretrained(out_path)


main()
Loading

0 comments on commit 6447a85

Please sign in to comment.