Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation based merging - copied over from wip-zipit branch #365

Merged
merged 13 commits into from
Jul 19, 2024
Merged
59 changes: 25 additions & 34 deletions mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,84 +8,75 @@
{
"name": "model.embed_tokens.weight",
"is_embed": true,
"output_space": "h_0"
"output_space": "running_residual"
}
],
"num_layers_config_key": "num_hidden_layers",
"layer_templates": {
"weights": [
{
"name": "model.layers.${layer_index}.input_layernorm.weight",
"input_space": "h_${layer_index}"
"input_space": "running_residual"
},
{
"name": "model.layers.${layer_index}.self_attn.q_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_qk_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.k_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_qk_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_qk_${layer_index}",
"head_split": "output",
"is_kq": true
},
{
"name": "model.layers.${layer_index}.self_attn.v_proj.weight",
"input_space": "h_${layer_index}",
"output_space": "attn_v_${layer_index}"
"input_space": "running_residual",
"output_space": "attn_v_${layer_index}",
"head_split": "output"
},
{
"name": "model.layers.${layer_index}.self_attn.o_proj.weight",
"input_space": "attn_v_${layer_index}",
"output_space": "post_attn_${layer_index}"
"output_space": "running_residual",
"head_split": "input"
},
{
"name": "model.layers.${layer_index}.post_attention_layernorm.weight",
"input_space": "h_a_${layer_index}"
"input_space": "running_residual"
},
{
"name": "model.layers.${layer_index}.mlp.up_proj.weight",
"input_space": "h_a_${layer_index}",
"input_space": "running_residual",
"output_space": "up_${layer_index}"
},
{
"name": "model.layers.${layer_index}.mlp.gate_proj.weight",
"input_space": "h_a_${layer_index}",
"input_space": "running_residual",
"output_space": "up_${layer_index}"
},
{
"name": "model.layers.${layer_index}.mlp.down_proj.weight",
"input_space": "up_${layer_index}",
"output_space": "post_mlp_${layer_index}"
}
],
"procedural_spaces": [
{
"name": "h_a_${layer_index}",
"type": "residual",
"inputs": [
"h_${layer_index}",
"post_attn_${layer_index}"
]
},
{
"name": "h_${layer_index+1}",
"type": "residual",
"inputs": [
"h_a_${layer_index}",
"post_mlp_${layer_index}"
]
"output_space": "running_residual"
}
]
},
"post_weights": [
{
"name": "model.norm.weight",
"input_space": "h_${num_layers}"
"input_space": "running_residual"
},
{
"name": "lm_head.weight",
"input_space": "h_${num_layers}",
"is_embed": true
"input_space": "running_residual",
"is_embed":true,
"aliases": [
"model.lm_head.weight"
]
}
]
}
2 changes: 2 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class WeightInfo(BaseModel, frozen=True):
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False


class ProceduralSpaceInfo(BaseModel, frozen=True):
Expand Down
171 changes: 171 additions & 0 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import logging
import os
from typing import Optional

import click
import safetensors.torch
import torch
import tqdm
from transformers import AutoTokenizer

from mergekit.architecture import get_architecture_info
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io.tasks import LoaderCache
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions, add_merge_options


@click.command("mergekit-activation-based-merge")
@click.argument("model_path", type=str)
@click.argument("secondary_model_path", type=str)
@click.argument("merge_unmerge_directory", type=str)
@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
@click.option(
"--dtype",
type=str,
default="float16",
help="Data type to convert weights to",
)
@click.option(
"--device",
"-d",
type=str,
default="cuda",
help="Device to compute on (default: cuda)",
)
@add_merge_options
def main(
model_path: str,
secondary_model_path,
merge_unmerge_directory: str,
out_path: str,
dtype: Optional[str],
device: Optional[str],
merge_options: MergeOptions,
):
model = ModelReference.model_validate(model_path)
secondary_model = ModelReference.model_validate(secondary_model_path)

dtype = dtype_from_name(dtype) if dtype else None

cache = LoaderCache()
cache.lazy_unpickle = merge_options.lazy_unpickle
cache.hf_cache_dir = merge_options.transformers_cache

for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"):
cache.get(m)

writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)

model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
model_arch_info = get_architecture_info(
model.config(trust_remote_code=merge_options.trust_remote_code)
)

loader_1 = cache.get(model)
loader_2 = cache.get(secondary_model)

os.makedirs(out_path, exist_ok=True)

merge_unmerge_dictionary = {}
# load files from merge_unmerge_directory
spaces = [
f.split("_unmerge")[0]
for f in os.listdir(merge_unmerge_directory)
if "_unmerge" in f
]
for i in spaces:
logging.info(f"Loading merge/unmerge tensors for {i}")
m = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
device=device,
)
u = safetensors.torch.load_file(
os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
device=device,
)
merge_unmerge_dictionary[i] = (
m[i].to(device, dtype=dtype),
u[i].to(device, dtype=dtype),
)

for weight_info in model_arch_info.all_weights(config=model_config):
merge_matrix, unmerge_matrix = None, None

if weight_info.input_space in merge_unmerge_dictionary:
_, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
unmerge_matrix = unmerge_matrix.chunk(2, dim=0)

if weight_info.output_space in merge_unmerge_dictionary:
merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
merge_matrix = merge_matrix.chunk(2, dim=1)

original_w = loader_1.get_tensor(weight_info.name, device=device)
original_w2 = loader_2.get_tensor(weight_info.name, device=device)

if dtype is not None:
original_w = original_w.to(dtype=dtype)
original_w2 = original_w2.to(dtype=dtype)

w = torch.clone(original_w)
w2 = torch.clone(original_w2)

if not merge_matrix and not unmerge_matrix:
logging.warning(
f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix"
)

if merge_matrix is not None:
if weight_info.is_embed:
w = (merge_matrix[0] @ w.T).T
w2 = (merge_matrix[1] @ w2.T).T
else:
w = merge_matrix[0] @ w
w2 = merge_matrix[1] @ w2

if unmerge_matrix is not None:
w = w @ unmerge_matrix[0]
w2 = w2 @ unmerge_matrix[1]

# check if weights have not mutated, if yes then shoot warning
if torch.allclose(original_w, w):
logging.warning(
f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 1 has mutated during merge"
)

if torch.allclose(original_w2, w2):
logging.warning(
f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge"
)
else:
logging.warning(
f"✅ Weight {weight_info.name} for model 2 has mutated during merge"
)

# average weights and save them
if merge_matrix:
w = w + w2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A decent next step for this might be to separate this out - if it just output two modified models then we could feed those directly in to mergekit-yaml and be able to try out merge methods other than linear without needing to bring that infrastructure into the script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. I'll be sure to add this in as a follow-up PR

else:
w = (w + w2) / 2
writer.save_tensor(weight_info.name, w)
writer.finalize()

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.save_pretrained(out_path, safe_serialization=True)

# write config
model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code)
if dtype:
model_out_config.torch_dtype = dtype
model_out_config.save_pretrained(out_path)


main()
Loading
Loading