Skip to content

Commit

Permalink
LLM.int8() Refactoring: Part 1 (#1401)
Browse files Browse the repository at this point in the history
* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation

* Fix unintended change

* New naive mm_dequant kernel for row-major; cleanup

* fix

* int8 refactor: initial sparse decomp, cleanup

* Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup

* int8: inference optimizations, some cleanup

* int8: more tests passing, cleanup

* int8 - more cleanup, most tests passing

* int8: specify CUDA stream for int8 ops

* perf: reduce overhead from getting cudaStream ptr

* Mark some functions for deprecation.

* int8 sparse decomp: small perf improvement

* update setup.py

* Update bitsandbytes/autograd/_functions.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update bitsandbytes/research/autograd/_functions.py

Co-authored-by: Aarni Koskela <[email protected]>

* int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn

* int8 cleanup

* Ignore ruff rule ISC001 (incompatible with formatter)

* add comment

* int8 more cleanup

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* int8: rename / deprecate old fn signatures

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* type annotation

* format update

* Update bitsandbytes/research/autograd/_functions.py

Co-authored-by: Aarni Koskela <[email protected]>

* cleanup

* Add comment to explain division optimization

* more cleanup

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* Update bitsandbytes/functional.py

Co-authored-by: Aarni Koskela <[email protected]>

* cleanup

* Type annotations, cleanup

* remove unused kernels; improved type annotations

* small perf optimization for single-GPU systems

* small perf optimization for single-GPU systems

* update docstrings

* Improve docs and tests

* Update docstring

* Update test

* add benchmarking script

* test cleanup: add deprecated marker, move benchmarks out

* Add int8 dequant function; misc improvements

* int8 matmul fallback for inner dims not divisible by 4

* improve register usage of kInt8VectorQuant - especially for A100/H100

* disable fail-fast for package build

* maxwell compat

* ptxas verbose

* docs update

* doc update

* backward fix

* Bugfix sparse decomp

* Int8 fix for PEFT OLoRA init

* Fix test for deprecated spmm_coo

* test improvement

* doc update

* typo

* doc cleanup

* docs

* add inference benchmark script

* Add benchmarks, doc update

---------

Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
matthewdouglas and akx authored Dec 5, 2024
1 parent 7dca700 commit 81e6345
Show file tree
Hide file tree
Showing 39 changed files with 2,626 additions and 2,323 deletions.
30 changes: 15 additions & 15 deletions .github/scripts/build-cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done

if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi


output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
##
build-shared-libs-cuda:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/
Debug/
Release/
cmake-build-*/

# IDE local files
.vs/
.idea/

# Distribution / packaging
.Python
Expand Down
14 changes: 1 addition & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
Expand Down Expand Up @@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
Expand Down Expand Up @@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES})

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS)
if(NOT APPLE)
Expand Down Expand Up @@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)

if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()

target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
Expand Down
159 changes: 159 additions & 0 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Benchmarking

## Inference
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.

See the example script in
[inference_benchmark.py](inference_benchmark.py).

### Results (as of v0.45.0)

Our overall benchmarking results compared with v0.44.1 provide the following insights:
#### LLM.int8()
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.

#### NF4/FP4
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.

Summaries with the benchmarking results are provided below.

#### NVIDIA T4 16GB
<details>
<summary>Qwen 2.5 3B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
</details>

#### NVIDIA RTX 4090 24GB
<details>
<summary>Llama 3.1 8B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
</details>

<details>
<summary>Qwen 2.5 14B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
</details>


#### NVIDIA H100 80GB SXM
<details>
<summary>Llama 3.1 8B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
</details>

<details>
<summary>Qwen 2.5 32B Instruct</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| BF16 | 1 | 0.0508 | 19.67 |
| NF4 | 1 | 0.0707 | 14.14 |
| NF4+DQ | 1 | 0.0860 | 11.63 |
| INT8 | 1 | 0.1031 | 9.70 |
| INT8+Decomp | 1 | 0.1820 | 5.49 |
| BF16 | 8 | 0.0525 | 152.50 |
| NF4 | 8 | 0.1154 | 69.35 |
| NF4+DQ | 8 | 0.1209 | 66.19 |
| INT8 | 8 | 0.1078 | 74.24 |
| INT8+Decomp | 8 | 0.1958 | 40.87 |
| BF16 | 32 | 0.0547 | 584.54 |
| NF4 | 32 | 0.1246 | 256.84 |
| NF4+DQ | 32 | 0.1298 | 246.47 |
| INT8 | 32 | 0.1056 | 302.96 |
| INT8+Decomp | 32 | 0.2027 | 157.83 |
</details>

<details>
<summary>Llama 3.1 70B</summary>

| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| NF4 | 1 | 0.0833 | 12.00 |
| NF4+DQ | 1 | 0.1052 | 9.50 |
| INT8 | 1 | 0.1294 | 7.73 |
| INT8+Decomp | 1 | 0.1985 | 5.04 |
| NF4 | 8 | 0.2348 | 34.07 |
| NF4+DQ | 8 | 0.2423 | 33.01 |
| INT8 | 8 | 0.1313 | 60.94 |
| INT8+Decomp | 8 | 0.2052 | 38.99 |
| NF4 | 32 | 0.2491 | 128.46 |
| NF4+DQ | 32 | 0.2580 | 124.04 |
| INT8 | 32 | 0.1314 | 243.45 |
| INT8+Decomp | 32 | 0.2189 | 146.19 |
</details>

#### Software Configuration
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).

For all hardware configurations, we used the following dependencies:
* `transformers==4.46.3`
* `accelerate==1.1.1`
* `tokenizers==0.20.3`
* `torch==2.5.1`
* `bitsandbytes==0.44.1`
* `bitsandbytes==0.45.0.dev`

In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.
134 changes: 134 additions & 0 deletions benchmarking/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Inference benchmarking tool.
Requirements:
transformers
accelerate
bitsandbytes
optimum-benchmark
Usage: python inference_benchmark.py model_id
options:
-h, --help show this help message and exit
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
--bf16
--fp16
--nf4
--nf4-dq
--int8
--int8-decomp
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
"""

import argparse
from pathlib import Path

from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging
import torch

BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8

WEIGHTS_CONFIGS = {
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
"nf4": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": False,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"nf4-dq": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"int8-decomp": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
},
},
"int8": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 0.0,
},
},
}

if __name__ == "__main__":
setup_logging(level="INFO")

parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")

parser.add_argument("model_id", type=str, help="The model checkpoint to use.")

parser.add_argument(
"--configs",
nargs="+",
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
default=["nf4", "int8", "int8-decomp"],
)
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")

parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
parser.add_argument("--input-length", type=int, default=64)

parser.add_argument("--out-dir", type=str, default="reports")

args = parser.parse_args()

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)

out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"

benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
Loading

0 comments on commit 81e6345

Please sign in to comment.