Skip to content

Commit

Permalink
prioretize config model type under path-based task determination (ope…
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Jan 20, 2025
1 parent fe6311d commit d3da17e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tools/llm_bench/llm_bench_utils/config_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

USE_CASES = {
'image_gen': ['stable-diffusion-', 'ssd-', 'tiny-sd', 'small-sd', 'lcm-', 'sdxl', 'dreamlike', "flux"],
"vlm": ["llava", "llava-next", "qwen2-vl", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v"],
"vlm": ["llava", "llava-next", "qwen2-vl", "llava-qwen2", "internvl-chat", "minicpmv", "phi3-v", "minicpm-v"],
'speech2text': ['whisper'],
'image_cls': ['vit'],
'code_gen': ['replit', 'codegen2', 'codegen', 'codet5', "stable-code"],
Expand Down
27 changes: 12 additions & 15 deletions tools/llm_bench/llm_bench_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,10 @@ def analyze_args(args):


def get_use_case(model_name_or_path):
# 1. try to get use_case from model name
path = os.path.normpath(model_name_or_path)
model_names = path.split(os.sep)
for model_name in reversed(model_names):
for case, model_ids in USE_CASES.items():
for model_id in model_ids:
if model_name.lower().startswith(model_id):
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}')
return case, model_name

# 2. try to get use_case from model config
try:
config_file = Path(model_name_or_path) / "config.json"
config_file = Path(model_name_or_path) / "config.json"
config = None
if config_file.exists():
config = json.loads(config_file.read_text())
except Exception:
config = None
if (Path(model_name_or_path) / "model_index.json").exists():
diffusers_config = json.loads((Path(model_name_or_path) / "model_index.json").read_text())
pipe_type = diffusers_config.get("_class_name")
Expand All @@ -214,6 +202,15 @@ def get_use_case(model_name_or_path):
if config.get("model_type").lower().replace('_', '-').startswith(model_id):
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_id}')
return case, model_ids[idx]
# try to get use_case from model name
path = os.path.normpath(model_name_or_path)
model_names = path.split(os.sep)
for model_name in reversed(model_names):
for case, model_ids in USE_CASES.items():
for model_id in model_ids:
if model_name.lower().startswith(model_id):
log.info(f'==SUCCESS FOUND==: use_case: {case}, model_type: {model_name}')
return case, model_name

raise RuntimeError('==Failure FOUND==: no use_case found')

Expand Down

0 comments on commit d3da17e

Please sign in to comment.