From 161c1949fdaeaabc272d409f0478a6548a77918d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:01:32 -0500 Subject: [PATCH] Update test --- tests/test_modules.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index 9e6a708b9..239c7d3a6 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -17,20 +17,18 @@ def __init__(self, initial_data): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): super().__init__() self.fc1 = bnb.nn.Linear8bitLt( dim1, dim2, has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, threshold=threshold, ) self.fc2 = bnb.nn.Linear8bitLt( dim2, dim1, has_fp16_weights=has_fp16_weights, - memory_efficient_backward=memory_efficient_backward, threshold=threshold, ) @@ -326,7 +324,7 @@ def test_linear8bitlt_accumulated_gradient(): acc_steps = 10 - for i in range(10): + for i in range(15): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) o2 = l2(b1) @@ -356,7 +354,7 @@ def test_linear8bitlt_accumulated_gradient(): @pytest.mark.parametrize("threshold", [0.0, 2.0]) -def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): +def test_linear8bitlt_no_fp16_weights(threshold): l1 = ( bnb.nn.Linear8bitLt( 32, @@ -420,7 +418,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): 64, threshold=threshold, has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, ) .half() .to("cuda") @@ -444,7 +441,6 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): 64, threshold=threshold, has_fp16_weights=False, - memory_efficient_backward=memory_efficient_backward, ) w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, mlp = mlp.cuda().half() # and this line triggers quantization @@ -463,21 +459,20 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" - if memory_efficient_backward: - b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) - o1 = mlp(b1) - assert o1.dtype == torch.float16 - assert o1.requires_grad - grad_proj = torch.randn_like(o1) + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) - mlp.zero_grad() - (o1 * grad_proj).sum().backward() - grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() - scale = grad_ref.abs().mean() + mlp.zero_grad() + (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() + scale = grad_ref.abs().mean() - torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) - idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) - assert (idx == 0).sum().item() <= b1.numel() * 0.005 + torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.005 @pytest.mark.parametrize(