Skip to content

Commit

Permalink
[Model] Add TP and BNB quantization support to LlavaMultiModalProject…
Browse files Browse the repository at this point in the history
…or (#10834)

Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
Isotr0py and DarkLight1337 authored Dec 2, 2024
1 parent 9b14d97 commit 4c05edb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
14 changes: 11 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,14 @@ def _load_weights(self, model_config: ModelConfig,
model_config.revision,
pre_quant, load_8bit))

model.load_weights(qweight_iterator)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

torch.cuda.empty_cache()

Expand Down Expand Up @@ -1152,9 +1159,10 @@ def _load_weights(self, model_config: ModelConfig,
shard_name, weight_name)
break

# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model.")
continue

if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
Expand Down
35 changes: 23 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -59,25 +61,32 @@ class LlavaImageEmbeddingInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()

self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states, _ = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states


Expand Down Expand Up @@ -325,7 +334,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
projector_hidden_act=config.projector_hidden_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))

self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
Expand Down

0 comments on commit 4c05edb

Please sign in to comment.