Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add TP and BNB quantization support to LlavaMultiModalProjector #10834

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,14 @@ def _load_weights(self, model_config: ModelConfig,
model_config.revision,
pre_quant, load_8bit))

model.load_weights(qweight_iterator)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

torch.cuda.empty_cache()

Expand Down Expand Up @@ -1152,9 +1159,10 @@ def _load_weights(self, model_config: ModelConfig,
shard_name, weight_name)
break

# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model.")
continue

if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
Expand Down
35 changes: 23 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -59,25 +61,32 @@ class LlavaImageEmbeddingInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
def __init__(self,
vision_hidden_size: int,
text_hidden_size: int,
projector_hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix=""):
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()

self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_2")

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states, _ = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states


Expand Down Expand Up @@ -325,7 +334,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
projector_hidden_act=config.projector_hidden_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))

self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
Expand Down
Loading