Skip to content

Commit

Permalink
fix mix up b/w 'formatted' and 'format' params for ollama api call (r…
Browse files Browse the repository at this point in the history
…un-llama#9594)

* fix mix up b/w 'formatted' and 'format' params for ollama api call, refactored ollama client

* fix linting errors and renamed formatted param in ollama client

* cr

---------

Co-authored-by: Adirtha Borgohain <[email protected]>
Co-authored-by: Haotian Zhang <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent 22cef9c commit d68c3bb
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions llama_index/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
completion_response = self.complete(prompt, prompt_formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
completion_response = self.stream_complete(
prompt, prompt_formatted=True, **kwargs
)
return stream_completion_response_to_chat_response(completion_response)

@llm_completion_callback()
Expand All @@ -111,28 +113,18 @@ def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
"Please install requests with `pip install requests`"
)
all_kwargs = self._get_all_kwargs(**kwargs)
del all_kwargs["formatted"] # ollama throws 400 if it receives this option

if not kwargs.get("formatted", False):
if not kwargs.pop("prompt_formatted", False):
prompt = self.completion_to_prompt(prompt)

ollama_request_json: Dict[str, Any] = {
"prompt": prompt,
"model": self.model,
"options": all_kwargs,
}
if all_kwargs.get("system"):
ollama_request_json["system"] = all_kwargs["system"]
del all_kwargs["system"]

if all_kwargs.get("formatted"):
ollama_request_json["format"] = "json" if all_kwargs["formatted"] else None
del all_kwargs["formatted"]

response = requests.post(
url=f"{self.base_url}/api/generate/",
headers={"Content-Type": "application/json"},
json=ollama_request_json,
json={
"prompt": prompt,
"model": self.model,
**all_kwargs,
},
stream=True,
)
response.encoding = "utf-8"
Expand Down

0 comments on commit d68c3bb

Please sign in to comment.