Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 18, 2024
1 parent 6e0a4b3 commit 161c194
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@ def __init__(self, initial_data):


class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super().__init__()
self.fc1 = bnb.nn.Linear8bitLt(
dim1,
dim2,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2,
dim1,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
)

Expand Down Expand Up @@ -326,7 +324,7 @@ def test_linear8bitlt_accumulated_gradient():

acc_steps = 10

for i in range(10):
for i in range(15):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
o2 = l2(b1)
Expand Down Expand Up @@ -356,7 +354,7 @@ def test_linear8bitlt_accumulated_gradient():


@pytest.mark.parametrize("threshold", [0.0, 2.0])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
def test_linear8bitlt_no_fp16_weights(threshold):
l1 = (
bnb.nn.Linear8bitLt(
32,
Expand Down Expand Up @@ -420,7 +418,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
.half()
.to("cuda")
Expand All @@ -444,7 +441,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
64,
threshold=threshold,
has_fp16_weights=False,
memory_efficient_backward=memory_efficient_backward,
)
w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
mlp = mlp.cuda().half() # and this line triggers quantization
Expand All @@ -463,21 +459,20 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"

if memory_efficient_backward:
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
grad_proj = torch.randn_like(o1)
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
o1 = mlp(b1)
assert o1.dtype == torch.float16
assert o1.requires_grad
grad_proj = torch.randn_like(o1)

mlp.zero_grad()
(o1 * grad_proj).sum().backward()
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean()
mlp.zero_grad()
(o1 * grad_proj).sum().backward()
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
scale = grad_ref.abs().mean()

torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005
torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
assert (idx == 0).sum().item() <= b1.numel() * 0.005


@pytest.mark.parametrize(
Expand Down

0 comments on commit 161c194

Please sign in to comment.