From 048a2d404c6a909e6f835ba18182fbfae130ba09 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 6 Mar 2024 09:30:51 +0200 Subject: [PATCH] Deduplicate helpers & fix lint issues from #1099 (#1107) --- tests/helpers.py | 27 ++++++++++++++++++++------- tests/test_linear4bit.py | 14 +------------- tests/test_linear8bitlt.py | 21 +++++++-------------- 3 files changed, 28 insertions(+), 34 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index f82a8631f..02cb881a3 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,4 @@ +from io import BytesIO from itertools import product import random from typing import Any, List @@ -7,6 +8,25 @@ test_dims_rng = random.Random(42) +TRUE_FALSE = (True, False) +BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool) +BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) + + +def torch_save_to_buffer(obj): + buffer = BytesIO() + torch.save(obj, buffer) + buffer.seek(0) + return buffer + + +def torch_load_from_buffer(buffer): + buffer.seek(0) + obj = torch.load(buffer) + buffer.seek(0) + return obj + + def get_test_dims(min: int, max: int, *, n: int) -> List[int]: return [test_dims_rng.randint(min, max) for _ in range(n)] @@ -42,10 +62,3 @@ def id_formatter(label: str): def describe_dtype(dtype: torch.dtype) -> str: return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2] - - -TRUE_FALSE = (True, False) -BOOLEAN_TRIPLES = list( - product(TRUE_FALSE, repeat=3) -) # all combinations of (bool, bool, bool) -BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index d1f60423c..567e1a466 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,5 +1,4 @@ import copy -from io import BytesIO import os import pickle from tempfile import TemporaryDirectory @@ -8,7 +7,7 @@ import torch import bitsandbytes as bnb -from tests.helpers import TRUE_FALSE +from tests.helpers import TRUE_FALSE, torch_load_from_buffer, torch_save_to_buffer storage = { "uint8": torch.uint8, @@ -17,17 +16,6 @@ "float32": torch.float32, } -def torch_save_to_buffer(obj): - buffer = BytesIO() - torch.save(obj, buffer) - buffer.seek(0) - return buffer - -def torch_load_from_buffer(buffer): - buffer.seek(0) - obj = torch.load(buffer) - buffer.seek(0) - return obj @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index a996b0215..edc3409cd 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,5 +1,4 @@ from contextlib import nullcontext -from io import BytesIO import os from tempfile import TemporaryDirectory @@ -10,7 +9,12 @@ from bitsandbytes import functional as F from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt -from tests.helpers import TRUE_FALSE, id_formatter +from tests.helpers import ( + TRUE_FALSE, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py @@ -66,17 +70,6 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CB is not None assert linear_custom.state.CxB is None -def torch_save_to_buffer(obj): - buffer = BytesIO() - torch.save(obj, buffer) - buffer.seek(0) - return buffer - -def torch_load_from_buffer(buffer): - buffer.seek(0) - obj = torch.load(buffer) - buffer.seek(0) - return obj @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) @pytest.mark.parametrize("serialize_before_forward", TRUE_FALSE, ids=id_formatter("serialize_before_forward")) @@ -171,4 +164,4 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri assert torch.allclose(fx_first, fx_second, atol=1e-5) assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) assert torch.allclose(fx_first, fx_third, atol=1e-5) - assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) \ No newline at end of file + assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)