Skip to content

Commit

Permalink
Deduplicate helpers & fix lint issues from #1099
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Mar 5, 2024
1 parent a1c0844 commit 924ae3a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 34 deletions.
27 changes: 20 additions & 7 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import BytesIO
from itertools import product
import random
from typing import Any, List
Expand All @@ -7,6 +8,25 @@
test_dims_rng = random.Random(42)


TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)


def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer


def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj


def get_test_dims(min: int, max: int, *, n: int) -> List[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)]

Expand Down Expand Up @@ -42,10 +62,3 @@ def id_formatter(label: str):

def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]


TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(
product(TRUE_FALSE, repeat=3)
) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
14 changes: 1 addition & 13 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from io import BytesIO
import os
import pickle
from tempfile import TemporaryDirectory
Expand All @@ -8,7 +7,7 @@
import torch

import bitsandbytes as bnb
from tests.helpers import TRUE_FALSE
from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer

storage = {
"uint8": torch.uint8,
Expand All @@ -17,17 +16,6 @@
"float32": torch.float32,
}

def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer

def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj

@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
@pytest.mark.parametrize("bias", TRUE_FALSE)
Expand Down
21 changes: 7 additions & 14 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from contextlib import nullcontext
from io import BytesIO
import os
from tempfile import TemporaryDirectory

Expand All @@ -10,7 +9,12 @@
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import TRUE_FALSE, id_formatter
from tests.helpers import (
TRUE_FALSE,
id_formatter,
torch_load_from_buffer,
torch_save_to_buffer,
)

# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
Expand Down Expand Up @@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None

def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer

def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer)
buffer.seek(0)
return obj

@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward"))
Expand Down Expand Up @@ -171,4 +164,4 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
assert torch.allclose(fx_first, fx_second, atol=1e-5)
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
assert torch.allclose(fx_first, fx_third, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)

0 comments on commit 924ae3a

Please sign in to comment.