From f4d50db6da2ab52d9e0fb6c3ef0871c82a40ab18 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 29 Apr 2024 11:39:03 +0400 Subject: [PATCH] fix export for phi1.5 (#400) --- llm_bench/python/convert.py | 39 ++++++++++++++++++++++ llm_bench/python/utils/ov_utils.py | 52 +++++++++++------------------- 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/llm_bench/python/convert.py b/llm_bench/python/convert.py index 371ffaf7c0..6a65925b14 100644 --- a/llm_bench/python/convert.py +++ b/llm_bench/python/convert.py @@ -1201,6 +1201,44 @@ def convert_falcon(args): unpatch_gptq(cuda, post_init) +def convert_phi(args): + trust_remote_code = False + try: + config = AutoConfig.from_pretrained(args.model_id) + except Exception: + config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True) + trust_remote_code = True + cuda, post_init = patch_gptq(config) + model_kwargs = {} + if trust_remote_code: + model_kwargs["trust_remote_code"] = trust_remote_code + precision = args.precision + compression_only = ( + args.compress_weights + and not args.force_convert + and not is_torch_compression(args) + and is_ov_model_provided(args.model_id, args.output_dir, args.precision) + ) + if post_init is not None: + model_kwargs["torch_dtype"] = torch.float32 + pt_model = None + gptq_applied = is_gptq(config) + precision = precision if not gptq_applied else GPTQ_DIR.format(precision=args.precision) + if not compression_only: + pt_model = AutoModelForCausalLM.from_pretrained( + args.model_id, + config=AutoConfig.from_pretrained(args.model_id), + **model_kwargs, + ) + pt_model.config.use_cache = True + pt_model.eval() + + convert_optimum_causallm_base(pt_model, args, config, compression_only) + + if post_init is not None: + unpatch_gptq(cuda, post_init) + + def convert_baichaun(args): config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True) cuda, post_init = patch_gptq(config) @@ -1304,6 +1342,7 @@ def convert_aquilachat(args): "lcm": convert_lcm, "ldm": convert_ldm_super_res, "mpt": convert_mpt, + "phi-": convert_phi, "replit": convert_mpt, "chatglm2": convert_causal_lm, "chatglm3": convert_causal_lm, diff --git a/llm_bench/python/utils/ov_utils.py b/llm_bench/python/utils/ov_utils.py index 7f2303d00a..ed62498fc6 100644 --- a/llm_bench/python/utils/ov_utils.py +++ b/llm_bench/python/utils/ov_utils.py @@ -141,40 +141,24 @@ def create_text_gen_model(model_path, device, **kwargs): if not model_path_existed: raise RuntimeError(f'==Failure ==: model path:{model_path} does not exist') else: - if model_type in ['replit', 'codegen2', 'chatglm']: - start = time.perf_counter() - ov_model = model_class.from_pretrained( - model_path, - device=device, - ov_config=ov_config, - config=AutoConfig.from_pretrained(model_path, trust_remote_code=True), - stateful=kwargs.get("stateful", None) - ) - end = time.perf_counter() - elif model_type in ['falcon', "mpt"]: - start = time.perf_counter() - ov_model = model_class.from_pretrained( - model_path, - device=device, - ov_config=ov_config, - stateful=kwargs.get("stateful", None), - trust_remote_code=False - ) - end = time.perf_counter() - else: - start = time.perf_counter() - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - ov_model = model_class.from_pretrained( - model_path, - device=device, - ov_config=ov_config, - config=config, - compile=False, - stateful=kwargs.get("stateful", None) - ) - if not isinstance(ov_model, OV_MODEL_CLASSES_MAPPING['t5']): - patch_inter_processing_and_compile(ov_model, **kwargs) - end = time.perf_counter() + remote_code = False + try: + model_config = AutoConfig.from_pretrained(model_path) + except Exception: + model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + remote_code = True + start = time.perf_counter() + ov_model = model_class.from_pretrained( + model_path, + device=device, + ov_config=ov_config, + config=model_config, + stateful=kwargs.get("stateful", None), + trust_remote_code=remote_code + ) + if not isinstance(ov_model, OV_MODEL_CLASSES_MAPPING['t5']): + patch_inter_processing_and_compile(ov_model, **kwargs) + end = time.perf_counter() if kwargs['num_beams'] > 1: bench_hook = utils.hook_beam_search.BeamSearchHook() else: