Skip to content

Commit

Permalink
[lmi][python] remove quantization enum and rely on engine validation/…
Browse files Browse the repository at this point in the history
…support
  • Loading branch information
siddvenk committed Nov 15, 2024
1 parent df52ef1 commit bf76ff0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,6 @@
from djl_python.properties_manager.properties import Properties, RollingBatchEnum, is_rolling_batch_enabled


class HFQuantizeMethods(str, Enum):
# added for backward compatibility lmi-dist
bitsandbytes = 'bitsandbytes'
gptq = 'gptq'

# huggingface
bitsandbytes4 = 'bitsandbytes4'
bitsandbytes8 = 'bitsandbytes8'

# TODO remove this after refactor of all handlers
# supported by vllm
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


def get_torch_dtype_from_str(dtype: str):
if dtype == "auto":
return dtype
Expand All @@ -57,7 +35,7 @@ class HuggingFaceProperties(Properties):
device_map: str = None
load_in_4bit: Optional[bool] = None
load_in_8bit: Optional[bool] = None
quantize: Optional[HFQuantizeMethods] = None
quantize: Optional[str] = None
low_cpu_mem_usage: Optional[bool] = False
disable_flash_attn: Optional[bool] = True

Expand All @@ -81,15 +59,15 @@ def validate_load_in_8bit(cls, load_in_8bit):
@model_validator(mode='after')
def set_quantize_for_backward_compatibility(self):
if self.load_in_4bit:
self.quantize = HFQuantizeMethods.bitsandbytes4
self.quantize = "bitsandbytes4"
elif self.load_in_8bit:
self.quantize = HFQuantizeMethods.bitsandbytes8
self.quantize = "bitsandbytes8"

# TODO remove this after refactor of all handlers
# parsing bitsandbytes8, so it can be directly passed to lmi dist model loader.
if self.quantize == HFQuantizeMethods.bitsandbytes8 \
if self.quantize == "bitsandbytes8" \
and self.rolling_batch == RollingBatchEnum.lmidist:
self.quantize = HFQuantizeMethods.bitsandbytes
self.quantize = "bitsandbytes"
return self

@model_validator(mode='after')
Expand Down Expand Up @@ -152,12 +130,12 @@ def construct_kwargs_quantize(self):
}:
return self

if self.quantize.value == HFQuantizeMethods.bitsandbytes8.value:
if self.quantize == "bitsandbytes8":
if "device_map" not in self.kwargs:
raise ValueError(
"device_map should be set when load_in_8bit is set")
self.kwargs["load_in_8bit"] = True
if self.quantize.value == HFQuantizeMethods.bitsandbytes4.value:
if self.quantize == "bitsandbytes4":
if "device_map" not in self.kwargs:
raise ValueError(
"device_map should set when load_in_4bit is set")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,13 @@
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
from typing import Optional, Mapping, Tuple
from typing import Optional, Mapping, Tuple, Dict

from pydantic import model_validator, field_validator

from djl_python.properties_manager.properties import Properties


class LmiDistQuantizeMethods(str, Enum):
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq = 'gptq'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


class LmiDistLoadFormats(str, Enum):
sagemaker_fast_model_loader = 'sagemaker_fast_model_loader'

Expand All @@ -40,7 +27,7 @@ class LmiDistRbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
quantize: Optional[LmiDistQuantizeMethods] = None
quantize: Optional[str] = None
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,18 @@
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
from typing import Optional, Any, Mapping, Tuple
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator

from djl_python.properties_manager.properties import Properties


class VllmQuantizeMethods(str, Enum):
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq = 'gptq'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


class VllmRbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
quantize: Optional[VllmQuantizeMethods] = None
quantize: Optional[str] = None
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
Expand Down
33 changes: 5 additions & 28 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
TransformerNeuronXProperties, TnXGenerationStrategy, TnXModelSchema,
TnXMemoryLayout, TnXDtypeName, TnXModelLoaders)
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.properties_manager.hf_properties import HuggingFaceProperties, HFQuantizeMethods
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from djl_python.properties_manager.sd_inf2_properties import StableDiffusionNeuronXProperties
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties, LmiDistQuantizeMethods
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties
from djl_python.properties_manager.scheduler_rb_properties import SchedulerRbProperties
from djl_python.tests.utils import parameterized, parameters

Expand Down Expand Up @@ -405,12 +405,11 @@ def test_hf_quantize(self):
'rolling_batch': 'lmi-dist'
}
hf_configs = HuggingFaceProperties(**properties)
self.assertEqual(hf_configs.quantize.value,
HFQuantizeMethods.bitsandbytes.value)
self.assertEqual(hf_configs.quantize, "bitsandbytes")

@parameters([{
"model_id": "model_id",
"quantize": HFQuantizeMethods.bitsandbytes4.value
"quantize": "bitsandbytes4"
}, {
"model_id": "model_id",
"load_in_8bit": "true"
Expand Down Expand Up @@ -445,13 +444,6 @@ def test_vllm_valid(properties):
self.assertEqual(vllm_configs.gpu_memory_utilization,
float(properties['gpu_memory_utilization']))

# test with invalid quantization
def test_invalid_quantization_method(properties):
properties['quantize'] = 'gguf'
with self.assertRaises(ValueError):
VllmRbProperties(**properties)
properties['quantize'] = 'awq'

def test_enforce_eager(properties):
properties.pop('enforce_eager')
properties.pop('quantize')
Expand Down Expand Up @@ -503,7 +495,6 @@ def test_invalid_long_lora_scaling_factors(properties):
'load_format': 'pt'
}
test_vllm_valid(properties.copy())
test_invalid_quantization_method(properties.copy())
test_enforce_eager(properties.copy())
test_long_lora_scaling_factors(properties.copy())
test_invalid_long_lora_scaling_factors(properties.copy())
Expand Down Expand Up @@ -573,22 +564,10 @@ def test_with_most_properties():
self.assertEqual(lmi_configs.enable_lora,
bool(properties['enable_lora']))

def test_invalid_quantization():
properties = {'quantize': 'invalid'}
with self.assertRaises(ValueError):
LmiDistRbProperties(**properties, **min_properties)

def test_quantization_with_dtype_error():
# you cannot give both quantization method and dtype
properties = {'quantize': 'bitsandbytes', 'dtype': 'int8'}
with self.assertRaises(ValueError):
LmiDistRbProperties(**properties, **min_properties)

def test_quantization_squeezellm():
properties = {'quantize': 'squeezellm'}
lmi_configs = LmiDistRbProperties(**properties, **min_properties)
self.assertEqual(lmi_configs.quantize.value,
LmiDistQuantizeMethods.squeezellm.value)
self.assertEqual(lmi_configs.quantize, "squeezellm")

def test_long_lora_scaling_factors():
properties = {"long_lora_scaling_factors": "3.0"}
Expand Down Expand Up @@ -627,8 +606,6 @@ def test_invalid_long_lora_scaling_factors():
}
test_with_min_properties()
test_with_most_properties()
test_invalid_quantization()
test_quantization_with_dtype_error()
test_quantization_squeezellm()
test_long_lora_scaling_factors()
test_invalid_long_lora_scaling_factors()
Expand Down

0 comments on commit bf76ff0

Please sign in to comment.