Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Dec 17, 2024
1 parent f9d5a93 commit 7639d27
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def test_quantization(self):
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]

if TorchAoConfig._is_cuda_capability_atleast_8_9():
Expand All @@ -253,8 +252,8 @@ def test_quantization(self):

for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint_a16w7"]:
# The dummy flux model that we use requires us to impose some restrictions on group_size here
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
Expand Down

0 comments on commit 7639d27

Please sign in to comment.