From d68c3bbe52399e7133bcefdf02c34838b1931181 Mon Sep 17 00:00:00 2001 From: adirtha Date: Tue, 19 Dec 2023 03:38:52 +0530 Subject: [PATCH] fix mix up b/w 'formatted' and 'format' params for ollama api call (#9594) * fix mix up b/w 'formatted' and 'format' params for ollama api call, refactored ollama client * fix linting errors and renamed formatted param in ollama client * cr --------- Co-authored-by: Adirtha Borgohain Co-authored-by: Haotian Zhang --- llama_index/llms/ollama.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/llama_index/llms/ollama.py b/llama_index/llms/ollama.py index d89bbc2ce13c1..bff2bb412c04f 100644 --- a/llama_index/llms/ollama.py +++ b/llama_index/llms/ollama.py @@ -82,7 +82,7 @@ def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]: @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: prompt = self.messages_to_prompt(messages) - completion_response = self.complete(prompt, formatted=True, **kwargs) + completion_response = self.complete(prompt, prompt_formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_chat_callback() @@ -90,7 +90,9 @@ def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: prompt = self.messages_to_prompt(messages) - completion_response = self.stream_complete(prompt, formatted=True, **kwargs) + completion_response = self.stream_complete( + prompt, prompt_formatted=True, **kwargs + ) return stream_completion_response_to_chat_response(completion_response) @llm_completion_callback() @@ -111,28 +113,18 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: "Please install requests with `pip install requests`" ) all_kwargs = self._get_all_kwargs(**kwargs) - del all_kwargs["formatted"] # ollama throws 400 if it receives this option - if not kwargs.get("formatted", False): + if not kwargs.pop("prompt_formatted", False): prompt = self.completion_to_prompt(prompt) - ollama_request_json: Dict[str, Any] = { - "prompt": prompt, - "model": self.model, - "options": all_kwargs, - } - if all_kwargs.get("system"): - ollama_request_json["system"] = all_kwargs["system"] - del all_kwargs["system"] - - if all_kwargs.get("formatted"): - ollama_request_json["format"] = "json" if all_kwargs["formatted"] else None - del all_kwargs["formatted"] - response = requests.post( url=f"{self.base_url}/api/generate/", headers={"Content-Type": "application/json"}, - json=ollama_request_json, + json={ + "prompt": prompt, + "model": self.model, + **all_kwargs, + }, stream=True, ) response.encoding = "utf-8"