From 3ccb1308dc078278d8d6f98349ee823a52471955 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Mar 2024 14:14:20 -0400 Subject: [PATCH 01/12] fix diagnostics error within vscode on windows --- bitsandbytes/diagnostics/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index f993dff7e..8974c6400 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -59,7 +59,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path for pth in dir.glob(lib_pattern): if pth.is_file(): yield pth - except PermissionError: + except (OSError, PermissionError): pass From 67e7ee3bc2b0b30ebe520f6e844a11ba5c76cc70 Mon Sep 17 00:00:00 2001 From: Steven Liu Date: Tue, 26 Mar 2024 10:06:07 -0700 Subject: [PATCH 02/12] first draft --- docs/source/fsdp_qlora.md | 106 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 docs/source/fsdp_qlora.md diff --git a/docs/source/fsdp_qlora.md b/docs/source/fsdp_qlora.md new file mode 100644 index 000000000..47922cfcc --- /dev/null +++ b/docs/source/fsdp_qlora.md @@ -0,0 +1,106 @@ +# FSDP-QLoRA + +FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone. + +This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries. + +> [!TIP] +> Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/TimDettmers/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA! + +## Quantized data storage + +FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. + +```py +import torch +import bitsandbytes as bnb + +model = bnb.nn.Linear4bit( + input_features, + output_features, + quant_type="fp4", + quant_storage=torch.uint8, +) +``` + +With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. + +## Training + +bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl). + +Before you begin, make sure you have the latest libraries installed. + +```bash +pip install -U bitsandbytes accelerate transformers peft trl +``` + +> [!TIP] +> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. + +The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type. + +```py +from transformers import BitsAndBytesConfig + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_storage=torch.bfloat16, +) +``` + +Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually. + +```py +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16, +) +``` + +Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`. + +```py +from peft import LoraConfig + +peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=64, + bias="none", + task_type="CAUSAL_LM", + target_modules="all-linear", +) +``` + +Now you can pass everything to the [`~trl.SFTTrainer`] for training. + +```py +from trl import SFTTrainer + +trainer = SFTTrainer( + model=model, + train_dataset=dataset, + peft_config=peft_config, + dataset_text_field="text", + max_seq_length=max_seq_length, + tokenizer=tokenizer, + args=training_arguments, +) +trainer.train() +``` + +## Resources + +To learn more about FSDP and QLoRA, check out the following resources: + +- The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository. +- The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI. +- For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post. +- For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post. From e3376abfd4f7923e3a66b13a8f039fbf21ae7f85 Mon Sep 17 00:00:00 2001 From: Steven Liu Date: Tue, 26 Mar 2024 11:01:11 -0700 Subject: [PATCH 03/12] toctree --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2184cce8c..fdfe19ee4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -12,6 +12,8 @@ title: 8-bit optimizers - local: algorithms title: Algorithms + - local: fsdp_qlora + title: FSDP-QLoRA - local: integrations title: Integrations - local: errors From c6e319072f3c1817460b441aa4135ce956b54e24 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 26 Mar 2024 22:11:08 +0100 Subject: [PATCH 04/12] Bump the major group with 3 updates (#1145) Updates the requirements on [pytest](https://github.com/pytest-dev/pytest), [pandas](https://github.com/pandas-dev/pandas) and [matplotlib](https://github.com/matplotlib/matplotlib) to permit the latest version. Updates `pytest` from 7.2.2 to 8.1.1 - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/7.2.2...8.1.1) Updates `pandas` to 2.2.1 - [Release notes](https://github.com/pandas-dev/pandas/releases) - [Commits](https://github.com/pandas-dev/pandas/compare/v2.2.0...v2.2.1) Updates `matplotlib` to 3.8.3 - [Release notes](https://github.com/matplotlib/matplotlib/releases) - [Commits](https://github.com/matplotlib/matplotlib/compare/v3.8.2...v3.8.3) --- updated-dependencies: - dependency-name: pytest dependency-type: direct:production update-type: version-update:semver-major dependency-group: major - dependency-name: pandas dependency-type: direct:development dependency-group: major - dependency-name: matplotlib dependency-type: direct:development dependency-group: major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 2 +- requirements-dev.txt | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index e6e375ccb..39fa16e08 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,5 +1,5 @@ # Requirements used for GitHub actions -pytest==7.2.2 +pytest==8.1.1 einops==0.6.0 lion-pytorch==0.0.6 scipy==1.10.1; python_version < "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index 7ede5b061..e112365ea 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,9 @@ # Requirements used for local development setuptools>=63 -pytest~=7.2.2 +pytest~=8.1.1 einops~=0.6.0 wheel~=0.40.0 lion-pytorch~=0.0.6 scipy~=1.11.4 -pandas~=2.2.0 -matplotlib~=3.8.2 +pandas~=2.2.1 +matplotlib~=3.8.3 From 040526310ed1b502647510648464d2673de8ad63 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 26 Mar 2024 17:42:25 -0400 Subject: [PATCH 05/12] Add CUDA 12.4 to docs/install helper (#1136) * Add CUDA 12.4 download to utility script, docs * (ci) Add CUDA 12.4.0 build to workflow * Apply ruff format to install_cuda.py --- docs/source/installation.mdx | 2 +- install_cuda.py | 9 +++++++-- install_cuda.sh | 7 +++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 49d8b4ebd..d0dd7ba76 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -84,7 +84,7 @@ Then locally install the CUDA version you need with this script from bitsandbyte ```bash wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123} +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124} # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc diff --git a/install_cuda.py b/install_cuda.py index 9e426cbd7..a5d09356d 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -17,6 +17,7 @@ "121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run", "122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run", "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", + "124": "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run", } @@ -76,7 +77,9 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print("Usage: python install_cuda.py [user/system] [download_path]") + print( + "Usage: python install_cuda.py [user/system] [download_path]" + ) sys.exit(1) version = sys.argv[1] @@ -97,7 +100,9 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") + print( + f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" + ) sys.exit(1) diff --git a/install_cuda.sh b/install_cuda.sh index 8ffbc8478..2e7fe8ed2 100644 --- a/install_cuda.sh +++ b/install_cuda.sh @@ -11,7 +11,7 @@ URL120=https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installer URL121=https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run URL122=https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run URL123=https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run - +URL124=https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run CUDA_VERSION=$1 BASE_PATH=$2 @@ -57,8 +57,11 @@ if [[ -n "$CUDA_VERSION" ]]; then elif [[ "$CUDA_VERSION" -eq "123" ]]; then URL=$URL123 FOLDER=cuda-12.3 + elif [[ "$CUDA_VERSION" -eq "124" ]]; then + URL=$URL124 + FOLDER=cuda-12.4 else - echo "argument error: No cuda version passed as input. Choose among versions 92 to 123" + echo "argument error: No cuda version passed as input. Choose among versions 110 to 124" fi else echo "argument error: No cuda version passed as input. Choose among versions 92 to 123" From fd9d072e02b74348004f197e686e168448883a9e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 18:32:04 +0100 Subject: [PATCH 06/12] Bump the minor-patch group with 4 updates (#1146) Updates the requirements on [einops](https://github.com/arogozhnikov/einops), [wheel](https://github.com/pypa/wheel), [lion-pytorch](https://github.com/lucidrains/lion-pytorch) and [scipy](https://github.com/scipy/scipy) to permit the latest version. Updates `einops` from 0.6.0 to 0.7.0 - [Release notes](https://github.com/arogozhnikov/einops/releases) - [Commits](https://github.com/arogozhnikov/einops/compare/v0.6.0...v0.7.0) Updates `wheel` to 0.43.0 - [Release notes](https://github.com/pypa/wheel/releases) - [Changelog](https://github.com/pypa/wheel/blob/main/docs/news.rst) - [Commits](https://github.com/pypa/wheel/compare/0.40.0...0.43.0) Updates `lion-pytorch` from 0.0.6 to 0.1.2 - [Release notes](https://github.com/lucidrains/lion-pytorch/releases) - [Commits](https://github.com/lucidrains/lion-pytorch/compare/0.0.6...0.1.2) Updates `scipy` from 1.11.4 to 1.12.0 - [Release notes](https://github.com/scipy/scipy/releases) - [Commits](https://github.com/scipy/scipy/compare/v1.11.4...v1.12.0) --- updated-dependencies: - dependency-name: einops dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch - dependency-name: wheel dependency-type: direct:development dependency-group: minor-patch - dependency-name: lion-pytorch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch - dependency-name: scipy dependency-type: direct:production update-type: version-update:semver-minor dependency-group: minor-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 6 +++--- requirements-dev.txt | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 39fa16e08..61f92018a 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions pytest==8.1.1 -einops==0.6.0 -lion-pytorch==0.0.6 +einops==0.7.0 +lion-pytorch==0.1.2 scipy==1.10.1; python_version < "3.9" -scipy==1.11.4; python_version >= "3.9" +scipy==1.12.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index e112365ea..fc5449ba7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,9 @@ # Requirements used for local development setuptools>=63 pytest~=8.1.1 -einops~=0.6.0 -wheel~=0.40.0 -lion-pytorch~=0.0.6 -scipy~=1.11.4 +einops~=0.7.0 +wheel~=0.43.0 +lion-pytorch~=0.1.2 +scipy~=1.12.0 pandas~=2.2.1 matplotlib~=3.8.3 From c17fb8eb4f4b0139229beda0e109e9aab91af957 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:34:09 -0400 Subject: [PATCH 07/12] Fix 4bit quantization with blocksize=4096 --- bitsandbytes/functional.py | 7 ++++--- csrc/ops.cu | 2 +- tests/test_functional.py | 28 +++++++++++++++++++++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb6a04892..f915223ca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64): if data is None: raise NotImplementedError(f"Typename {typename} not supported") - data = Tensor(data) - data /= data.abs().max() + data = torch.tensor(data, device=device) + data.div_(data.abs().max()) + assert data.numel() == 16 - return data.to(device) + return data def quantize_fp4( diff --git a/csrc/ops.cu b/csrc/ops.cu index 796211fed..3a6ffdda8 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -58,7 +58,7 @@ template void quantizeBlockwise(floa num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) diff --git a/tests/test_functional.py b/tests/test_functional.py index b9f1a6ead..1cca04511 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1928,7 +1928,9 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) -def test_fp4_quant(dtype): +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) +@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) +def test_4bit_quant(dtype, quant_type, blocksize): vals = list(product([0, 1], repeat=4)) code = {} @@ -1953,8 +1955,8 @@ def test_fp4_quant(dtype): code[idx] = result A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) - qa, SA = F.quantize_fp4(A1, blocksize=64) - A2 = F.dequantize_fp4(qa, SA) + qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -1962,8 +1964,24 @@ def test_fp4_quant(dtype): err = err.mean() assert A2.dtype == dtype - assert err.item() < 0.1 - assert relerr.item() < 0.28 + + # With larger block sizes, we can expect this to blow up. + # At blocksize>=1024, don't even bother looking at relerr. + if blocksize <= 64: + assert err.item() < 0.1 + assert relerr.item() < 0.28 + elif blocksize <= 256: + assert err.item() < 0.11 + assert relerr.item() < 0.30 + elif blocksize <= 512: + assert err.item() < 0.12 + assert relerr.item() < 0.31 + elif quant_type == "fp4": + # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 + assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 + else: + # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 + assert err.item() < math.log2(blocksize) * 8e-2 @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From a471456911168b3ac798ff99967606013c71cc50 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:35:54 -0400 Subject: [PATCH 08/12] fix formatting for install_cuda.py --- install_cuda.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/install_cuda.py b/install_cuda.py index a5d09356d..cf7c8ee71 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -77,9 +77,7 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print( - "Usage: python install_cuda.py [user/system] [download_path]" - ) + print("Usage: python install_cuda.py [user/system] [download_path]") sys.exit(1) version = sys.argv[1] @@ -100,9 +98,7 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print( - f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" - ) + print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1) From 494de206ce029cf7d03a12eeb7d72368d04d7458 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:22:35 +0200 Subject: [PATCH 09/12] Bump the minor-patch group with 1 update (#1162) Bumps the minor-patch group with 1 update: [lion-pytorch](https://github.com/lucidrains/lion-pytorch). Updates `lion-pytorch` from 0.1.2 to 0.1.4 - [Release notes](https://github.com/lucidrains/lion-pytorch/releases) - [Commits](https://github.com/lucidrains/lion-pytorch/compare/0.1.2...0.1.4) --- updated-dependencies: - dependency-name: lion-pytorch dependency-type: direct:production update-type: version-update:semver-patch dependency-group: minor-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-ci.txt | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-ci.txt b/requirements-ci.txt index 61f92018a..4df975993 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions pytest==8.1.1 einops==0.7.0 -lion-pytorch==0.1.2 +lion-pytorch==0.1.4 scipy==1.10.1; python_version < "3.9" scipy==1.12.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index fc5449ba7..291a51cb1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ setuptools>=63 pytest~=8.1.1 einops~=0.7.0 wheel~=0.43.0 -lion-pytorch~=0.1.2 +lion-pytorch~=0.1.4 scipy~=1.12.0 pandas~=2.2.1 matplotlib~=3.8.3 From bed0860b8e11ea4a15d729e60f694c46eefe7fd4 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 2 Apr 2024 06:31:03 -0400 Subject: [PATCH 10/12] Tests: improve memory usage (#1147) --- tests/conftest.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 17ffd281c..59146963d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import gc + import pytest import torch @@ -20,6 +22,13 @@ def pytest_runtest_call(item): raise +@pytest.hookimpl(trylast=True) +def pytest_runtest_teardown(item, nextitem): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @pytest.fixture(scope="session") def requires_cuda() -> bool: cuda_available = torch.cuda.is_available() From 2965c765a7d95de35484d374e2ce0159858010b3 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:27:07 +0200 Subject: [PATCH 11/12] CHANGELOG.md: mention accuracy changes when quantizing post v0.42 --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 397dceb77..b671145a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -357,6 +357,10 @@ Bug fixes: - Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061). - Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063). +#### Backwards Compatibility +- After upgrading from `v0.42` to `v0.43`, when using 4bit quantization, models may generate slightly different outputs (approximately up to the 2nd decimal place) due to a fix in the code. For anyone interested in the details, [see this comment](https://github.com/TimDettmers/bitsandbytes/discussions/1094#discussioncomment-8984069). + + #### Internal and Build System Enhancements: - Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037). From bfe21182631e8f9575e4b992e70719e01c256901 Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 4 Apr 2024 19:31:40 +0200 Subject: [PATCH 12/12] README: include download badges --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 43eadf5a3..2cf630dcb 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # `bitsandbytes` +[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes) + The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions. The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module.