Skip to content

Commit

Permalink
llava-next support
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Aug 30, 2024
1 parent c3fc608 commit 0aaf5f8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
3 changes: 1 addition & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,7 @@ def export_from_model(

logger.info(f"Automatic task detection to: {task}.")

stateful = stateful and ensure_export_task_support_stateful(task) or getattr(getattr(model, "config", {}), "model_type", None) == "llava"

stateful = stateful and ensure_export_task_support_stateful(task) or getattr(getattr(model, "config", {}), "model_type", None) in ["llava", "llava_next"]
# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_export_configs is None:
raise ValueError(
Expand Down
9 changes: 5 additions & 4 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):

class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
def pack_image_features(self, image_features, image_sizes, image_newline=None):
from trnasformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Expand Down Expand Up @@ -893,20 +893,21 @@ def get_multimodal_embeddings(self, input_ids, pixel_values=None, attention_mask
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is not None:
image_newline = torch.zeros(self.config.text_config.hidden_size, dtype=torch.float32)
image_features = torch.split(torch.from_numpy(vision_embeds), image_num_patches, dim=0)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
image_newline=self.image_newline,
image_newline=image_newline,
)
inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings(image_features, inputs_embeds, feature_les=feature_lens, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings(image_features, inputs_embeds, feature_lens=feature_lens, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, **kwargs)

if pixel_values is not None and past_key_values is not None:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)

return inputs_embeds, attention_mask, position_ids

def merge_vision_text_embeddings(self, vision_embeds, inputs_embeds, input_ids, attention_mask, position_ids=None, **kwargs):
def merge_vision_text_embeddings(self, vision_embeds, inputs_embeds, feature_lens, input_ids, attention_mask, position_ids=None, **kwargs):
image_token_index = self.config.image_token_index
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
Expand Down

0 comments on commit 0aaf5f8

Please sign in to comment.