Skip to content

Commit

Permalink
Merge pull request #34 from ModelTC/temp_eval2
Browse files Browse the repository at this point in the history
add down stream evaluation
  • Loading branch information
llmc-reviewer authored Aug 21, 2024
2 parents 080d43b + 6fb9c1d commit b806d50
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "lm-evaluation-harness"]
path = lm-evaluation-harness
url = https://github.com/ModelTC/llmc.git
branch = lm-eval
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

## News

- **Aug 22, 2024:** 🔥We support lots of small language models, including current SOTA [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)(see [Supported Model List](#supported-model-list)). Additionally, we also support down stream task evaluation through our modified [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 🤗. Specifically, people can first employ `save_trans` mode(see `save` part in [Configuration](#configuration)) to save a weight modified model. After obtaining the transformed model, they can directly evaluate the quantized model referring to [run_lm_eval.sh](scripts/run_lm_eval.sh). More details can be found in [here](https://llmc-en.readthedocs.io/en/latest/advanced/model_test.html).

- **Jul 23, 2024:** 🍺🍺🍺 We release a brand new version benchmark paper:

[**LLMC: Benchmarking Large Language Model Quantization with a Versatile Compression Toolkit**](https://arxiv.org/abs/2405.06001v2).
Expand Down Expand Up @@ -247,6 +249,20 @@ To help users design their configs, we now explain some universal configurations

✅ [LLaVA](https://github.com/haotian-liu/LLaVA)

✅ [InternLM2.5](https://huggingface.co/internlm)

✅ [StableLM](https://github.com/Stability-AI/StableLM)

✅ [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2)

✅ [Phi2](https://huggingface.co/microsoft/phi-2)

✅ [Phi 1.5](https://huggingface.co/microsoft/phi-1_5)

✅ [MiniCPM](https://github.com/OpenBMB/MiniCPM)

✅ [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)

You can add your own model type referring to files under `llmc/models/*.py`.

## Supported Algorithm List
Expand Down Expand Up @@ -308,6 +324,7 @@ We develop our code referring to the following repos:
- https://github.com/mobiusml/hqq
- [https://github.com/spcl/QuaRot](https://github.com/spcl/QuaRot)
- [https://github.com/locuslab/wanda](https://github.com/locuslab/wanda)
- [https://github.com/EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)

## Star History

Expand Down
23 changes: 23 additions & 0 deletions README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

## ニュース

- **2024年8月22日:** 🔥私たちは、現在の最先端技術である[SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)[サポートされているモデルリスト](#supported-model-list)を参照)を含む多くの小型言語モデルをサポートしています。さらに、改良された[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 🤗 を通じてダウンストリームタスクの評価もサポートしています。具体的には、まず `save_trans` モード([設定](#設定)`save` 部分を参照)を使用して、変更されたモデルの重みを保存します。変換後のモデルを取得した後、[run_lm_eval.sh](scripts/run_lm_eval.sh)を参照して量子化モデルを直接評価することができます。詳細は[こちら](https://llmc-en.readthedocs.io/en/latest/advanced/model_test.html)で確認できます。

- **2024 年 7 月 23 日:** 🍺🍺🍺 新しいバージョンのベンチマーク ペーパーをリリースします:

[**LLMC: 多用途の圧縮ツールキットを使用した大規模言語モデル量子化のベンチマーク**](https://arxiv.org/abs/2405.06001v2)
Expand Down Expand Up @@ -231,6 +233,26 @@

✅ [LLaVA](https://github.com/haotian-liu/LLaVA)

✅ [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral)

✅ [Qwen V2](https://github.com/QwenLM/Qwen2)

✅ [LLaVA](https://github.com/haotian-liu/LLaVA)

✅ [InternLM2.5](https://huggingface.co/internlm)

✅ [StableLM](https://github.com/Stability-AI/StableLM)

✅ [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2)

✅ [Phi2](https://huggingface.co/microsoft/phi-2)

✅ [Phi 1.5](https://huggingface.co/microsoft/phi-1_5)

✅ [MiniCPM](https://github.com/OpenBMB/MiniCPM)

✅ [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)

`llmc/models/*.py`の下のファイルを参照して、独自のモデルタイプを追加できます。

## サポートされているアルゴリズムリスト
Expand Down Expand Up @@ -292,6 +314,7 @@
- https://github.com/mobiusml/hqq
- [https://github.com/spcl/QuaRot](https://github.com/spcl/QuaRot)
- [https://github.com/locuslab/wanda](https://github.com/locuslab/wanda)
- [https://github.com/EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)

## スター履歴

Expand Down
27 changes: 24 additions & 3 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

## 新闻

- **2024年8月22日:** 🔥我们支持包括当前最先进的 [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)(请参阅 [支持的模型列表](#supported-model-list))在内的许多小型语言模型。此外,我们还通过修改后的[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) 🤗 支持下游任务评估。具体来说,人们可以首先使用`save_trans`模式(请参阅 [配置](#配置) 中的 `save` 部分)来保存修改后的模型权重。获取转换后的模型后,可以直接参考 [run_lm_eval.sh](scripts/run_lm_eval.sh)来评估量化模型。更多详情可在[这里](https://llmc-zhcn.readthedocs.io/en/latest/advanced/model_test.html#id2)找到。

- **2024 年 7 月 23 日:** 🍺🍺🍺 我们发布了全新版本的基准论文:

[**LLMC:使用多功能压缩工具包对大型语言模型量化进行基准测试**](https://arxiv.org/abs/2405.06001v2)
Expand Down Expand Up @@ -70,9 +72,7 @@
- 量化大型语言模型,如 Llama2-70B、OPT-175B,并在仅一个 A100/H100/H800 GPU上评估其 PPL💥。
- 为用户提供选择的最新的[与原论文代码仓库精度对齐](benchmark/align.md)的压缩算法,并且用户可以在一个大型语言模型上依次使用多个算法💥。
- 由我们工具通过特定压缩算法导出的转换模型(`save_trans`模式在`quant`部分的[配置](#配置))可以通过多个后端进行简单量化,得到经过特定压缩算法优化的模型,相应的后端可以进行推断💥。
- 我们的压缩模型(`save_lightllm`模式在`quant`部分的\[配置\](#

配置))具有较低的内存占用,可以直接通过[Lightllm](https://github.com/ModelTC/lightllm)进行推断💥。
- 我们的压缩模型(`save_lightllm`模式在`quant`部分的\[配置\](#配置))具有较低的内存占用,可以直接通过[Lightllm](https://github.com/ModelTC/lightllm)进行推断💥。

## 使用方式

Expand Down Expand Up @@ -227,6 +227,26 @@

✅ [LLaMA V3](https://huggingface.co/meta-llama)

✅ [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral)

✅ [Qwen V2](https://github.com/QwenLM/Qwen2)

✅ [LLaVA](https://github.com/haotian-liu/LLaVA)

✅ [InternLM2.5](https://huggingface.co/internlm)

✅ [StableLM](https://github.com/Stability-AI/StableLM)

✅ [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2)

✅ [Phi2](https://huggingface.co/microsoft/phi-2)

✅ [Phi 1.5](https://huggingface.co/microsoft/phi-1_5)

✅ [MiniCPM](https://github.com/OpenBMB/MiniCPM)

✅ [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)

你可以参考 `llmc/models/*.py` 下的文件添加你自己的模型类型。

## 支持的算法列表
Expand Down Expand Up @@ -287,6 +307,7 @@
- https://github.com/TimDettmers/bitsandbytes
- https://github.com/mobiusml/hqq
- [https://github.com/locuslab/wanda](https://github.com/locuslab/wanda)
- [https://github.com/EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)

## 星标历史

Expand Down
29 changes: 28 additions & 1 deletion docs/en/source/advanced/model_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ python run.py configs/eval_lightllm.py
```
When the model has completed the inference and metric calculations, we can get the evaluation results of the model. The output folder will be generated in the current directory, the logs subfolder will record the logs in the evaluation, and the summary subfile will record the accuracy of the measured data set

## Use of the lm-evaluation-harness evaluation tool

Besides the above-mentioned methods, we also recommend people use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). We have already integrated this tool in ours. After cloning the submodules of our llmc, people can refer to the following commands to evaluate the quantized model/full precision model:

```
export CUDA_VISIBLE_DEVICES=4,5,6,7
llmc=./llmc
lm_eval=./llmc/lm-evaluation-harness
export PYTHONPATH=$llmc:$PYTHONPATH
export PYTHONPATH=$llmc:$lm_eval:$PYTHONPATH
# Replace the config file (i.e., RTN with algorithm-transformed model path or notate quant with original model path)
# with the one you want to use. `--quarot` depends on the transformation algorithm used before.
accelerate launch --multi_gpu --num_processes 4 llmc/tools/llm_eval.py \
--config llmc/configs/quantization/RTN/rtn_quarot.yml \
--model hf \
--quarot \
--tasks lambada_openai,arc_easy \
--model_args parallelize=False \
--batch_size 64 \
--output_path ./save/lm_eval \
--log_samples
```

We preserve the command in [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). There are only two more arguments ``--config`` and ``--quarot``. The former is for loading the transformed model (saved by ``save_trans``) or the original hugginface model, depending on the model path. Otherwise, remove ``quant`` part in the config to perform evaluation for the full-precision model, and we only support RTN quant, where all related quantization granularities need to align with the setting of the transformed model. The latter is employed if the model is transformed by [QuaRot](https://arxiv.org/abs/2404.00456).

*Remark: Please cancel the paralleize (or paralleize=False) and pretrained=\* in ``--model_args`` for evaluation.*

## FAQ

**<font color=red> Q1 </font>**
Expand All @@ -169,7 +196,7 @@ The test accuracy of the Humaneval of the LLAMA model is too low

**<font color=green> Solution </font>**

You may need to delete the \n at the end of each entry in the Humaneval jsonl file in the dataset provided by OpenCompass and retest it
You may need to delete the \n at the end of each entry in the Humaneval json file in the dataset provided by OpenCompass and retest it

**<font color=red> Q3 </font>**

Expand Down
25 changes: 25 additions & 0 deletions docs/zh_cn/source/advanced/model_test.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,31 @@ python run.py configs/eval_lightllm.py
```
当模型完成推理和指标计算后,我们便可获得模型的评测结果。其中会在当前目录下生成output文件夹,logs子文件夹记录着评测中的日志,最后生成summary子文件会记录所测数据集的精度

## lm-evaluation-harness评测工具的使用

我们保留了[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)中的命令。只添加了两个参数``--config````--quarot``。前者用于加载由``save_trans``保存的转换模型或根据模型路径的原始huggingface模型。如果不使用``quant``部分,则配置中将移除该部分以对全精度模型进行评估,我们只支持RTN量化,其中所有相关的量化粒度需要与转换模型的设置对齐。如果模型经过[QuaRot](https://arxiv.org/abs/2404.00456)转换,则使用后者。

```
export CUDA_VISIBLE_DEVICES=4,5,6,7
llmc=./llmc
lm_eval=./llmc/lm-evaluation-harness
export PYTHONPATH=$llmc:$PYTHONPATH
export PYTHONPATH=$llmc:$lm_eval:$PYTHONPATH
# Replace the config file (i.e., RTN with algorithm-transformed model path or notate quant with original model path)
# with the one you want to use. `--quarot` is depend on the transformation algorithm used before.
accelerate launch --multi_gpu --num_processes 4 llmc/tools/llm_eval.py \
--config llmc/configs/quantization/RTN/rtn_quarot.yml \
--model hf \
--quarot \
--tasks lambada_openai,arc_easy \
--model_args parallelize=False \
--batch_size 64 \
--output_path ./save/lm_eval \
--log_samples
```

*备注:请在``--model_args``不使用pretrained=\*同时进行评估时取消并行化(或paralleize=False)。*

## 常见问题

**<font color=red> 问题1 </font>**
Expand Down
5 changes: 4 additions & 1 deletion llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def set_quant_config(self):

def replace_rotate_linears(self, block):
for n, m in block.named_modules():
if isinstance(m, nn.Linear) and ('down_proj' in n or 'o_proj' in n):
if isinstance(m, nn.Linear) and ('down_proj' in n
or 'o_proj' in n
or 'fc2' in n
or 'out_proj' in n):
subset = {'layers': {n: m}}
self.model.replace_module_subset(
RotateLinear,
Expand Down
7 changes: 7 additions & 0 deletions llmc/compression/quantization/quarot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import gc
import json
import os

import torch
import torch.nn as nn
Expand Down Expand Up @@ -30,6 +32,11 @@ def preprocess(self):
self.model.get_embed_layers()[0].weight,
):
logger.info('Tie weight! Copy embed_layer for head_layer!')
path = os.join(self.config.model.path, 'config.json')
with open(path, 'w') as f:
config = json.load(f)
config['tie_word_embeddings'] = False
json.dump(config, f, indent=4)
del self.model.get_head_layers()[0].weight
w = self.model.get_embed_layers()[0].weight.clone()
self.model.get_head_layers()[0].weight = nn.Parameter(w)
Expand Down
1 change: 1 addition & 0 deletions lm-evaluation-harness
Submodule lm-evaluation-harness added at 5ac2bb
16 changes: 16 additions & 0 deletions scripts/run_lm_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
export CUDA_VISIBLE_DEVICES=4,5,6,7
llmc=./llmc
lm_eval=./llmc/lm-evaluation-harness
export PYTHONPATH=$llmc:$PYTHONPATH
export PYTHONPATH=$llmc:$lm_eval:$PYTHONPATH
# Replace the config file (i.e., RTN with algorithm-transformed model path or notate quant with original model path)
# with the one you want to use. `--quarot` is depend on the transformation algorithm used before.
accelerate launch --multi_gpu --num_processes 4 llmc/tools/llm_eval.py \
--config llmc/configs/quantization/RTN/rtn_quarot.yml \
--model hf \
--quarot \
--tasks lambada_openai,arc_easy \
--model_args parallelize=False \
--batch_size 64 \
--output_path ./save/lm_eval \
--log_samples
45 changes: 45 additions & 0 deletions tools/llm_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
import copy
import functools
import gc
import os
import sys

import torch
import yaml
from loguru import logger
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM

sys.path.append(os.path.join(os.path.dirname(__file__), '../lm-evaluation-harness'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
import lm_eval.__main__ as lm_eval
import torch.nn as nn
from easydict import EasyDict

from llmc.compression.quantization import FakeQuantLinear, Quantizer
from llmc.compression.quantization.base_blockwise_quantization import \
BaseBlockwiseQuantization
from llmc.compression.quantization.module_utils import LlmcRMSNorm
from llmc.data import BaseDataset, BaseTokenizer
from llmc.eval import PerplexityEval
from llmc.models import *
from llmc.utils import check_config, mkdirs, seed_all
from llmc.utils.registry_factory import ALGO_REGISTRY, MODEL_REGISTRY

if __name__ == '__main__':
logger.warning('This script only supports transformed/original model type!')
parser = lm_eval.setup_parser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--quarot', action='store_true')
args = parser.parse_args()
with open(args.config, 'r') as file:
config = yaml.safe_load(file)
config = EasyDict(config)
args.config = config
if 'pretrained' not in args.model_args:
if 'paralleize=True' in args.model_args:
logger.error("Please remove 'paralleize=True' from model_args!")
sys.exit(1)
args.model_args += ',pretrained=' + config.model.path
lm_eval.cli_evaluate(args)
2 changes: 1 addition & 1 deletion tools/quant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM

sys.path.append('..')
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
import matplotlib.pyplot as plt
import torch.nn as nn

Expand Down

0 comments on commit b806d50

Please sign in to comment.