Skip to content

Commit

Permalink
joined vision embeddings for llava
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Sep 26, 2024
1 parent e0da998 commit 109f927
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 136 deletions.
147 changes: 32 additions & 115 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import enum
import random
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

Expand All @@ -25,7 +24,6 @@
from optimum.exporters.onnx.model_configs import (
CLIPOnnxConfig,
CLIPTextOnnxConfig,
CLIPVisionModelOnnxConfig,
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
Expand Down Expand Up @@ -71,6 +69,7 @@
InternVLChatImageEmbeddingModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
LlavaImageEmbeddingModelPatcher,
MistralModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Expand Down Expand Up @@ -1268,92 +1267,25 @@ def rename_ambiguous_inputs(self, inputs):
return model_inputs


class DummyLLavaMultiModalProjectorInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ["image_features"]

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
self.task = task

if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
else:
self.batch_size = batch_size
self.hidden_size = normalized_config.hidden_size
self.num_patches = (normalized_config.image_size // normalized_config.patch_size) ** 2
self.normalized_config = normalized_config

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
shape = [self.batch_size, self.num_patches, self.hidden_size]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


class LLavaMultimodalProjectorOpenVINOConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyLLavaMultiModalProjectorInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"image_features": {0: "batch_size"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {"hidden_states": {0: "batch_size"}}


@register_in_tasks_manager("clip-vision-model", *["feature-extraction"], library_name="transformers")
class CLIPVisionModeOpenVIONConfig(CLIPVisionModelOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self._config.output_hidden_states:
for i in range(self._config.num_hidden_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size"}

return common_outputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._config.output_hidden_states and "output_hidden_states" not in model_kwargs:
model_kwargs["output_hidden_states"] = True

return super().patch_model_for_export(model, model_kwargs)


class LlavaConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
VISION_EMBEDDINGS = "vision_embeddings"
TEXT_EMBEDDINGS = "text_embeddings"
MULTI_MODAL_PROJECTOR = "multi_modal_projector"


@register_in_tasks_manager("llava", *["image-text-to-text"], library_name="transformers")
class LlavaOpenVINOConfig(OnnxConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: LlavaConfigBehavior = LlavaConfigBehavior.LANGUAGE,
behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
Expand All @@ -1364,14 +1296,22 @@ def __init__(
preprocessors=preprocessors,
)
self._behavior = behavior
self._orig_config = config
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {}
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {}
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"last_hidden_state": {0: "batch_size"}}

def with_behavior(
self,
Expand All @@ -1388,7 +1328,7 @@ def with_behavior(
behavior = LlavaConfigBehavior(behavior)

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
model_type = self._config.text_config.model_type
model_type = self._orig_config.text_config.model_type
model_type = model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
Expand All @@ -1403,32 +1343,23 @@ def with_behavior(
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._config.text_config,
self._orig_config.text_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS
export_config = InputEmbedOpenvVINOConfig(
self._config.text_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
return export_config

if behavior == LlavaConfigBehavior.MULTI_MODAL_PROJECTOR:
export_config = LLavaMultimodalProjectorOpenVINOConfig(
self._config.vision_config,
self._orig_config.text_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
return export_config

if behavior == LlavaConfigBehavior.LANGUAGE:
model_type = self._config.text_config.model_type
model_type = self._orig_config.text_config.model_type
model_type = model_type.replace("_", "-")

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
Expand All @@ -1444,7 +1375,7 @@ def with_behavior(
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._config.text_config,
self._orig_config.text_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
Expand All @@ -1455,29 +1386,14 @@ def with_behavior(
return export_config

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
model_type = self._config.vision_config.model_type
model_type = model_type.replace("_", "-")
self._config.vision_config.output_hidden_states = True

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
f"Unsupported vision embedding model type provided `{model_type}`. Please define custom export config"
)

if "feature-extraction" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
raise ValueError(
f"Export config for feature extraction for `{model_type}` is not available. Please define custom export config"
)

export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]["feature-extraction"]
export_config = export_config_class(
self._config.vision_config,
task="feature-extraction",
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)
return export_config

def get_model_for_behaviour(self, model, behavior: Union[str, LlavaConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
Expand All @@ -1487,20 +1403,20 @@ def get_model_for_behaviour(self, model, behavior: Union[str, LlavaConfigBehavio
return model.language_model

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
vision_embedding = model.vision_tower
vision_embedding.config.output_hidden_states = True
vision_embedding.vision_model.config.output_hidden_states = True
return vision_embedding
return model

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.get_input_embeddings()
text_embedding.config = model.language_model.config
return text_embedding

if behavior == LlavaConfigBehavior.MULTI_MODAL_PROJECTOR:
mm_projector = model.multi_modal_projector
mm_projector.config = model.config.vision_config
return mm_projector
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return LlavaImageEmbeddingModelPatcher(self, model, model_kwargs)


@register_in_tasks_manager("llava-next", *["image-text-to-text"], library_name="transformers")
Expand Down Expand Up @@ -1540,6 +1456,7 @@ def __init__(
self._orig_config = config
if self._behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down
39 changes: 33 additions & 6 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,12 +1549,6 @@ class Phi3ModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

# currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues
if self._model.config.max_position_embeddings != getattr(
self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings
):
self._model.config.max_position_embeddings = self._model.config.original_max_position_embe

if is_transformers_version(">=", "4.42.0"):
self._model.model._orig_forward = self._model.model.forward
self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model)
Expand Down Expand Up @@ -2639,3 +2633,36 @@ def __init__(
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def llava_vision_embed_forward(self, pixel_values):
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer]

if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

image_features = self.multi_modal_projector(selected_image_feature)
return image_features


class LlavaImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(llava_vision_embed_forward, model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
16 changes: 1 addition & 15 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,24 +748,10 @@ def can_generate(self):


class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["multi_modal_projector"]

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
vision_feature_layer = self.config.vision_feature_layer
vision_feature_select_strategy = self.config.vision_feature_select_strategy
image_outputs = self.vision_embeddings(pixel_values)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

image_features = self.multi_modal_projector(selected_image_feature)
image_features = self.vision_embeddings(pixel_values).last_hidden_state

return image_features

Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def _infer_library_from_model_or_model_class(
):
if model.__module__.startswith("open_clip"):
library_name = "open_clip"
elif model.__module__.startswith("torch"):
library_name = "transformers"
elif model.__module__.startswith("optimum"):
# for wrapped models like timm in optimum.intel.openvino.modeling_timm
library_name = TasksManager._infer_library_from_model_or_model_class(model=model.model)
Expand Down

0 comments on commit 109f927

Please sign in to comment.