Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable quant model support #1074

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def convert_functions(m, target_m, new_function_name, new_function):
convert_functions(sub_m, target_m, new_function_name, new_function)


def convert_class(m, target_m, new_class, config=None):
def convert_class(m, target_m, new_class, device, config):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, config)
new_m = new_class(sub_m, device, config)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, config)
convert_class(sub_m, target_m, new_class, device, config)


def patch_op(m, target_m, new_op_name, new_op):
Expand All @@ -81,7 +81,7 @@ def _patch_llama_model(model):
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.config)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config)
return model


Expand All @@ -97,7 +97,7 @@ def _patch_falcon_model(model):
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
return model


Expand All @@ -110,7 +110,7 @@ def _patch_gpt2_model(model):
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_functions(model, GPT2Block, "forward", _gpt2_block_forward)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config)
convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.device, model.config)
return model


Expand All @@ -119,7 +119,7 @@ def _patch_bert_model(model):
Patch bert model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, BertIntermediate, _IPEXIntermediate)
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand All @@ -128,7 +128,7 @@ def _patch_vit_model(model):
Patch vit model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, ViTIntermediate, _IPEXIntermediate)
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config)
return model


Expand Down
176 changes: 111 additions & 65 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
logger = logging.getLogger(__name__)

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
_accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "musa"]


if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
Expand Down Expand Up @@ -133,6 +134,32 @@ def forward(self, x, y, z):
return x


# Adapted from https://github.com/huggingface/accelerate/blob/v1.2.1/src/accelerate/hooks.py#L183
def _remove_hooks_for_ipex(module, recurse):
if hasattr(module, "_hf_hook"):
module._hf_hook.detach_hook(module)
delattr(module, "_hf_hook")

if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module.__class__.forward.__get__(module)
else:
module.forward = module.__class__.forward.__get__(module)
delattr(module, "_old_forward")

# Remove accelerate added warning hooks from dispatch_model
for attr in _accelerate_added_attributes:
module.__dict__.pop(attr, None)

if recurse:
for child in module.children():
_remove_hooks_for_ipex(child, recurse)

return module


# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
def _ipex_rms_layer_norm_forward(self, hidden_states):
return rms_norm(hidden_states, self.weight, self.variance_epsilon)
Expand Down Expand Up @@ -568,11 +595,11 @@ def _gpt2_block_forward(


class _IPEXAttention(nn.Module):
def __init__(self, module, config) -> None:
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
self.module_device = device
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device
Expand Down Expand Up @@ -654,32 +681,38 @@ def forward(


class _IPEXLlamaAttention(_IPEXAttention):
def __init__(self, module, config) -> None:
super().__init__(module, config)
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)
def __init__(self, module, device, config) -> None:
super().__init__(module, device, config)
if getattr(config, "quantization_config", None) is None:
concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous()
bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias]
use_bias = bias_list != []
self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias)
self.concat_qkv.weight = nn.Parameter(concat_weight)
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)

elif self.module_device.type == "xpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = XPULinearAdd(module.o_proj)

def qkv_gemm(self, hidden_states):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
if hasattr(self, "concat_qkv"):
qkv_out = self.concat_qkv(hidden_states)
query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim)
key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim)
value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim)
else:
query = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim)
key = self.k_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
value = self.v_proj(hidden_states).view(-1, self.num_key_value_heads, self.head_dim)

return query, key, value

Expand All @@ -690,9 +723,9 @@ def rope(self, query, key, **kwargs):


class _IPEXFalconAttention(_IPEXAttention):
def __init__(self, module, config):
def __init__(self, module, device, config):
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
super().__init__(module, device, config)
self.q_slice = self.head_dim * config.num_kv_heads
self.k_slice = self.q_slice + self.head_dim
self.v_slice = self.k_slice + self.head_dim
Expand All @@ -717,9 +750,11 @@ def rope(self, query, key, **kwargs):


class _IPEXGPT2Attention(_IPEXAttention):
def __init__(self, module, config) -> None:
def __init__(self, module, device, config) -> None:
self.num_key_value_heads = config.num_key_value_heads
super().__init__(module, config)
super().__init__(module, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def qkv_gemm(self, hidden_states):
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1)
Expand All @@ -740,21 +775,22 @@ def postprocess_attention_output(self, attn_output):

# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L186
class _IPEXLlamaMLP(nn.Module):
def __init__(self, module, config) -> None:
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)
self.module_device = device
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = LinearAdd(module.down_proj)
self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj)
elif self.module_device.type == "xpu":
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mlp_linear_add = XPULinearAdd(module.down_proj)
self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj)

def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs):
if hasattr(self, "linear_silu_mul"):
Expand All @@ -772,21 +808,22 @@ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **


class _IPEXFalconMLP(nn.Module):
def __init__(self, module, config) -> None:
def __init__(self, module, device, config) -> None:
super().__init__()
_setattr_from_module(self, module)
self.config = config
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
self.module_device = device
if getattr(config, "quantization_config", None) is None:
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)
self.linear_gelu = XPULinearGelu(module.dense_h_to_4h)
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:
if self.module_device.type == "cpu":
self.linear_add_add = LinearAddAdd(module.dense_4h_to_h)
elif self.module_device.type == "xpu":
self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h)

def forward(
self,
Expand All @@ -807,11 +844,13 @@ def forward(

# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
class _IPEXLlamaDecoderLayer(nn.Module):
def __init__(self, module, config):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.self_attn = _IPEXLlamaAttention(module.self_attn, config)
self.mlp = _IPEXLlamaMLP(module.mlp, config)
self.self_attn = _IPEXLlamaAttention(module.self_attn, device, config)
self.mlp = _IPEXLlamaMLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
Expand Down Expand Up @@ -840,11 +879,13 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):


class _IPEXFalconDecoderLayer(nn.Module):
def __init__(self, module, config):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.self_attention = _IPEXFalconAttention(module.self_attention, config)
self.mlp = _IPEXFalconMLP(module.mlp, config)
self.self_attention = _IPEXFalconAttention(module.self_attention, device, config)
self.mlp = _IPEXFalconMLP(module.mlp, device, config)
if getattr(config, "quantization_config", None):
_remove_hooks_for_ipex(self, True)

def forward(self, hidden_states: torch.Tensor, **kwargs):
# Please see the original model's forward to check the parameter
Expand All @@ -867,15 +908,20 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):

# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524
class _IPEXIntermediate(nn.Module):
def __init__(self, module, config):
def __init__(self, module, device, config):
super().__init__()
_setattr_from_module(self, module)
self.module_device = next(module.parameters()).device
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)
self.module_device = device
if getattr(config, "quantization_config", None) is None:
if self.module_device.type == "cpu":
self.linear_gelu = LinearGelu(module.dense)
elif self.module_device.type == "xpu":
self.linear_gelu = XPULinearGelu(module.dense)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_gelu(hidden_states)
if hasattr(self, "linear_gelu"):
hidden_states = self.linear_gelu(hidden_states)
else:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
1 change: 1 addition & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def maybe_apply_torch_compile(self):
self.model.device.type != "cpu"
or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES
or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE)
or getattr(self.config, "quantization_config", None)
):
return
if self.use_cache and not self._supports_static_cache:
Expand Down
Loading
Loading