From 804ec0c8b65a31b3532869fd278adc3864159a9b Mon Sep 17 00:00:00 2001 From: Paul Swingle Date: Wed, 29 May 2024 12:54:39 -0700 Subject: [PATCH] add functions to add arbitrary role to spicemessages --- pyproject.toml | 2 +- spice/spice_message.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7f24c22..c93de33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages=["spice"] [project] name = "spiceai" -version = "0.3.9" +version = "0.3.10" license = {text = "Apache-2.0"} description = "A Python library for building AI-powered applications." readme = "README.md" diff --git a/spice/spice_message.py b/spice/spice_message.py index eeb7bae..9706d9c 100644 --- a/spice/spice_message.py +++ b/spice/spice_message.py @@ -5,7 +5,7 @@ from collections import UserList from json import JSONEncoder from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Iterable, Literal, Optional, TypedDict from openai.types.chat import ( ChatCompletionAssistantMessageParam, @@ -25,6 +25,10 @@ VALID_MIMETYPES = ["image/jpeg", "image/png", "image/gif", "image/webp"] +def create_message(role: Literal["user", "assistant", "system"], content: str) -> ChatCompletionMessageParam: + return {"role": role, "content": content} # pyright: ignore + + def user_message(content: str) -> ChatCompletionUserMessageParam: """Creates a user message with the given content.""" return {"role": "user", "content": content} @@ -112,6 +116,9 @@ def __init__(self, client: Spice, initlist: Optional[Iterable[SpiceMessage]] = N self._client = client super().__init__(initlist) + def add_message(self, role: Literal["user", "assistant", "system"], content: str): + self.data.append(create_message(role, content)) + def add_user_message(self, content: str): """Appends a user message with the given content.""" self.data.append(user_message(content)) @@ -136,6 +143,13 @@ def add_http_image_message(self, url: str): """Appends a user message with the image from the given url.""" self.data.append(http_image_message(url)) + def add_prompt(self, role: Literal["user", "assistant", "system"], name: str, **context: Any): + prompt = self._client.get_prompt(name) + rendered_prompt = self._client.get_rendered_prompt(name, **context) + message = _MetadataDict(create_message(role, rendered_prompt)) + message.prompt_metadata = {"name": name, "content": prompt, "context": context} + self.data.append(message) # pyright: ignore + def add_user_prompt(self, name: str, **context: Any): """Appends a user message with the given pre-loaded prompt using jinja to render the context.""" prompt = self._client.get_prompt(name)