From 6fb71967ed9cd7553318ca6cddd30c2a5d606159 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 30 Nov 2024 12:50:38 -0800 Subject: [PATCH] Handle tied weights and aliases differently --- mergekit/_data/architectures/bert-masked-lm.json | 3 ++- mergekit/_data/architectures/distilbert-masked-lm.json | 3 ++- mergekit/_data/architectures/gemma2.json | 5 ++++- mergekit/_data/architectures/gptbigcode.json | 4 +++- mergekit/_data/architectures/internlm2.json | 3 ++- mergekit/_data/architectures/llama.json | 5 ++++- mergekit/_data/architectures/mamba.json | 5 ++++- mergekit/_data/architectures/phi3-small.json | 5 +++-- mergekit/_data/architectures/qwen2.json | 3 ++- mergekit/_data/architectures/roberta-masked-lm.json | 7 +++++-- mergekit/_data/architectures/solar.json | 3 ++- mergekit/_data/architectures/starcoder2.json | 5 ++++- mergekit/architecture.py | 4 ++++ mergekit/io/tasks.py | 6 +++++- mergekit/plan.py | 5 ++++- 15 files changed, 50 insertions(+), 16 deletions(-) diff --git a/mergekit/_data/architectures/bert-masked-lm.json b/mergekit/_data/architectures/bert-masked-lm.json index 3b0620fb..d6430e40 100644 --- a/mergekit/_data/architectures/bert-masked-lm.json +++ b/mergekit/_data/architectures/bert-masked-lm.json @@ -44,7 +44,8 @@ }, { "name": "cls.predictions.decoder.weight", - "aliases": [ + "optional": true, + "tied_names": [ "bert.embeddings.word_embeddings.weight" ], "is_embed": true diff --git a/mergekit/_data/architectures/distilbert-masked-lm.json b/mergekit/_data/architectures/distilbert-masked-lm.json index 6828cca2..1a079811 100644 --- a/mergekit/_data/architectures/distilbert-masked-lm.json +++ b/mergekit/_data/architectures/distilbert-masked-lm.json @@ -40,7 +40,8 @@ { "name": "vocab_projector.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "distilbert.embeddings.word_embeddings.weight" ] }, diff --git a/mergekit/_data/architectures/gemma2.json b/mergekit/_data/architectures/gemma2.json index 0c6372f0..52505245 100644 --- a/mergekit/_data/architectures/gemma2.json +++ b/mergekit/_data/architectures/gemma2.json @@ -54,7 +54,10 @@ { "name": "lm_head.weight", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/gptbigcode.json b/mergekit/_data/architectures/gptbigcode.json index 4b086278..c12bac5c 100644 --- a/mergekit/_data/architectures/gptbigcode.json +++ b/mergekit/_data/architectures/gptbigcode.json @@ -21,7 +21,9 @@ }, { "name": "lm_head.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "transformer.wte.weight" ] } diff --git a/mergekit/_data/architectures/internlm2.json b/mergekit/_data/architectures/internlm2.json index 057bc649..888faa48 100644 --- a/mergekit/_data/architectures/internlm2.json +++ b/mergekit/_data/architectures/internlm2.json @@ -16,7 +16,8 @@ { "name": "output.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.tok_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/llama.json b/mergekit/_data/architectures/llama.json index 7106806b..00918a2c 100644 --- a/mergekit/_data/architectures/llama.json +++ b/mergekit/_data/architectures/llama.json @@ -74,7 +74,10 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/mamba.json b/mergekit/_data/architectures/mamba.json index b3727dba..1c473532 100644 --- a/mergekit/_data/architectures/mamba.json +++ b/mergekit/_data/architectures/mamba.json @@ -16,7 +16,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["backbone.embeddings.weight"] + "optional": true, + "tied_names": [ + "backbone.embeddings.weight" + ] } ], "num_layers_config_key": "num_hidden_layers", diff --git a/mergekit/_data/architectures/phi3-small.json b/mergekit/_data/architectures/phi3-small.json index 7b3a1e80..f27dfac4 100644 --- a/mergekit/_data/architectures/phi3-small.json +++ b/mergekit/_data/architectures/phi3-small.json @@ -12,8 +12,9 @@ "post_weights": [ { "name": "lm_head.weight", - "is_embed":true, - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] }, diff --git a/mergekit/_data/architectures/qwen2.json b/mergekit/_data/architectures/qwen2.json index 638b3630..c7131523 100644 --- a/mergekit/_data/architectures/qwen2.json +++ b/mergekit/_data/architectures/qwen2.json @@ -16,7 +16,8 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] } diff --git a/mergekit/_data/architectures/roberta-masked-lm.json b/mergekit/_data/architectures/roberta-masked-lm.json index 492127a5..1aae76a1 100644 --- a/mergekit/_data/architectures/roberta-masked-lm.json +++ b/mergekit/_data/architectures/roberta-masked-lm.json @@ -8,7 +8,8 @@ "name": "roberta.embeddings.position_embeddings.weight" }, { - "name": "roberta.embeddings.word_embeddings.weight" + "name": "roberta.embeddings.word_embeddings.weight", + "is_embed": true }, { "name": "roberta.embeddings.token_type_embeddings.weight" @@ -43,7 +44,9 @@ }, { "name": "lm_head.decoder.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "roberta.embeddings.word_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/solar.json b/mergekit/_data/architectures/solar.json index 7bd6a751..78fd5998 100644 --- a/mergekit/_data/architectures/solar.json +++ b/mergekit/_data/architectures/solar.json @@ -73,7 +73,8 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.lm_head.weight" ] } diff --git a/mergekit/_data/architectures/starcoder2.json b/mergekit/_data/architectures/starcoder2.json index 851fdd1a..c2266899 100644 --- a/mergekit/_data/architectures/starcoder2.json +++ b/mergekit/_data/architectures/starcoder2.json @@ -13,7 +13,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["model.embed_tokens.weight"] + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] }, { "name": "model.norm.bias" diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 4c7b4625..40872160 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. + tied_names (Optional[List[str]]): + List of names for weights that are tied to this weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. """ @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True): input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False + tied: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None head_split: Literal[None, "input", "output"] = None is_kq: Optional[bool] = False diff --git a/mergekit/io/tasks.py b/mergekit/io/tasks.py index 70dffc41..499ad4c0 100644 --- a/mergekit/io/tasks.py +++ b/mergekit/io/tasks.py @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]): device: Optional[str] = None optional: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None def arguments(self) -> Dict[str, Task]: return {} def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]: - all_names = [self.tensor] + list(self.aliases or []) + all_names = ( + [self.tensor] + list(self.aliases or []) + list(self.tied_names or []) + ) for name in all_names: if name in loader.index.tensor_paths: return name @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]: device=self.device, optional=wi.optional, aliases=wi.aliases, + tied_names=wi.tied_names, ) for (model, wi) in self.weight_info.items() } diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..5b34eddc 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -139,7 +139,10 @@ def plan_tensor( any_weight = False for model, w_in in zip(models, weights_in): index = LoaderCache().get(model).index - if w_in.name in index.tensor_paths: + if any( + name in index.tensor_paths + for name in [w_in.name] + (w_in.aliases or []) + ): any_weight = True break