Skip to content

Commit

Permalink
add the ability to generate content from text+images
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Feb 24, 2024
1 parent 9e96ae3 commit 6481a7c
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 32 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ authors = ["Tyler Thomas <[email protected]>"]
version = "0.1.0"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"

[compat]
Aqua = "0.8"
Base64 = "1"
Dates = "1"
HTTP = "1"
JSON3 = "1"
Expand Down
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ returns
"Hello! How can I assist you today?"
```

```julia
using GoogleGenAI

secret_key = ENV["GOOGLE_API_KEY"]
model = "gemini-pro-vision"
prompt = "What is this image?"
image_path = "test/example.jpg"
response = generate_content(secret_key, model, prompt, image_path)
println(response.text)
```
returns
```
"The logo for the Julia programming language."
```

### Count Tokens
```julia
using GoogleGenAI
Expand Down
105 changes: 74 additions & 31 deletions src/GoogleGenAI.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module GoogleGenAI

using Base64
using JSON3
using HTTP

Expand All @@ -11,7 +12,7 @@ abstract type AbstractGoogleProvider end
base_url::String = "https://generativelanguage.googleapis.com"
api_version::String = "v1beta"
end
A configuration object used to set up and authenticate requests to the Google Generative Language API.
# Fields
Expand Down Expand Up @@ -80,16 +81,17 @@ function _parse_response(response::HTTP.Messages.Response)
end

"""
generate_content(provider::GooglePAbstractGoogleProviderrovider, model_name::String, input::String; kwargs...) -> GoogleTextResponse
generate_content(api_key::String, model_name::String, input::String; kwargs...) -> GoogleTextResponse
generate_content(provider::AbstractGoogleProvider, model_name::String, prompt::String, image_path::String; kwargs...) -> GoogleTextResponse
generate_content(api_key::String, model_name::String, prompt::String, image_path::String; kwargs...) -> GoogleTextResponse
Generate text using the specified model.
Generate content based on a combination of text prompt and an image (optional).
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `provider::AbstractGoogleProvider`: The provider instance for API requests.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
- `model_name::String`: The model to use for content generation.
- `prompt::String`: The text prompt to accompany the image.
- `image_path::String` (optional): The path to the image file to include in the request.
# Keyword Arguments
- `temperature::Float64` (optional): Controls the randomness in the generation process. Higher values result in more random outputs. Typically ranges between 0 and 1.
Expand All @@ -99,10 +101,10 @@ Generate text using the specified model.
- `safety_settings::Vector{Dict}` (optional): Settings to control the safety aspects of the generated content, such as filtering out unsafe or inappropriate content.
# Returns
- `GoogleTextResponse`
- `GoogleTextResponse`: The generated content response.
"""
function generate_content(
provider::AbstractGoogleProvider, model_name::String, input::String; kwargs...
provider::AbstractGoogleProvider, model_name::String, prompt::String; kwargs...
)
endpoint = "models/$model_name:generateContent"

Expand All @@ -119,74 +121,114 @@ function generate_content(
safety_settings = nothing
end
body = Dict(
"contents" => [Dict("parts" => [Dict("text" => input)])],
"contents" => [Dict("parts" => [Dict("text" => prompt)])],
"generationConfig" => generation_config,
"safetySettings" => safety_settings,
)

response = _request(provider, endpoint, :POST, body)
return _parse_response(response)
end
function generate_content(api_key::String, model_name::String, input::String; kwargs...)
return generate_content(GoogleProvider(; api_key), model_name, input; kwargs...)
function generate_content(api_key::String, model_name::String, prompt::String; kwargs...)
return generate_content(GoogleProvider(; api_key), model_name, prompt; kwargs...)
end

function generate_content(
provider::AbstractGoogleProvider,
model_name::String,
prompt::String,
image_path::String;
kwargs...,
)
# Correctly encode the image to Base64
image_data = open(base64encode, image_path)

# Construct the request body
body = Dict(
"contents" => [
Dict(
"parts" => [
Dict("text" => prompt),
Dict(
"inline_data" =>
Dict("mime_type" => "image/jpeg", "data" => image_data),
),
],
),
],
# Include other generation configurations from kwargs
"generationConfig" =>
Dict([string(k) => v for (k, v) in kwargs if k != :safety_settings]),
"safetySettings" => get(kwargs, :safety_settings, nothing),
)

response = _request(provider, "models/$model_name:generateContent", :POST, body)
return _parse_response(response)
end
function generate_content(
api_key::String, model_name::String, prompt::String, image_path::String; kwargs...
)
return generate_content(
GoogleProvider(; api_key), model_name, prompt, image_path; kwargs...
)
end

"""
count_tokens(provider::AbstractGoogleProvider, model_name::String, input::String) -> Int
count_tokens(api_key::String, model_name::String, input::String) -> Int
count_tokens(provider::AbstractGoogleProvider, model_name::String, prompt::String) -> Int
count_tokens(api_key::String, model_name::String, prompt::String) -> Int
Calculate the number of tokens generated by the specified model for a given input string.
Calculate the number of tokens generated by the specified model for a given prompt string.
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
- `prompt::String`: The prompt prompt based on which the text is generated.
# Returns
- `Int`: The total number of tokens that the given input string would be broken into by the specified model's tokenizer.
- `Int`: The total number of tokens that the given prompt string would be broken into by the specified model's tokenizer.
"""
function count_tokens(provider::AbstractGoogleProvider, model_name::String, input::String)
function count_tokens(provider::AbstractGoogleProvider, model_name::String, prompt::String)
endpoint = "models/$model_name:countTokens"
body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])])
body = Dict("contents" => [Dict("parts" => [Dict("text" => prompt)])])
response = _request(provider, endpoint, :POST, body)
total_tokens = get(JSON3.read(response.body), "totalTokens", 0)
return total_tokens
end
function count_tokens(api_key::String, model_name::String, input::String)
return count_tokens(GoogleProvider(; api_key), model_name, input)
function count_tokens(api_key::String, model_name::String, prompt::String)
return count_tokens(GoogleProvider(; api_key), model_name, prompt)
end

#TODO: Do we want an embeddings struct, or just the array of embeddings?
"""
embed_content(provider::AbstractGoogleProvider, model_name::String, input::String) -> GoogleEmbeddingResponse
embed_content(api_key::String, model_name::String, input::String) -> GoogleEmbeddingResponse
embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String) -> GoogleEmbeddingResponse
embed_content(api_key::String, model_name::String, prompt::String) -> GoogleEmbeddingResponse
Generate an embedding for the given input text using the specified model.
Generate an embedding for the given prompt text using the specified model.
# Arguments
- `provider::AbstractGoogleProvider`: The provider instance containing API key and base URL information.
- `api_key::String`: Your Google API key as a string.
- `model_name::String`: The name of the model to use for generating content.
- `input::String`: The input prompt based on which the text is generated.
- `prompt::String`: The prompt prompt based on which the text is generated.
# Returns
- `GoogleEmbeddingResponse`
"""
function embed_content(provider::AbstractGoogleProvider, model_name::String, input::String)
function embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String)
endpoint = "models/$model_name:embedContent"
body = Dict(
"model" => "models/$model_name",
"content" => Dict("parts" => [Dict("text" => input)]),
"content" => Dict("parts" => [Dict("text" => prompt)]),
)
response = _request(provider, endpoint, :POST, body)
embedding_values = get(
get(JSON3.read(response.body), "embedding", Dict()), "values", Vector{Float64}()
)
return GoogleEmbeddingResponse(embedding_values, response.status)
end
function embed_content(api_key::String, model_name::String, input::String)
return embed_content(GoogleProvider(; api_key), model_name, input)
function embed_content(api_key::String, model_name::String, prompt::String)
return embed_content(GoogleProvider(; api_key), model_name, prompt)
end

"""
Expand All @@ -212,7 +254,7 @@ function list_models(provider::AbstractGoogleProvider)
:version => model.version,
:display_name => model.displayName,
:description => model.description,
:input_token_limit => model.inputTokenLimit,
:prompt_token_limit => model.inputTokenLimit,
:output_token_limit => model.outputTokenLimit,
:supported_generation_methods => model.supportedGenerationMethods,
:temperature => get(model, :temperature, nothing),
Expand All @@ -224,6 +266,7 @@ function list_models(provider::AbstractGoogleProvider)
end
list_models(api_key::String) = list_models(GoogleProvider(; api_key))

export GoogleProvider, generate_content, count_tokens, embed_content, list_models
export GoogleProvider,
generate_content, count_tokens, embed_content, list_models

end # module GoogleGenAI
Binary file added test/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@ using Test
const secret_key = ENV["GOOGLE_API_KEY"]

@testset "GoogleGenAI.jl" begin
response = generate_content(secret_key, "gemini-pro", "Hello"; max_output_tokens=25)
response = generate_content(secret_key, "gemini-pro", "Hello"; max_output_tokens=50)
@test typeof(response) == GoogleGenAI.GoogleTextResponse

response = generate_content(
secret_key,
"gemini-pro-vision",
"What is this picture?",
"example.jpg";
max_output_tokens=50,
)
@test typeof(response) == GoogleGenAI.GoogleTextResponse

n_tokens = count_tokens(secret_key, "gemini-pro", "Hello")
Expand Down

0 comments on commit 6481a7c

Please sign in to comment.