Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation based merging - copied over from wip-zipit branch #365

Merged
merged 13 commits into from
Jul 19, 2024
Merged

Conversation

metric-space
Copy link
Contributor

@metric-space metric-space commented Jul 10, 2024

What is this?

This PR introduces a way to merge two models via their activations and hidden states on a tiny sample of data.
This method uses these activations and hidden states to form correlation matrices to then generate permutation and inverse permutation matrices for weights in each model and then combines them

This PR consists of three main scripts

  1. the first one generates the activation/hidden state for each space
  2. a permutation and inverse permutation pair is generated for each space
  3. based on each space and the connected weights, the permutation and/or inverse permutation is applied to each weight and then the weights are combined

Assumptions

The models to be merged are of the same architecture and equal block/layer count

Testing

To test this we need to get the mergekit/scripts/random_permuter.py script from the branch rope-alignment

(see below the bash stuff for the final inference script i.e test_by_gen.py)

git clone --branch rope-alignment https://github.com/arcee-ai/mergekit.git  permuter
python3  -mvenv permuter 
cd permuter && source bin/activate
pip install -e .
huggingface-cli login
python mergekit/scripts/permute_random.py meta-llama/Llama-2-7b-chat-hf --permute-head-dims  --out-path random2
cp $HF_HOME/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/{tokenizer*,special_tokens_map.json} random2
deactivate 
cd -

git clone --branch abm https://github.com/arcee-ai/mergekit.git  mergekit
python3  -mvenv mergekit 
cd mergekit && source bin/activate
pip install -e .
mkdir delete_dump_output/
python mergekit/scripts/ABM/extract_activations.py  meta-llama/Llama-2-7b-chat-hf -o ./delete_dump_output  -d arcee-ai/pmc-test-perplexity  -s 8  -c text  -u test  --device cpu
python mergekit/scripts/ABM/extract_activations.py /home/ubuntu/data/permuter/random2 -o ./delete_dump_output  -d arcee-ai/pmc-test-perplexity  -s 8  -c text  -u test  --device cpu
mkdir delete_m_v_out
python mergekit/scripts/ABM/extract_permutation_matrices.py ./delete_dump_output/meta-llama_Llama-2-7b-chat-hf_features.bin ./delete_dump_output/_home_ubuntu_data_permuter_random2_features.bin   --model_path  meta-llama/Llama-2-7b-chat-hf --out_path ./delete_m_v_out
mkdir new_model/
python mergekit/scripts/ABM/activations_based_merge.py  meta-llama/Llama-2-7b-chat-hf  /home/ubuntu/data/permuter/random2  delete_m_v_out -o new_model
python test_by_gen.py new_model

(test_by_gen.py)

import sys

import torch
from transformers import pipeline

model = sys.argv[1] 

pipe = pipeline(
    "text-generation", model=model, torch_dtype=torch.bfloat16, device_map="auto"
)

# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {
        "role": "system",
        "content": "You are a helpful chatbot who pretends to be Richard Feynman",
    },
    {"role": "user", "content": "Could you tell me about the challenger disaster ?"},
]
prompt = pipe.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
outputs = pipe(
    prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95
)
print(outputs[0]["generated_text"])

If all goes well, you should see the following (or something along the lines of the following)
Screenshot from 2024-07-06 21-46-24

Things that couldn't make into the final PR

on-the-fly handling of models with grouped query attention. This hasn't been tested enough for this release but will be in the near future. For now, users will have to resort to using this script first:

Note:

Because this was copied over from another branch (wip-zipit) @shamanez 's contributions to the PR is missing, so this is explicit acknowledgement that @shamanez has worked on this PR alongside other authors

@metric-space metric-space requested a review from cg123 July 10, 2024 21:08
Copy link
Collaborator

@cg123 cg123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! One suggestion for a followup, but this should be good to merge.


# average weights and save them
if merge_matrix:
w = w + w2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A decent next step for this might be to separate this out - if it just output two modified models then we could feed those directly in to mergekit-yaml and be able to try out merge methods other than linear without needing to bring that infrastructure into the script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. I'll be sure to add this in as a follow-up PR

@metric-space metric-space merged commit 6447a85 into main Jul 19, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants