Skip to content

Commit

Permalink
Add marlin to scripts/eval.py
Browse files Browse the repository at this point in the history
  • Loading branch information
moeiniamir committed Apr 16, 2024
1 parent 996b175 commit 4baf419
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
12 changes: 12 additions & 0 deletions scripts/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,15 @@ When formatting samples, `prompt_string` is prepended to the beginning, then `nu
Thus the structure of each question's preamble is `prompt | few shot examples | context | continuation delimiter`. The continuation (aka choices for MC) is then tokenized separately and the tokens of the preamble and tokens of the continuation are concatenated. It is important to note that if the continuation delimiter has a trailing space, it is stripped and instead prepended to the continuation. Furthermore, if the continuation does not have a leading space, one will be prepended.

----
# Evaluation using [Marlin](https://github.com/IST-DASLab/marlin)
Add `marlin_path` to model's yaml similar to below:

```yaml
model:
name: marlin-demo
pretrained_model_name_or_path: ${model_name_or_path}
pretrained: true
marlin_path: ${marlin_path}
```

And run the gauntlet as you normally would.
45 changes: 45 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,37 @@ def evaluate_model(
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


###############
import marlin

# Save checkpoint name here since passing around extra args seems to confuse the eval harness
MARLIN_CHECKPOINT = ''

def get_llama_marlin(name, *args, **kwargs):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(name, torch_dtype='auto')
# Not really sure why this is sometimes > 1, but it messes up quantized inference ...
# Fortunately, just setting it to 1 doesn't seem to affect standard inference
model.config.pretraining_tp = 1
def name_filter(n):
if 'q_proj' in n or 'k_proj' in n or 'v_proj' in n or 'o_proj' in n:
return True
if 'mlp.gate_proj' in n or 'mlp.up_proj' in n or 'mlp.down_proj' in n:
return True
return False
groupsize = -1 if MARLIN_CHECKPOINT.endswith('marlin') else 128
marlin.replace_linear(model, name_filter, groupsize=groupsize)
model.load_state_dict(torch.load(MARLIN_CHECKPOINT))
return model
###############


def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
# Run user provided code if specified
code_paths = pop_config(cfg,
Expand Down Expand Up @@ -292,6 +323,20 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
eval_gauntlet_config)

for model_cfg in model_configs:
original_from_pretrained = None
if model_cfg.model.get('marlin_path', False):
global MARLIN_CHECKPOINT
MARLIN_CHECKPOINT = model_cfg.model.get('marlin_path', '')
# Overwrite model load with marlin load
import transformers
if original_from_pretrained is None:
original_from_pretrained = transformers.AutoModelForCausalLM.from_pretrained
transformers.AutoModelForCausalLM.from_pretrained = staticmethod(get_llama_marlin)
else:
if original_from_pretrained is not None:
import transformers
transformers.AutoModelForCausalLM.from_pretrained = original_from_pretrained

(trainer, logger_keys, eval_gauntlet_callback,
eval_gauntlet_df) = evaluate_model(
model_cfg=model_cfg,
Expand Down

0 comments on commit 4baf419

Please sign in to comment.