Skip to content

Commit

Permalink
adding whole Linear8bitLt/Linear4bit module save/load serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Feb 28, 2024
1 parent f9eba9c commit eb9924c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
4 changes: 3 additions & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ def __new__(
cls.SCB = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
obj = torch.Tensor._make_subclass(cls, data, requires_grad)
obj.CB, obj.SCB = cls.CB, cls.SCB
return obj

def cuda(self, device):
if self.has_fp16_weights:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from io import BytesIO
import os
import pickle
from tempfile import TemporaryDirectory
Expand All @@ -16,12 +17,24 @@
"float32": torch.float32,
}

def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer

def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE)
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage, save_before_forward):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
Expand Down Expand Up @@ -124,6 +137,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert a.dtype == b.dtype
assert torch.equal(a, b)

if save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)

# Forward test
x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x)
Expand All @@ -136,6 +152,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert torch.equal(a, b)
assert torch.equal(a, c)

if not save_before_forward:
bytes_4bit = torch_save_to_buffer(linear_q)
linear_q3 = torch_load_from_buffer(bytes_4bit)

# Test moving to CPU and back to GPU
linear_q2.to("cpu")
linear_q2.to(device)
Expand All @@ -144,6 +164,11 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora
assert c.device == d.device
assert torch.equal(c, d)

d = linear_q3(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)

# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir:
state_path_4bit = os.path.join(tmpdir, "state_4bit.pth")
Expand Down
34 changes: 33 additions & 1 deletion tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import nullcontext
from io import BytesIO
import os
from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -65,12 +66,25 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None

def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer

def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj

@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
@pytest.mark.parametrize("deserialize_before_cuda", TRUE_FALSE, ids=id_formatter("deserialize_before_cuda"))
@pytest.mark.parametrize("force_no_igemmlt", TRUE_FALSE, ids=id_formatter("force_no_igemmlt"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
@pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward"))
@pytest.mark.parametrize("load_before_cuda", TRUE_FALSE, ids=id_formatter("load_before_cuda"))
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt, save_before_forward, load_before_cuda):
linear = torch.nn.Linear(32, 96)
x = torch.randn(3, 32, dtype=torch.half)

Expand All @@ -93,6 +107,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()

if save_before_forward:
bytes_8bit = torch_save_to_buffer(linear_custom)

x_first = x.clone().cuda().requires_grad_(True)
fx_first = linear_custom(x_first).float()
grad_proj = torch.randn_like(fx_first)
Expand All @@ -101,6 +118,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
if not serialize_before_forward:
state_dict_8bit = linear_custom.state_dict()

if not save_before_forward:
bytes_8bit = torch_save_to_buffer(linear_custom)

with TemporaryDirectory() as tmpdir:
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
state_path = os.path.join(tmpdir, "state.pth")
Expand All @@ -127,16 +147,28 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
new_linear_custom.load_state_dict(new_state_dict, strict=True)

if load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

new_linear_custom = new_linear_custom.cuda()

if not deserialize_before_cuda:
new_linear_custom.load_state_dict(new_state_dict, strict=True)

if not load_before_cuda:
new_linear_custom2 = torch_load_from_buffer(bytes_8bit)

x_second = x.clone().cuda().requires_grad_(True)
fx_second = new_linear_custom(x_second).float()
(fx_second * grad_proj).mean().backward()

x_third = x.clone().cuda().requires_grad_(True)
fx_third = new_linear_custom2(x_third).float()
(fx_third * grad_proj).mean().backward()

# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
if has_fp16_weights or not deserialize_before_cuda:
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)

0 comments on commit eb9924c

Please sign in to comment.