Skip to content

Commit

Permalink
Handle tied weights and aliases differently
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Nov 30, 2024
1 parent afe3780 commit 6fb7196
Show file tree
Hide file tree
Showing 15 changed files with 50 additions and 16 deletions.
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/bert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
},
{
"name": "cls.predictions.decoder.weight",
"aliases": [
"optional": true,
"tied_names": [
"bert.embeddings.word_embeddings.weight"
],
"is_embed": true
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/distilbert-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
{
"name": "vocab_projector.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"distilbert.embeddings.word_embeddings.weight"
]
},
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/gemma2.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
4 changes: 3 additions & 1 deletion mergekit/_data/architectures/gptbigcode.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
},
{
"name": "lm_head.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"transformer.wte.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/internlm2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "output.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.tok_embeddings.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/llama.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"optional": true
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
]
}
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/mamba.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["backbone.embeddings.weight"]
"optional": true,
"tied_names": [
"backbone.embeddings.weight"
]
}
],
"num_layers_config_key": "num_hidden_layers",
Expand Down
5 changes: 3 additions & 2 deletions mergekit/_data/architectures/phi3-small.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"post_weights": [
{
"name": "lm_head.weight",
"is_embed":true,
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/qwen2.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
}
Expand Down
7 changes: 5 additions & 2 deletions mergekit/_data/architectures/roberta-masked-lm.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"name": "roberta.embeddings.position_embeddings.weight"
},
{
"name": "roberta.embeddings.word_embeddings.weight"
"name": "roberta.embeddings.word_embeddings.weight",
"is_embed": true
},
{
"name": "roberta.embeddings.token_type_embeddings.weight"
Expand Down Expand Up @@ -43,7 +44,9 @@
},
{
"name": "lm_head.decoder.weight",
"aliases": [
"is_embed": true,
"optional": true,
"tied_names": [
"roberta.embeddings.word_embeddings.weight"
]
}
Expand Down
3 changes: 2 additions & 1 deletion mergekit/_data/architectures/solar.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
"name": "lm_head.weight",
"input_space": "running_residual",
"is_embed": true,
"aliases": [
"optional": true,
"tied_names": [
"model.lm_head.weight"
]
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/_data/architectures/starcoder2.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
{
"name": "lm_head.weight",
"is_embed": true,
"aliases": ["model.embed_tokens.weight"]
"optional": true,
"tied_names": [
"model.embed_tokens.weight"
]
},
{
"name": "model.norm.bias"
Expand Down
4 changes: 4 additions & 0 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True):
Indicates whether the weight can be omitted from a model.
aliases (Optional[List[str]]):
List of alternative names for the weight, if applicable.
tied_names (Optional[List[str]]):
List of names for weights that are tied to this weight, if applicable.
force_dtype (Optional[str]):
Mandatory dtype for the weight, if applicable.
"""
Expand All @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True):
input_space: Optional[str] = None
output_space: Optional[str] = None
optional: bool = False
tied: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
head_split: Literal[None, "input", "output"] = None
is_kq: Optional[bool] = False
Expand Down
6 changes: 5 additions & 1 deletion mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]):
device: Optional[str] = None
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
tied_names: Optional[Tuple[str, ...]] = None

def arguments(self) -> Dict[str, Task]:
return {}

def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
all_names = [self.tensor] + list(self.aliases or [])
all_names = (
[self.tensor] + list(self.aliases or []) + list(self.tied_names or [])
)
for name in all_names:
if name in loader.index.tensor_paths:
return name
Expand Down Expand Up @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]:
device=self.device,
optional=wi.optional,
aliases=wi.aliases,
tied_names=wi.tied_names,
)
for (model, wi) in self.weight_info.items()
}
Expand Down
5 changes: 4 additions & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def plan_tensor(
any_weight = False
for model, w_in in zip(models, weights_in):
index = LoaderCache().get(model).index
if w_in.name in index.tensor_paths:
if any(
name in index.tensor_paths
for name in [w_in.name] + (w_in.aliases or [])
):
any_weight = True
break

Expand Down

0 comments on commit 6fb7196

Please sign in to comment.