-
Notifications
You must be signed in to change notification settings - Fork 179
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
652 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# Use pgai with Gemini | ||
|
||
This page shows you how to: | ||
|
||
- [Configure pgai for Gemini](#configure-pgai-for-gemini) | ||
- [Add AI functionality to your database](#usage) | ||
|
||
## Configure pgai for Gemini | ||
|
||
Gemini functions in pgai require an [Gemini API key](https://ai.google.dev/gemini-api/docs/api-key). | ||
|
||
- [Handle API keys using pgai from psql](#handle-api-keys-using-pgai-from-psql) | ||
- [Handle API keys using pgai from python](#handle-api-keys-using-pgai-from-python) | ||
|
||
### Handle API keys using pgai from psql | ||
|
||
The api key is an [optional parameter to pgai functions](https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html). | ||
You can either: | ||
|
||
* [Run AI queries by passing your API key implicitly as a session parameter](#run-ai-queries-by-passing-your-api-key-implicitly-as-a-session-parameter) | ||
* [Run AI queries by passing your API key explicitly as a function argument](#run-ai-queries-by-passing-your-api-key-explicitly-as-a-function-argument) | ||
|
||
#### Run AI queries by passing your API key implicitly as a session parameter | ||
|
||
To use a [session level parameter when connecting to your database with psql](https://www.postgresql.org/docs/current/config-setting.html#CONFIG-SETTING-SHELL) | ||
to run your AI queries: | ||
|
||
1. Set your Gemini key as an environment variable in your shell: | ||
```bash | ||
export ANTHROPIC_API_KEY="this-is-my-super-secret-api-key-dont-tell" | ||
``` | ||
1. Use the session level parameter when you connect to your database: | ||
|
||
```bash | ||
PGOPTIONS="-c ai.gemini_api_key=$GEMINI_API_KEY" psql -d "postgres://<username>:<password>@<host>:<port>/<database-name>" | ||
``` | ||
|
||
1. Run your AI query: | ||
|
||
`ai.gemini_api_key` is set for the duration of your psql session, you do not need to specify it for pgai functions. | ||
|
||
```sql | ||
select ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
); | ||
``` | ||
#### Run AI queries by passing your API key explicitly as a function argument | ||
1. Set your Gemini key as an environment variable in your shell: | ||
```bash | ||
export ANTHROPIC_API_KEY="this-is-my-super-secret-api-key-dont-tell" | ||
``` | ||
2. Connect to your database and set your api key as a [psql variable](https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-VARIABLES): | ||
```bash | ||
psql -d "postgres://<username>:<password>@<host>:<port>/<database-name>" -v gemini_api_key=$ANTHROPIC_API_KEY | ||
``` | ||
Your API key is now available as a psql variable named `gemini_api_key` in your psql session. | ||
You can also log into the database, then set `gemini_api_key` using the `\getenv` [metacommand](https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-META-COMMAND-GETENV): | ||
```sql | ||
\getenv gemini_api_key ANTHROPIC_API_KEY | ||
``` | ||
4. Pass your API key to your parameterized query: | ||
```sql | ||
SELECT ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
, api_key=>$1 | ||
) AS actual | ||
\bind :gemini_api_key | ||
\g | ||
``` | ||
Use [\bind](https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-META-COMMAND-BIND) to pass the value of `gemini_api_key` to the parameterized query. | ||
The `\bind` metacommand is available in psql version 16+. | ||
### Handle API keys using pgai from python | ||
1. In your Python environment, include the dotenv and postgres driver packages: | ||
```bash | ||
pip install python-dotenv | ||
pip install psycopg2-binary | ||
``` | ||
2. Set your Gemini API key in a .env file or as an environment variable: | ||
```bash | ||
ANTHROPIC_API_KEY="this-is-my-super-secret-api-key-dont-tell" | ||
DB_URL="your connection string" | ||
``` | ||
3. Pass your API key as a parameter to your queries: | ||
```python | ||
import os | ||
from dotenv import load_dotenv | ||
load_dotenv() | ||
ANTHROPIC_API_KEY = os.environ["ANTHROPIC_API_KEY"] | ||
DB_URL = os.environ["DB_URL"] | ||
import psycopg2 | ||
from psycopg2.extras import Json | ||
messages = [{'role': 'user', 'content': 'Name five famous people from Birmingham, Alabama.'}] | ||
with psycopg2.connect(DB_URL) as conn: | ||
with conn.cursor() as cur: | ||
# pass the API key as a parameter to the query. don't use string manipulations | ||
cur.execute(""" | ||
SELECT ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, %s | ||
, api_key=>%s | ||
) | ||
""", (Json(messages), ANTHROPIC_API_KEY)) | ||
records = cur.fetchall() | ||
``` | ||
Do not use string manipulation to embed the key as a literal in the SQL query. | ||
## Usage | ||
This section shows you how to use AI directly from your database using SQL. | ||
- [Generate](#generate): generate a response to a prompt | ||
### Generate | ||
[Generate a response for the prompt provided](https://ai.google.dev/api?lang=python): | ||
```sql | ||
-- the following two metacommands cause the raw query results to be printed | ||
-- without any decoration | ||
\pset tuples_only on | ||
\pset format unaligned | ||
select jsonb_extract_path_text | ||
( | ||
ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
) | ||
, 'content', '0', 'text' | ||
); | ||
``` | ||
The data returned looks like: | ||
```text | ||
Here are five famous people from Birmingham, Alabama: | ||
1. Condoleezza Rice - Former U.S. Secretary of State and National Security Advisor | ||
2. Courteney Cox - Actress, best known for her role as Monica Geller on the TV show "Friends" | ||
3. Charles Barkley - Former NBA player and current television analyst | ||
4. Vonetta Flowers - Olympic gold medalist in bobsledding, the first African American to win a gold medal at the Winter Olympics | ||
5. Carl Lewis - Olympic track and field athlete who won nine gold medals across four Olympic Games | ||
These individuals have made significant contributions in various fields, including politics, entertainment, sports, and athletics, and have helped put Birmingham, Alabama on the map in their respective areas. | ||
``` | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
from typing import Optional | ||
import google.generativeai as genai | ||
from google.generativeai import GenerativeModel as Gemini | ||
|
||
DEFAULT_KEY_NAME = "GEMINI_API_KEY" | ||
|
||
|
||
def make_client( | ||
api_key: str, | ||
base_url: Optional[str] = None, | ||
timeout: Optional[float] = None, | ||
max_retries: Optional[int] = None, | ||
) -> Gemini: | ||
args = {} | ||
if timeout is not None: | ||
args["timeout"] = timeout | ||
if max_retries is not None: | ||
args["max_retries"] = max_retries | ||
genai.configure(api_key=api_key) | ||
return Gemini(**args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
------------------------------------------------------------------------------- | ||
-- gemini_generate | ||
-- https://ai.google.dev/api | ||
create or replace function ai.gemini_generate | ||
( model text | ||
, messages jsonb | ||
, max_tokens int default 1024 | ||
, api_key text default null | ||
, api_key_name text default null | ||
, base_url text default null | ||
, timeout float8 default null | ||
, max_retries int default null | ||
, system_prompt text default null | ||
, user_id text default null | ||
, stop_sequences text[] default null | ||
, temperature float8 default null | ||
, tool_choice jsonb default null | ||
, tools jsonb default null | ||
, top_k int default null | ||
, top_p float8 default null | ||
) returns jsonb | ||
as $python$ | ||
#ADD-PYTHON-LIB-DIR | ||
import ai.gemini | ||
import ai.secrets | ||
api_key_resolved = ai.secrets.get_secret(plpy, api_key, api_key_name, ai.gemini.DEFAULT_KEY_NAME, SD) | ||
client = ai.gemini.make_client(api_key=api_key_resolved, base_url=base_url, timeout=timeout, max_retries=max_retries) | ||
|
||
import json | ||
messages_1 = json.loads(messages) | ||
|
||
args = {} | ||
if system_prompt is not None: | ||
args["system"] = system_prompt | ||
if user_id is not None: | ||
args["metadata"] = {"user_id", user_id} | ||
if stop_sequences is not None: | ||
args["stop_sequences"] = stop_sequences | ||
if temperature is not None: | ||
args["temperature"] = temperature | ||
if tool_choice is not None: | ||
args["tool_choice"] = json.loads(tool_choice) | ||
if tools is not None: | ||
args["tools"] = json.loads(tools) | ||
if top_k is not None: | ||
args["top_k"] = top_k | ||
if top_p is not None: | ||
args["top_p"] = top_p | ||
|
||
message = client.messages.create(model=model, messages=messages_1, max_tokens=max_tokens, **args) | ||
return message.to_json() | ||
$python$ | ||
language plpython3u volatile parallel safe security invoker | ||
set search_path to pg_catalog, pg_temp | ||
; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
|
||
import psycopg | ||
import pytest | ||
|
||
|
||
# skip tests in this module if disabled | ||
enable_gemini_tests = os.getenv("GEMINI_API_KEY") | ||
if not enable_gemini_tests or enable_gemini_tests == "0": | ||
pytest.skip(allow_module_level=True) | ||
|
||
|
||
@pytest.fixture() | ||
def gemini_api_key() -> str: | ||
gemini_api_key = os.environ["GEMINI_API_KEY"] | ||
return gemini_api_key | ||
|
||
|
||
@pytest.fixture() | ||
def cur() -> psycopg.Cursor: | ||
with psycopg.connect("postgres://[email protected]:5432/test") as con: | ||
with con.cursor() as cur: | ||
yield cur | ||
|
||
|
||
@pytest.fixture() | ||
def cur_with_api_key(gemini_api_key, cur) -> psycopg.Cursor: | ||
with cur: | ||
cur.execute( | ||
"select set_config('ai.gemini_api_key', %s, false) is not null", | ||
(gemini_api_key,), | ||
) | ||
yield cur | ||
|
||
|
||
@pytest.fixture() | ||
def cur_with_external_functions_executor_url(cur) -> psycopg.Cursor: | ||
with cur: | ||
cur.execute( | ||
"select set_config('ai.external_functions_executor_url', 'http://localhost:8000', false) is not null", | ||
) | ||
yield cur | ||
|
||
|
||
def test_gemini_generate(cur, gemini_api_key): | ||
cur.execute( | ||
""" | ||
with x as | ||
( | ||
select ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
, api_key=>%s | ||
) as actual | ||
) | ||
select jsonb_extract_path_text(x.actual, 'content', '0', 'text') is not null | ||
and x.actual->>'stop_reason' = 'end_turn' | ||
from x | ||
""", | ||
(gemini_api_key,), | ||
) | ||
actual = cur.fetchone()[0] | ||
assert actual is True | ||
|
||
|
||
def test_gemini_generate_api_key_name(cur_with_external_functions_executor_url): | ||
cur_with_external_functions_executor_url.execute( | ||
""" | ||
with x as | ||
( | ||
select ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
, api_key_name=> 'GEMINI_API_KEY_REAL' | ||
) as actual | ||
) | ||
select jsonb_extract_path_text(x.actual, 'content', '0', 'text') is not null | ||
and x.actual->>'stop_reason' = 'end_turn' | ||
from x | ||
""" | ||
) | ||
actual = cur_with_external_functions_executor_url.fetchone()[0] | ||
assert actual is True | ||
|
||
|
||
def test_gemini_generate_no_key(cur_with_api_key): | ||
cur_with_api_key.execute(""" | ||
with x as | ||
( | ||
select ai.gemini_generate | ||
( 'gemini-1.5-flash' | ||
, jsonb_build_array | ||
( jsonb_build_object | ||
( 'role', 'user' | ||
, 'content', 'Name five famous people from Birmingham, Alabama.' | ||
) | ||
) | ||
) as actual | ||
) | ||
select jsonb_extract_path_text(x.actual, 'content', '0', 'text') is not null | ||
and x.actual->>'stop_reason' = 'end_turn' | ||
from x | ||
""") | ||
actual = cur_with_api_key.fetchone()[0] | ||
assert actual is True |
Oops, something went wrong.