Skip to content

Commit

Permalink
Merge pull request #1231 from BenjaminBossan/fix-8bit-deepcopy
Browse files Browse the repository at this point in the history
FIX Make Int8Params deepcopy-able
  • Loading branch information
Titus-von-Koeller authored May 30, 2024
2 parents c08653b + ed99b3c commit 3c8c18a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 4 deletions.
19 changes: 15 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,12 @@ def __new__(
CB=None,
SCB=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
if data is None:
data = torch.empty(0)
obj = torch.Tensor._make_subclass(cls, data, requires_grad)
obj.CB, obj.SCB = cls.CB, cls.SCB
obj.CB = CB
obj.SCB = SCB
obj.has_fp16_weights = has_fp16_weights
return obj

def cuda(self, device):
Expand All @@ -585,6 +584,18 @@ def cuda(self, device):

return self

def __deepcopy__(self, memo):
# adjust this if new arguments are added to the constructor
new_instance = type(self).__new__(
type(self),
data=copy.deepcopy(self.data, memo),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
CB=copy.deepcopy(self.CB, memo),
SCB=copy.deepcopy(self.SCB, memo),
)
return new_instance

@overload
def to(
self: T,
Expand Down
58 changes: 58 additions & 0 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import nullcontext
import copy
import os
import pickle
from tempfile import TemporaryDirectory

import pytest
Expand Down Expand Up @@ -177,3 +179,59 @@ def test_linear_serialization(
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)


@pytest.fixture
def linear8bit():
linear = torch.nn.Linear(32, 96)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
return linear_custom


def test_linear8bit_copy_param(linear8bit):
shallow_copy = copy.copy(linear8bit)
assert linear8bit.weight is shallow_copy.weight
assert linear8bit.bias is shallow_copy.bias
assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr()


def test_linear8bit_deepcopy_param(linear8bit):
deep_copy = copy.deepcopy(linear8bit)
assert linear8bit.weight is not deep_copy.weight
assert linear8bit.bias is not deep_copy.bias
assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data)
assert linear8bit.state == deep_copy.state

# check for a bug where SCB and CB were not copied
assert deep_copy.weight.SCB is not None
assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all()
assert deep_copy.weight.CB is not None
assert (linear8bit.weight.CB == deep_copy.weight.CB).all()


def test_linear8bit_serialization(linear8bit):
serialized = pickle.dumps(linear8bit)
deserialized = pickle.loads(serialized)
assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr()
assert torch.allclose(linear8bit.weight.data, deserialized.weight.data)
assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr()
assert torch.allclose(linear8bit.bias.data, deserialized.bias.data)
assert linear8bit.state == deserialized.state

# check for a bug where SCB and CB were not copied
assert (linear8bit.weight.SCB == deserialized.weight.SCB).all()
assert (linear8bit.weight.CB == deserialized.weight.CB).all()

0 comments on commit 3c8c18a

Please sign in to comment.