From 8c11f03a4d68d8a8a913890544ea3f2b03a9cd85 Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 18:58:05 +0000 Subject: [PATCH 1/6] added support for passing model names to inference servers --- gradio_utils/grclient.py | 4 ++++ src/evaluate_params.py | 1 + src/gen.py | 3 +++ 3 files changed, 8 insertions(+) diff --git a/gradio_utils/grclient.py b/gradio_utils/grclient.py index dbf187c5a..a7c7cb4e6 100644 --- a/gradio_utils/grclient.py +++ b/gradio_utils/grclient.py @@ -674,6 +674,7 @@ def query_or_summarize_or_extract(self, Returns: summary/answer: str or extraction List[str] """ + print(model) if self.config is None: self.setup() if self.persist: @@ -799,6 +800,7 @@ def query_or_summarize_or_extract(self, hyde_show_only_final=hyde_show_only_final, doc_json_mode=doc_json_mode, metadata_in_context=metadata_in_context, + model = model ) # in case server changed, update in case clone() @@ -976,6 +978,7 @@ def simple_stream(self, time.sleep(0.01) # ensure get last output to avoid race res_all = job.outputs().copy() + print("res all", res_all) success = job.communicator.job.latest_status.success timeout = 0.1 if success else 10 if len(res_all) > 0: @@ -985,6 +988,7 @@ def simple_stream(self, strex = ''.join(traceback.format_tb(e.__traceback__)) res = res_all[-1] + print("res_all", res_all, "res", res) res_dict = ast.literal_eval(res) text = res_dict['response'] sources = res_dict.get('sources') diff --git a/src/evaluate_params.py b/src/evaluate_params.py index 04520dbae..0d68343fc 100644 --- a/src/evaluate_params.py +++ b/src/evaluate_params.py @@ -81,6 +81,7 @@ 'image_file', 'image_control', + 'base_model' ] # form evaluate defaults for submit_nochat_api diff --git a/src/gen.py b/src/gen.py index 3afe9779c..97b61d313 100644 --- a/src/gen.py +++ b/src/gen.py @@ -2594,6 +2594,7 @@ def get_inf_models(inference_server): elif inference_server.startswith('anthropic'): models.extend(list(anthropic_mapping.keys())) elif inference_server.startswith('http'): + print("get models") inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server) if gr_client is not None: res = gr_client.predict(api_name='/model_names') @@ -4457,6 +4458,7 @@ def evaluate( num_return_sequences=num_return_sequences, do_sample=do_sample, chat=chat_client, + base_model=base_model ) # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value, @@ -5283,6 +5285,7 @@ def mean(a):""", ''] + params_list, tts_speed, image_file, image_control, + None, ] # adjust examples if non-chat mode if not chat: From 57c9309eaadb6337f15e2d864a6ad869484847ac Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 19:01:34 +0000 Subject: [PATCH 2/6] cleaned print statements --- gradio_utils/grclient.py | 3 --- src/gen.py | 1 - 2 files changed, 4 deletions(-) diff --git a/gradio_utils/grclient.py b/gradio_utils/grclient.py index a7c7cb4e6..60122ae25 100644 --- a/gradio_utils/grclient.py +++ b/gradio_utils/grclient.py @@ -674,7 +674,6 @@ def query_or_summarize_or_extract(self, Returns: summary/answer: str or extraction List[str] """ - print(model) if self.config is None: self.setup() if self.persist: @@ -978,7 +977,6 @@ def simple_stream(self, time.sleep(0.01) # ensure get last output to avoid race res_all = job.outputs().copy() - print("res all", res_all) success = job.communicator.job.latest_status.success timeout = 0.1 if success else 10 if len(res_all) > 0: @@ -988,7 +986,6 @@ def simple_stream(self, strex = ''.join(traceback.format_tb(e.__traceback__)) res = res_all[-1] - print("res_all", res_all, "res", res) res_dict = ast.literal_eval(res) text = res_dict['response'] sources = res_dict.get('sources') diff --git a/src/gen.py b/src/gen.py index 97b61d313..9dcabc7b6 100644 --- a/src/gen.py +++ b/src/gen.py @@ -2594,7 +2594,6 @@ def get_inf_models(inference_server): elif inference_server.startswith('anthropic'): models.extend(list(anthropic_mapping.keys())) elif inference_server.startswith('http'): - print("get models") inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server) if gr_client is not None: res = gr_client.predict(api_name='/model_names') From 9e4c967ac26ae8037d1aeb53ece0c3f9cd5f58d3 Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 22:36:36 +0000 Subject: [PATCH 3/6] fixed docai inference issues --- src/gen.py | 156 +---------------------------------------------------- 1 file changed, 1 insertion(+), 155 deletions(-) diff --git a/src/gen.py b/src/gen.py index ff0bdfb7d..021854c69 100644 --- a/src/gen.py +++ b/src/gen.py @@ -4529,161 +4529,6 @@ def evaluate( if not stream_output and img_file == 1: from src.vision.utils_vision import get_llava_response response, _ = get_llava_response(**llava_kwargs) - - if gr_client is not None: - # Note: h2oGPT gradio server could handle input token size issues for prompt, - # but best to handle here so send less data to server - - chat_client = chat - where_from = "gr_client" - client_langchain_mode = 'Disabled' - client_add_chat_history_to_context = add_chat_history_to_context - client_add_search_to_context = False - client_langchain_action = LangChainAction.QUERY.value - client_langchain_agents = [] - gen_server_kwargs = dict(temperature=temperature, - top_p=top_p, - top_k=top_k, - penalty_alpha=penalty_alpha, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat=chat_client, - base_model=base_model - ) - # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection - if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value, - str(PromptType.plain.value)]: - # if our prompt is plain, assume either correct or gradio server knows different prompt type, - # so pass empty prompt_Type - gr_prompt_type = '' - gr_prompt_dict = '' - gr_prompt = prompt # already prepared prompt - gr_context = '' - gr_iinput = '' - else: - # if already have prompt_type that is not plain, None, or '', then already applied some prompting - # But assume server can handle prompting, and need to avoid double-up. - # Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle - # So avoid "prompt" and let gradio server reconstruct from prompt_type we passed - # Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed, - # because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter - # since those won't appear - gr_context = context - gr_prompt = instruction - gr_iinput = iinput - gr_prompt_type = prompt_type - gr_prompt_dict = prompt_dict - - # ensure image in correct format - img_file = get_image_file(image_file, image_control, document_choice) - if img_file is not None and os.path.isfile(img_file): - from src.vision.utils_vision import img_to_base64 - img_file = img_to_base64(img_file) - elif isinstance(img_file, str): - # assume already bytes - img_file = img_file - else: - img_file = None - - client_kwargs = dict(instruction=gr_prompt if chat_client else '', # only for chat=True - iinput=gr_iinput, # only for chat=True - context=gr_context, - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - - **gen_server_kwargs, - - prompt_type=gr_prompt_type, - prompt_dict=gr_prompt_dict, - - instruction_nochat=gr_prompt if not chat_client else '', - iinput_nochat=gr_iinput, # only for chat=False - langchain_mode=client_langchain_mode, - - add_chat_history_to_context=client_add_chat_history_to_context, - chat_conversation=chat_conversation, - text_context_list=text_context_list, - - chatbot_role=chatbot_role, - speaker=speaker, - tts_language=tts_language, - tts_speed=tts_speed, - - langchain_action=client_langchain_action, - langchain_agents=client_langchain_agents, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - document_source_substrings=[], - document_source_substrings_op='and', - document_content_substrings=[], - document_content_substrings_op='and', - pre_prompt_query=pre_prompt_query, - prompt_query=prompt_query, - pre_prompt_summary=pre_prompt_summary, - prompt_summary=prompt_summary, - hyde_llm_prompt=hyde_llm_prompt, - system_prompt=system_prompt, - image_audio_loaders=image_audio_loaders, - pdf_loaders=pdf_loaders, - url_loaders=url_loaders, - jq_schema=jq_schema, - extract_frames=extract_frames, - llava_prompt=llava_prompt, - visible_models=visible_models, - h2ogpt_key=h2ogpt_key, - add_search_to_context=client_add_search_to_context, - docs_ordering_type=docs_ordering_type, - min_max_new_tokens=min_max_new_tokens, - max_input_tokens=max_input_tokens, - max_total_input_tokens=max_total_input_tokens, - docs_token_handling=docs_token_handling, - docs_joiner=docs_joiner, - hyde_level=hyde_level, - hyde_template=hyde_template, - hyde_show_only_final=hyde_show_only_final, - doc_json_mode=doc_json_mode, - metadata_in_context=metadata_in_context, - - image_file=img_file, - image_control=None, # already stuffed into image_file - ) - assert len(set(list(client_kwargs.keys())).symmetric_difference(eval_func_param_names)) == 0 - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - response = '' - text = '' - sources = [] - strex = '' - if not stream_output: - res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) - res_dict = ast.literal_eval(res) - text = res_dict['response'] - sources = res_dict['sources'] - response = prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response) - else: - new_stream = False # hanging for many chatbots - gr_stream_kwargs = dict(client_kwargs=client_kwargs, - api_name=api_name, - prompt=prompt, prompter=prompter, - sanitize_bot_response=sanitize_bot_response, - max_time=max_time, - is_public=is_public, - verbose=verbose) - if new_stream: - res_dict = yield from gr_client.stream(**gr_stream_kwargs) - else: - res_dict = yield from gr_client.simple_stream(**gr_stream_kwargs) - response = res_dict.get('response', '') elif hf_client: # quick sanity check to avoid long timeouts, just see if can reach server requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) @@ -4853,6 +4698,7 @@ def evaluate( image_file=img_file, image_control=None, # already stuffed into image_file + base_model=base_model, ) assert len(set(list(client_kwargs.keys())).symmetric_difference(eval_func_param_names)) == 0 api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing From fd724130b632c0a9484cb1782e5533125105abe6 Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 22:42:43 +0000 Subject: [PATCH 4/6] fixed merge error --- src/gen.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gen.py b/src/gen.py index 021854c69..1055883eb 100644 --- a/src/gen.py +++ b/src/gen.py @@ -4529,6 +4529,8 @@ def evaluate( if not stream_output and img_file == 1: from src.vision.utils_vision import get_llava_response response, _ = get_llava_response(**llava_kwargs) + yield dict(response=response, sources=[], save_dict={}, error='', llm_answers={}, + response_no_refs=response, sources_str='', prompt_raw='') elif hf_client: # quick sanity check to avoid long timeouts, just see if can reach server requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) From 8a2933f73302741b6048e6e959f2e865f508e7fd Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 22:44:49 +0000 Subject: [PATCH 5/6] fixed merge error --- src/gen.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/src/gen.py b/src/gen.py index 1055883eb..5eca73e86 100644 --- a/src/gen.py +++ b/src/gen.py @@ -4531,42 +4531,6 @@ def evaluate( response, _ = get_llava_response(**llava_kwargs) yield dict(response=response, sources=[], save_dict={}, error='', llm_answers={}, response_no_refs=response, sources_str='', prompt_raw='') - elif hf_client: - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) - # HF inference server needs control over input tokens - where_from = "hf_client" - response = '' - sources = [] - - # prompt must include all human-bot like tokens, already added by prompt - # https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types - terminate_response = prompter.terminate_response or [] - stop_sequences = list(set(terminate_response + [prompter.PreResponse])) - stop_sequences = [x for x in stop_sequences if x] - gen_server_kwargs = dict(do_sample=do_sample, - max_new_tokens=max_new_tokens, - # best_of=None, - repetition_penalty=repetition_penalty, - return_full_text=False, - seed=SEED, - stop_sequences=stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # truncate=False, # behaves oddly - # typical_p=top_p, - # watermark=False, - # decoder_input_details=False, - ) - # work-around for timeout at constructor time, will be issue if multi-threading, - # so just do something reasonable or max_time if larger - # lower bound because client is re-used if multi-threading - hf_client.timeout = max(300, max_time) - if not stream_output: - text = hf_client.generate(prompt, **gen_server_kwargs).generated_text - response = prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response) else: response = '' tgen0 = time.time() From a8432cbcb68ee7f78f8b63122d20c85d7e1f95da Mon Sep 17 00:00:00 2001 From: Ryan Chesler Date: Mon, 1 Apr 2024 22:45:28 +0000 Subject: [PATCH 6/6] fixed merge error --- src/gen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gen.py b/src/gen.py index 5eca73e86..484f8bcd5 100644 --- a/src/gen.py +++ b/src/gen.py @@ -4529,6 +4529,7 @@ def evaluate( if not stream_output and img_file == 1: from src.vision.utils_vision import get_llava_response response, _ = get_llava_response(**llava_kwargs) + yield dict(response=response, sources=[], save_dict={}, error='', llm_answers={}, response_no_refs=response, sources_str='', prompt_raw='') else: