Skip to content

Commit

Permalink
[Model]: Add support for Aria model (vllm-project#10514)
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
xffxff and Isotr0py authored Nov 25, 2024
1 parent 452a4e8 commit b1d9205
Show file tree
Hide file tree
Showing 8 changed files with 791 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,12 @@ Text Generation
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`AriaForConditionalGeneration`
- Aria
- T + I
- :code:`rhymes-ai/Aria`
-
- ✅︎
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- T + I\ :sup:`E`
Expand Down
18 changes: 18 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,23 @@ def run_idefics3(question: str, modality: str):
return llm, prompt, stop_token_ids


# Aria
def run_aria(question: str, modality: str):
assert modality == "image"
model_name = "rhymes-ai/Aria"

llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16")

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
"<|im_end|>\n<|im_start|>assistant\n")

stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -423,6 +440,7 @@ def run_idefics3(question: str, modality: str):
"molmo": run_molmo,
"glm4v": run_glm4v,
"idefics3": run_idefics3,
"aria": run_aria,
}


Expand Down
20 changes: 20 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,25 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
)


def load_aria(question, image_urls: List[str]) -> ModelRequestData:
model_name = "rhymes-ai/Aria"
llm = LLM(model=model_name,
tokenizer_mode="slow",
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={"image": len(image_urls)})
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None)


model_example_map = {
"phi3_v": load_phi3v,
"h2ovl_chat": load_h2onvl,
Expand All @@ -330,6 +349,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
"qwen_vl_chat": load_qwenvl_chat,
"mllama": load_mllama,
"idefics3": load_idefics3,
"aria": load_aria,
}


Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class _HfExamplesInfo:
trust_remote_code=True),
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
trust_remote_code=True),
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
trust_remote_code=True),
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
trust_remote_code=True),
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ def _placeholder_str(self, modality: ModalityStr,
return ""
if model_type == "idefics3":
return "<image>"
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"

raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
Expand Down
Loading

0 comments on commit b1d9205

Please sign in to comment.