Skip to content

Commit

Permalink
feat: Prompt management support via prompts module
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696993909
  • Loading branch information
matthew29tang authored and copybara-github committed Nov 20, 2024
1 parent b355881 commit 56c3f66
Show file tree
Hide file tree
Showing 3 changed files with 867 additions and 7 deletions.
14 changes: 14 additions & 0 deletions vertexai/preview/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,21 @@
from vertexai.prompts._prompts import (
Prompt,
)
from vertexai.prompts._prompts_hub import (
from_id,
create_version,
restore_version,
list_prompts,
list_prompt_versions,
delete_prompt_resource,
)

__all__ = [
"Prompt",
"create_version",
"delete_prompt_resource",
"from_id",
"list_prompts",
"list_prompt_versions",
"restore_version",
]
123 changes: 116 additions & 7 deletions vertexai/prompts/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
# limitations under the License.
#

from copy import deepcopy

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform.compat.services import dataset_service_client
from vertexai.generative_models import (
Content,
Image,
Expand Down Expand Up @@ -51,6 +54,7 @@

_LOGGER = base.Logger(__name__)

DEFAULT_MODEL_NAME = "gemini-1.5-flash-002"
VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})"


Expand Down Expand Up @@ -126,9 +130,10 @@ class Prompt:

def __init__(
self,
prompt_data: PartsType,
prompt_data: Optional[PartsType] = None,
*,
variables: Optional[List[Dict[str, PartsType]]] = None,
prompt_name: Optional[str] = None,
generation_config: Optional[GenerationConfig] = None,
model_name: Optional[str] = None,
safety_settings: Optional[SafetySetting] = None,
Expand All @@ -141,9 +146,11 @@ def __init__(
Args:
prompt: A PartsType prompt which may be a template with variables or a prompt with no variables.
variables: A list of dictionaries containing the variable names and values.
prompt_name: The name of the prompt if stored in an online resource.
generation_config: A GenerationConfig object containing parameters for generation.
model_name: Model Garden model resource name.
Alternatively, a tuned model endpoint resource name can be provided.
If no model is provided, the default latest model will be used.
safety_settings: A SafetySetting object containing safety settings for generation.
system_instruction: A PartsType object representing the system instruction.
tools: A list of Tool objects for function calling.
Expand All @@ -158,8 +165,16 @@ def __init__(
self._tools = None
self._tool_config = None

# Prompt Management
self._dataset_client_value = None
self._dataset = None
self._prompt_name = None
self._version_id = None
self._version_name = None

self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
self.prompt_name = prompt_name
self.model_name = model_name
self.generation_config = generation_config
self.safety_settings = safety_settings
Expand All @@ -168,20 +183,27 @@ def __init__(
self.tool_config = tool_config

@property
def prompt_data(self) -> PartsType:
def prompt_data(self) -> Optional[PartsType]:
return self._prompt_data

@property
def variables(self) -> Optional[List[Dict[str, PartsType]]]:
return self._variables

@property
def prompt_name(self) -> Optional[str]:
return self._prompt_name

@property
def generation_config(self) -> Optional[GenerationConfig]:
return self._generation_config

@property
def model_name(self) -> Optional[str]:
return self._model_name
if self._model_name:
return self._model_name
else:
return Prompt._format_model_resource_name(DEFAULT_MODEL_NAME)

@property
def safety_settings(self) -> Optional[List[SafetySetting]]:
Expand All @@ -199,14 +221,29 @@ def tools(self) -> Optional[List[Tool]]:
def tool_config(self) -> Optional[ToolConfig]:
return self._tool_config

@property
def prompt_id(self) -> Optional[str]:
if self._dataset:
return self._dataset.name.split("/")[-1]
return None

@property
def version_id(self) -> Optional[str]:
return self._version_id

@property
def version_name(self) -> Optional[str]:
return self._version_name

@prompt_data.setter
def prompt_data(self, prompt_data: PartsType) -> None:
"""Overwrites the existing saved local prompt_data.
Args:
prompt_data: A PartsType prompt.
"""
self._validate_parts_type_data(prompt_data)
if prompt_data is not None:
self._validate_parts_type_data(prompt_data)
self._prompt_data = prompt_data

@variables.setter
Expand All @@ -226,6 +263,14 @@ def variables(self, variables: List[Dict[str, PartsType]]) -> None:
f"Variables must be a list of dictionaries, not {type(variables)}"
)

@prompt_name.setter
def prompt_name(self, prompt_name: Optional[str]) -> None:
"""Overwrites the existing saved local prompt_name."""
if prompt_name:
self._prompt_name = prompt_name
else:
self._prompt_name = None

@model_name.setter
def model_name(self, model_name: Optional[str]) -> None:
"""Overwrites the existing saved local model_name."""
Expand Down Expand Up @@ -370,6 +415,10 @@ def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
)
```
"""
# If prompt_data is None, throw an error.
if self.prompt_data is None:
raise ValueError("prompt_data must not be empty.")

variables_dict = variables_dict.copy()

# If there are no variables, return the prompt_data as a Content object.
Expand Down Expand Up @@ -541,10 +590,15 @@ def generate_content(
)
```
"""
if not (model_name or self._model_name):
_LOGGER.info(
"No model name specified, falling back to default model: %s",
self.model_name,
)
model_name = model_name or self.model_name

generation_config = generation_config or self.generation_config
safety_settings = safety_settings or self.safety_settings
model_name = model_name or self.model_name
tools = tools or self.tools
tool_config = tool_config or self.tool_config
system_instruction = system_instruction or self.system_instruction
Expand All @@ -567,14 +621,69 @@ def generate_content(
stream=stream,
)

@property
def _dataset_client(self) -> dataset_service_client.DatasetServiceClient:
if not getattr(self, "_dataset_client_value", None):
self._dataset_client_value = (
aiplatform_initializer.global_config.create_client(
client_class=dataset_service_client.DatasetServiceClient,
)
)
return self._dataset_client_value

@classmethod
def _clone(cls, prompt: "Prompt") -> "Prompt":
"""Returns a copy of the Prompt."""
return Prompt(
prompt_data=prompt.prompt_data,
variables=deepcopy(prompt.variables),
generation_config=deepcopy(prompt.generation_config),
safety_settings=deepcopy(prompt.safety_settings),
tools=deepcopy(prompt.tools),
tool_config=deepcopy(prompt.tool_config),
system_instruction=prompt.system_instruction,
model_name=prompt.model_name,
)

def get_unassembled_prompt_data(self) -> PartsType:
"""Returns the prompt data, without any variables replaced."""
return self.prompt_data

def __str__(self) -> str:
"""Returns the prompt data as a string, without any variables replaced."""
return str(self.prompt_data)
return str(self.prompt_data or "")

def __repr__(self) -> str:
"""Returns a string representation of the unassembled prompt."""
return f"Prompt(prompt_data='{self.prompt_data}', variables={self.variables})"
result = "Prompt("
if self.prompt_data:
result += f"prompt_data='{self.prompt_data}', "
if self.variables and self.variables[0]:
result += f"variables={self.variables}), "
if self.system_instruction:
result += f"system_instruction={self.system_instruction}), "
if self._model_name:
# Don't display default model in repr
result += f"model_name={self._model_name}), "
if self.generation_config:
result += f"generation_config={self.generation_config}), "
if self.safety_settings:
result += f"safety_settings={self.safety_settings}), "
if self.tools:
result += f"tools={self.tools}), "
if self.tool_config:
result += f"tool_config={self.tool_config}, "
if self.prompt_id:
result += f"prompt_id={self.prompt_id}, "
if self.version_id:
result += f"version_id={self.version_id}, "
if self.prompt_name:
result += f"prompt_name={self.prompt_name}, "
if self.version_name:
result += f"version_name={self.version_name}, "

# Remove trailing ", "
if result[-2:] == ", ":
result = result[:-2]
result += ")"
return result
Loading

0 comments on commit 56c3f66

Please sign in to comment.