Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lmi][python] remove quantization enum and rely on engine validation/… #2561

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,6 @@
from djl_python.properties_manager.properties import Properties, RollingBatchEnum, is_rolling_batch_enabled


class HFQuantizeMethods(str, Enum):
# added for backward compatibility lmi-dist
bitsandbytes = 'bitsandbytes'
gptq = 'gptq'

# huggingface
bitsandbytes4 = 'bitsandbytes4'
bitsandbytes8 = 'bitsandbytes8'

# TODO remove this after refactor of all handlers
# supported by vllm
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


def get_torch_dtype_from_str(dtype: str):
if dtype == "auto":
return dtype
Expand All @@ -57,7 +35,7 @@ class HuggingFaceProperties(Properties):
device_map: str = None
load_in_4bit: Optional[bool] = None
load_in_8bit: Optional[bool] = None
quantize: Optional[HFQuantizeMethods] = None
quantize: Optional[str] = None
low_cpu_mem_usage: Optional[bool] = False
disable_flash_attn: Optional[bool] = True

Expand All @@ -81,15 +59,15 @@ def validate_load_in_8bit(cls, load_in_8bit):
@model_validator(mode='after')
def set_quantize_for_backward_compatibility(self):
if self.load_in_4bit:
self.quantize = HFQuantizeMethods.bitsandbytes4
self.quantize = "bitsandbytes4"
elif self.load_in_8bit:
self.quantize = HFQuantizeMethods.bitsandbytes8
self.quantize = "bitsandbytes8"

# TODO remove this after refactor of all handlers
# parsing bitsandbytes8, so it can be directly passed to lmi dist model loader.
if self.quantize == HFQuantizeMethods.bitsandbytes8 \
if self.quantize == "bitsandbytes8" \
and self.rolling_batch == RollingBatchEnum.lmidist:
self.quantize = HFQuantizeMethods.bitsandbytes
self.quantize = "bitsandbytes"
return self

@model_validator(mode='after')
Expand Down Expand Up @@ -152,12 +130,12 @@ def construct_kwargs_quantize(self):
}:
return self

if self.quantize.value == HFQuantizeMethods.bitsandbytes8.value:
if self.quantize == "bitsandbytes8":
if "device_map" not in self.kwargs:
raise ValueError(
"device_map should be set when load_in_8bit is set")
self.kwargs["load_in_8bit"] = True
if self.quantize.value == HFQuantizeMethods.bitsandbytes4.value:
if self.quantize == "bitsandbytes4":
if "device_map" not in self.kwargs:
raise ValueError(
"device_map should set when load_in_4bit is set")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,13 @@
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
from typing import Optional, Mapping, Tuple
from typing import Optional, Mapping, Tuple, Dict

from pydantic import model_validator, field_validator

from djl_python.properties_manager.properties import Properties


class LmiDistQuantizeMethods(str, Enum):
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq = 'gptq'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


class LmiDistLoadFormats(str, Enum):
sagemaker_fast_model_loader = 'sagemaker_fast_model_loader'

Expand All @@ -40,7 +27,7 @@ class LmiDistRbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
quantize: Optional[LmiDistQuantizeMethods] = None
quantize: Optional[str] = None
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,18 @@
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
from typing import Optional, Any, Mapping, Tuple
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator

from djl_python.properties_manager.properties import Properties


class VllmQuantizeMethods(str, Enum):
awq = 'awq'
deepspeedfp = 'deepspeedfp'
fp8 = 'fp8'
fbgemm_fp8 = 'fbgemm_fp8'
gptq = 'gptq'
gptq_marlin = 'gptq_marlin'
gptq_marlin_24 = 'gptq_marlin_24'
awq_marlin = 'awq_marlin'
marlin = 'marlin'
squeezellm = 'squeezellm'


class VllmRbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
quantize: Optional[VllmQuantizeMethods] = None
quantize: Optional[str] = None
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
TransformerNeuronXProperties, TnXGenerationStrategy, TnXModelSchema,
TnXMemoryLayout, TnXDtypeName, TnXModelLoaders)
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.properties_manager.hf_properties import HuggingFaceProperties, HFQuantizeMethods
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from djl_python.properties_manager.sd_inf2_properties import StableDiffusionNeuronXProperties
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties, LmiDistQuantizeMethods
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties
from djl_python.properties_manager.scheduler_rb_properties import SchedulerRbProperties
from djl_python.tests.utils import parameterized, parameters

Expand Down Expand Up @@ -405,12 +405,11 @@ def test_hf_quantize(self):
'rolling_batch': 'lmi-dist'
}
hf_configs = HuggingFaceProperties(**properties)
self.assertEqual(hf_configs.quantize.value,
HFQuantizeMethods.bitsandbytes.value)
self.assertEqual(hf_configs.quantize, "bitsandbytes")

@parameters([{
"model_id": "model_id",
"quantize": HFQuantizeMethods.bitsandbytes4.value
"quantize": "bitsandbytes4"
}, {
"model_id": "model_id",
"load_in_8bit": "true"
Expand Down Expand Up @@ -445,13 +444,6 @@ def test_vllm_valid(properties):
self.assertEqual(vllm_configs.gpu_memory_utilization,
float(properties['gpu_memory_utilization']))

# test with invalid quantization
def test_invalid_quantization_method(properties):
properties['quantize'] = 'gguf'
with self.assertRaises(ValueError):
VllmRbProperties(**properties)
properties['quantize'] = 'awq'

def test_enforce_eager(properties):
properties.pop('enforce_eager')
properties.pop('quantize')
Expand Down Expand Up @@ -503,7 +495,6 @@ def test_invalid_long_lora_scaling_factors(properties):
'load_format': 'pt'
}
test_vllm_valid(properties.copy())
test_invalid_quantization_method(properties.copy())
test_enforce_eager(properties.copy())
test_long_lora_scaling_factors(properties.copy())
test_invalid_long_lora_scaling_factors(properties.copy())
Expand Down Expand Up @@ -573,22 +564,10 @@ def test_with_most_properties():
self.assertEqual(lmi_configs.enable_lora,
bool(properties['enable_lora']))

def test_invalid_quantization():
properties = {'quantize': 'invalid'}
with self.assertRaises(ValueError):
LmiDistRbProperties(**properties, **min_properties)

def test_quantization_with_dtype_error():
# you cannot give both quantization method and dtype
properties = {'quantize': 'bitsandbytes', 'dtype': 'int8'}
with self.assertRaises(ValueError):
LmiDistRbProperties(**properties, **min_properties)

def test_quantization_squeezellm():
properties = {'quantize': 'squeezellm'}
lmi_configs = LmiDistRbProperties(**properties, **min_properties)
self.assertEqual(lmi_configs.quantize.value,
LmiDistQuantizeMethods.squeezellm.value)
self.assertEqual(lmi_configs.quantize, "squeezellm")

def test_long_lora_scaling_factors():
properties = {"long_lora_scaling_factors": "3.0"}
Expand Down Expand Up @@ -627,8 +606,6 @@ def test_invalid_long_lora_scaling_factors():
}
test_with_min_properties()
test_with_most_properties()
test_invalid_quantization()
test_quantization_with_dtype_error()
test_quantization_squeezellm()
test_long_lora_scaling_factors()
test_invalid_long_lora_scaling_factors()
Expand Down
Loading