Skip to content

Commit

Permalink
fix all auto-gptq tests
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Dec 17, 2024
1 parent 69cf2e3 commit 7312b7a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
11 changes: 9 additions & 2 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
from ..version import __version__ as optimum_version
from .constants import GPTQ_CONFIG
from .data import get_dataset, prepare_dataset
from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen
from .utils import (
get_block_name_with_pattern,
get_device,
get_layers,
get_preceding_modules,
get_seqlen,
nested_move_to,
)


if is_accelerate_available():
Expand All @@ -53,7 +60,7 @@
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, nested_move_to
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
from gptqmodel.version import __version__ as gptqmodel_version

Expand Down
2 changes: 1 addition & 1 deletion optimum/gptq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_seqlen(model: nn.Module):
return 2048


def move_to(obj: torch.Tensor | nn.Module, device: torch.device):
def move_to(obj: torch.Tensor, device: torch.device):
if get_device(obj) != device:
obj = obj.to(device)
return obj
Expand Down
6 changes: 3 additions & 3 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def require_gptq(test_case):
"""
Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
"""
return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")(
test_case
)
return unittest.skipUnless(
is_auto_gptq_available() or is_gptqmodel_available(), "test requires gptqmodel or auto-gptq"
)(test_case)


def require_torch_gpu(test_case):
Expand Down
5 changes: 4 additions & 1 deletion tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def test_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights
"""
# AutoGPTQ does not support CPU
if self.device_map_for_quantization == "cpu" and not is_gptqmodel_available():
return

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
Expand Down Expand Up @@ -309,7 +312,7 @@ def test_exllama_serialization(self):
device_map={"": self.device_for_inference},
)
self.check_quantized_layers_type(
quantized_model_from_saved, "exllama" if is_gptqmodel_available else "exllamav2"
quantized_model_from_saved, "exllama" if is_gptqmodel_available() else "exllamav2"
)

# transformers and auto-gptq compatibility
Expand Down

0 comments on commit 7312b7a

Please sign in to comment.