From eb9924c408895cc968ec57de3120468faab4cf20 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 28 Feb 2024 13:21:59 -0800 Subject: [PATCH] adding whole Linear8bitLt/Linear4bit module save/load serialization --- bitsandbytes/nn/modules.py | 4 +++- tests/test_linear4bit.py | 27 ++++++++++++++++++++++++++- tests/test_linear8bitlt.py | 34 +++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bd2bd5832..16c8aa9b8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -449,7 +449,9 @@ def __new__( cls.SCB = None if data is None: data = torch.empty(0) - return torch.Tensor._make_subclass(cls, data, requires_grad) + obj = torch.Tensor._make_subclass(cls, data, requires_grad) + obj.CB, obj.SCB = cls.CB, cls.SCB + return obj def cuda(self, device): if self.has_fp16_weights: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 3e62bdf3b..d1f60423c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,4 +1,5 @@ import copy +from io import BytesIO import os import pickle from tempfile import TemporaryDirectory @@ -16,12 +17,24 @@ "float32": torch.float32, } +def torch_save_to_buffer(obj): + buffer = BytesIO() + torch.save(obj, buffer) + buffer.seek(0) + return buffer + +def torch_load_from_buffer(buffer): + buffer.seek(0) + obj = torch.load(buffer) + buffer.seek(0) + return obj @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage): +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE) +def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert a.dtype == b.dtype assert torch.equal(a, b) + if save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + # Forward test x = torch.rand(42, layer_shape[0], device=device) a = linear_q(x) @@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, b) assert torch.equal(a, c) + if not save_before_forward: + bytes_4bit = torch_save_to_buffer(linear_q) + linear_q3 = torch_load_from_buffer(bytes_4bit) + # Test moving to CPU and back to GPU linear_q2.to("cpu") linear_q2.to(device) @@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c.device == d.device assert torch.equal(c, d) + d = linear_q3(x) + assert c.dtype == d.dtype + assert c.device == d.device + assert torch.equal(c, d) + # Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias with TemporaryDirectory() as tmpdir: state_path_4bit = os.path.join(tmpdir, "state_4bit.pth") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 6fa7efb8d..a996b0215 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,4 +1,5 @@ from contextlib import nullcontext +from io import BytesIO import os from tempfile import TemporaryDirectory @@ -65,12 +66,25 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CB is not None assert linear_custom.state.CxB is None +def torch_save_to_buffer(obj): + buffer = BytesIO() + torch.save(obj, buffer) + buffer.seek(0) + return buffer + +def torch_load_from_buffer(buffer): + buffer.seek(0) + obj = torch.load(buffer) + buffer.seek(0) + return obj @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda")) @pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt")) -def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): +@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) +@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda")) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda): linear = torch.nn.Linear(32, 96) x = torch.randn(3, 32, dtype=torch.half) @@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + x_first = x.clone().cuda().requires_grad_(True) fx_first = linear_custom(x_first).float() grad_proj = torch.randn_like(fx_first) @@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri if not serialize_before_forward: state_dict_8bit = linear_custom.state_dict() + if not save_before_forward: + bytes_8bit = torch_save_to_buffer(linear_custom) + with TemporaryDirectory() as tmpdir: state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") state_path = os.path.join(tmpdir, "state.pth") @@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): new_linear_custom.load_state_dict(new_state_dict, strict=True) + if load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + new_linear_custom = new_linear_custom.cuda() if not deserialize_before_cuda: new_linear_custom.load_state_dict(new_state_dict, strict=True) + if not load_before_cuda: + new_linear_custom2 = torch_load_from_buffer(bytes_8bit) + x_second = x.clone().cuda().requires_grad_(True) fx_second = new_linear_custom(x_second).float() (fx_second * grad_proj).mean().backward() + x_third = x.clone().cuda().requires_grad_(True) + fx_third = new_linear_custom2(x_third).float() + (fx_third * grad_proj).mean().backward() + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised if has_fp16_weights or not deserialize_before_cuda: assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) + assert torch.allclose(fx_first, fx_third, atol=1e-5) + assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) \ No newline at end of file