From aa2c6f6468abc7cd3f8f19e806560c34fdc18f37 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Tue, 26 Mar 2024 09:53:42 -0700 Subject: [PATCH] create new http_kwargs, move api_kwars to namedtuple --- Project.toml | 2 +- README.md | 15 +++-- docs/src/index.md | 15 +++-- src/GoogleGenAI.jl | 144 +++++++++++++++++++++++++++++++-------------- test/runtests.jl | 7 ++- 5 files changed, 126 insertions(+), 57 deletions(-) diff --git a/Project.toml b/Project.toml index bed5ade..06511d6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GoogleGenAI" uuid = "903d41d1-eaca-47dd-943b-fee3930375ab" authors = ["Tyler Thomas "] -version = "0.1.0" +version = "0.2.0" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" diff --git a/README.md b/README.md index a025a39..de48dcb 100644 --- a/README.md +++ b/README.md @@ -44,12 +44,13 @@ outputs ``` ```julia -response = generate_content(secret_key, model, prompt; max_output_tokens=10) +api_kwargs = (max_output_tokens=50,) +response = generate_content(secret_key, model, prompt; api_kwargs) println(response.text) ``` outputs ```julia -"Hello! How can I assist you today?" +"Hello there, how may I assist you today?" ``` ```julia @@ -72,6 +73,7 @@ outputs ```julia # Define the provider with your API key (placeholder here) provider = GoogleProvider(api_key=ENV["GOOGLE_API_KEY"]) +api_kwargs = (max_output_tokens=50,) model_name = "gemini-pro" conversation = [ Dict(:role => "user", :parts => [Dict(:text => "When was Julia 1.0 released?")]) @@ -82,7 +84,7 @@ push!(conversation, Dict(:role => "model", :parts => [Dict(:text => response.tex println("Model: ", response.text) push!(conversation, Dict(:role => "user", :parts => [Dict(:text => "Who created the language?")])) -response = generate_content(provider, model_name, conversation, max_output_tokens=100) +response = generate_content(provider, model_name, conversation; api_kwargs) println("Model: ", response.text) ``` outputs @@ -144,6 +146,10 @@ end ``` outputs ```julia +gemini-1.0-pro +gemini-1.0-pro-001 +gemini-1.0-pro-latest +gemini-1.0-pro-vision-latest gemini-pro gemini-pro-vision ``` @@ -163,5 +169,6 @@ safety_settings = [ ] model = "gemini-pro" prompt = "Hello" -response = generate_content(secret_key, model, prompt; safety_settings=safety_settings) +api_kwargs = (safety_settings=safety_settings,) +response = generate_content(secret_key, model, prompt; api_kwargs) ``` diff --git a/docs/src/index.md b/docs/src/index.md index 423e40f..54f89fd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -38,12 +38,13 @@ outputs ``` ```julia -response = generate_content(secret_key, model, prompt; max_output_tokens=10) +api_kwargs = (max_output_tokens=50,) +response = generate_content(secret_key, model, prompt; api_kwargs) println(response.text) ``` outputs ```julia -"Hello! How can I assist you today?" +"Hello there, how may I assist you today?" ``` ```julia @@ -66,6 +67,7 @@ outputs ```julia # Define the provider with your API key (placeholder here) provider = GoogleProvider(api_key=ENV["GOOGLE_API_KEY"]) +api_kwargs = (max_output_tokens=50,) model_name = "gemini-pro" conversation = [ Dict(:role => "user", :parts => [Dict(:text => "When was Julia 1.0 released?")]) @@ -76,7 +78,7 @@ push!(conversation, Dict(:role => "model", :parts => [Dict(:text => response.tex println("Model: ", response.text) push!(conversation, Dict(:role => "user", :parts => [Dict(:text => "Who created the language?")])) -response = generate_content(provider, model_name, conversation, max_output_tokens=100) +response = generate_content(provider, model_name, conversation; api_kwargs) println("Model: ", response.text) ``` outputs @@ -138,6 +140,10 @@ end ``` outputs ```julia +gemini-1.0-pro +gemini-1.0-pro-001 +gemini-1.0-pro-latest +gemini-1.0-pro-vision-latest gemini-pro gemini-pro-vision ``` @@ -157,5 +163,6 @@ safety_settings = [ ] model = "gemini-pro" prompt = "Hello" -response = generate_content(secret_key, model, prompt; safety_settings=safety_settings) +api_kwargs = (safety_settings=safety_settings,) +response = generate_content(secret_key, model, prompt; api_kwargs) ``` diff --git a/src/GoogleGenAI.jl b/src/GoogleGenAI.jl index 40c95f7..8d254c0 100644 --- a/src/GoogleGenAI.jl +++ b/src/GoogleGenAI.jl @@ -35,7 +35,11 @@ function status_error(resp, log=nothing) end function _request( - provider::AbstractGoogleProvider, endpoint::String, method::Symbol, body::Dict + provider::AbstractGoogleProvider, + endpoint::String, + method::Symbol, + body::Dict; + http_kwargs..., ) if isempty(provider.api_key) throw(ArgumentError("api cannot be empty")) @@ -43,7 +47,9 @@ function _request( url = "$(provider.base_url)/$(provider.api_version)/$endpoint?key=$(provider.api_key)" headers = Dict("Content-Type" => "application/json") serialized_body = isempty(body) ? UInt8[] : JSON3.write(body) - response = HTTP.request(method, url; headers=headers, body=serialized_body) + response = HTTP.request( + method, url; headers=headers, body=serialized_body, http_kwargs... + ) if response.status >= 400 status_error(response, String(response.body)) end @@ -78,11 +84,11 @@ end #TODO: Should we use different function names? """ - generate_content(provider::AbstractGoogleProvider, model_name::String, prompt::String, image_path::String; kwargs...) -> NamedTuple - generate_content(api_key::String, model_name::String, prompt::String, image_path::String; kwargs...) -> NamedTuple + generate_content(provider::AbstractGoogleProvider, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple + generate_content(api_key::String, model_name::String, prompt::String, image_path::String; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple - generate_content(provider::AbstractGoogleProvider, model_name::String, conversation::Vector{Dict{Symbol,Any}}; kwargs...) -> NamedTuple - generate_content(api_key::String, model_name::String, conversation::Vector{Dict{Symbol,Any}}; kwargs...) -> NamedTuple + generate_content(provider::AbstractGoogleProvider, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple + generate_content(api_key::String, model_name::String, conversation::Vector{Dict{Symbol,Any}}; api_kwargs=NamedTuple(), https_kwargs=NamedTuple()) -> NamedTuple Generate content based on a combination of text prompt and an image (optional). @@ -93,13 +99,16 @@ Generate content based on a combination of text prompt and an image (optional). - `prompt::String`: The text prompt to accompany the image. - `image_path::String` (optional): The path to the image file to include in the request. -# Keyword Arguments +# API Keyword Arguments - `temperature::Float64` (optional): Controls the randomness in the generation process. Higher values result in more random outputs. Typically ranges between 0 and 1. - `candidate_count::Int` (optional): The number of generation candidates to consider. Currently, only one candidate can be specified. - `max_output_tokens::Int` (optional): The maximum number of tokens that the generated content should contain. - `stop_sequences::Vector{String}` (optional): A list of sequences where the generation should stop. Useful for defining natural endpoints in generated content. - `safety_settings::Vector{Dict}` (optional): Settings to control the safety aspects of the generated content, such as filtering out unsafe or inappropriate content. +# HTTP Kwargs +- All keyword arguments supported by the `HTTP.request` function. Documentation can be found here: https://juliaweb.github.io/HTTP.jl/stable/reference/#HTTP.request. + # Returns - `NamedTuple`: A named tuple containing the following keys: - `candidates`: A vector of dictionaries, each representing a generation candidate. @@ -109,33 +118,42 @@ Generate content based on a combination of text prompt and an image (optional). - `finish_reason`: A string indicating the reason why the generation process was finished. """ function generate_content( - provider::AbstractGoogleProvider, model_name::String, prompt::String; kwargs... + provider::AbstractGoogleProvider, + model_name::String, + prompt::String; + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), ) endpoint = "models/$model_name:generateContent" + safety_settings = get(api_kwargs, :safety_settings, nothing) + generation_config = Dict{String,Any}() - for (key, value) in kwargs + for key in keys(api_kwargs) if key != :safety_settings - generation_config[string(key)] = value + generation_config[string(key)] = getproperty(api_kwargs, key) end end - if haskey(kwargs, :safety_settings) - safety_settings = kwargs[:safety_settings] - else - safety_settings = nothing - end body = Dict( "contents" => [Dict("parts" => [Dict("text" => prompt)])], "generationConfig" => generation_config, "safetySettings" => safety_settings, ) - response = _request(provider, endpoint, :POST, body) + response = _request(provider, endpoint, :POST, body; https_kwargs...) return _parse_response(response) end -function generate_content(api_key::String, model_name::String, prompt::String; kwargs...) - return generate_content(GoogleProvider(; api_key), model_name, prompt; kwargs...) +function generate_content( + api_key::String, + model_name::String, + prompt::String; + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), +) + return generate_content( + GoogleProvider(; api_key), model_name, prompt; api_kwargs, https_kwargs + ) end function generate_content( @@ -143,9 +161,17 @@ function generate_content( model_name::String, prompt::String, image_path::String; - kwargs..., + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), ) image_data = open(base64encode, image_path) + safety_settings = get(api_kwargs, :safety_settings, nothing) + generation_config = Dict{String,Any}() + for key in keys(api_kwargs) + if key != :safety_settings + generation_config[string(key)] = getproperty(api_kwargs, key) + end + end body = Dict( "contents" => [ Dict( @@ -158,19 +184,25 @@ function generate_content( ], ), ], - "generationConfig" => - Dict([string(k) => v for (k, v) in kwargs if k != :safety_settings]), - "safetySettings" => get(kwargs, :safety_settings, nothing), + "generationConfig" => generation_config, + "safetySettings" => safety_settings, ) - response = _request(provider, "models/$model_name:generateContent", :POST, body) + response = _request( + provider, "models/$model_name:generateContent", :POST, body; https_kwargs... + ) return _parse_response(response) end function generate_content( - api_key::String, model_name::String, prompt::String, image_path::String; kwargs... + api_key::String, + model_name::String, + prompt::String, + image_path::String; + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), ) return generate_content( - GoogleProvider(; api_key), model_name, prompt, image_path; kwargs... + GoogleProvider(; api_key), model_name, prompt, image_path; api_kwargs, https_kwargs ) end @@ -178,7 +210,8 @@ function generate_content( provider::AbstractGoogleProvider, model_name::String, conversation::Vector{Dict{Symbol,Any}}; - kwargs..., + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), ) endpoint = "models/$model_name:generateContent" @@ -189,27 +222,33 @@ function generate_content( push!(contents, Dict("role" => role, "parts" => parts)) end + safety_settings = get(api_kwargs, :safety_settings, nothing) generation_config = Dict{String,Any}() - for (key, value) in kwargs + for key in keys(api_kwargs) if key != :safety_settings - generation_config[string(key)] = value + generation_config[string(key)] = getproperty(api_kwargs, key) end end - - safety_settings = get(kwargs, :safety_settings, nothing) + body = Dict( "contents" => contents, "generationConfig" => generation_config, "safetySettings" => safety_settings, ) - response = _request(provider, endpoint, :POST, body) + response = _request(provider, endpoint, :POST, body; https_kwargs) return _parse_response(response) end function generate_content( - api_key::String, model_name::String, conversation::Vector{Dict{Symbol,Any}}; kwargs... + api_key::String, + model_name::String, + conversation::Vector{Dict{Symbol,Any}}; + api_kwargs=NamedTuple(), + https_kwargs=NamedTuple(), ) - return generate_content(GoogleProvider(; api_key), model_name, conversation; kwargs...) + return generate_content( + GoogleProvider(; api_key), model_name, conversation; api_kwargs, https_kwargs + ) end """ @@ -239,10 +278,10 @@ function count_tokens(api_key::String, model_name::String, prompt::String) end """ - embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String) -> NamedTuple - embed_content(api_key::String, model_name::String, prompt::String) -> NamedTuple - embed_content(provider::AbstractGoogleProvider, model_name::String, prompts::Vector{String}) -> Vector{NamedTuple} - embed_content(api_key::String, model_name::String, prompts::Vector{String}) -> Vector{NamedTuple} + embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String https_kwargs=NamedTuple()) -> NamedTuple + embed_content(api_key::String, model_name::String, prompt::String https_kwargs=NamedTuple()) -> NamedTuple + embed_content(provider::AbstractGoogleProvider, model_name::String, prompts::Vector{String} https_kwargs=NamedTuple()) -> Vector{NamedTuple} + embed_content(api_key::String, model_name::String, prompts::Vector{String}, https_kwargs=NamedTuple()) -> Vector{NamedTuple} Generate an embedding for the given prompt text using the specified model. @@ -252,29 +291,42 @@ Generate an embedding for the given prompt text using the specified model. - `model_name::String`: The name of the model to use for generating content. - `prompt::String`: The prompt prompt based on which the text is generated. +# HTTP Kwargs +- All keyword arguments supported by the `HTTP.request` function. Documentation can be found here: https://juliaweb.github.io/HTTP.jl/stable/reference/#HTTP.request. + # Returns - `NamedTuple`: A named tuple containing the following keys: - `values`: A vector of `Float64` representing the embedding values for the given prompt. - `response_status`: An integer representing the HTTP response status code. """ -function embed_content(provider::AbstractGoogleProvider, model_name::String, prompt::String) +function embed_content( + provider::AbstractGoogleProvider, + model_name::String, + prompt::String; + https_kwargs=NamedTuple(), +) endpoint = "models/$model_name:embedContent" body = Dict( "model" => "models/$model_name", "content" => Dict("parts" => [Dict("text" => prompt)]), ) - response = _request(provider, endpoint, :POST, body) + response = _request(provider, endpoint, :POST, body; https_kwargs...) embedding_values = get( get(JSON3.read(response.body), "embedding", Dict()), "values", Vector{Float64}() ) return (values=embedding_values, response_status=response.status) end -function embed_content(api_key::String, model_name::String, prompt::String) - return embed_content(GoogleProvider(; api_key), model_name, prompt) +function embed_content( + api_key::String, model_name::String, prompt::String, https_kwargs=NamedTuple() +) + return embed_content(GoogleProvider(; api_key), model_name, prompt; https_kwargs...) end function embed_content( - provider::AbstractGoogleProvider, model_name::String, prompts::Vector{String} + provider::AbstractGoogleProvider, + model_name::String, + prompts::Vector{String}, + https_kwargs=NamedTuple(), ) endpoint = "models/$model_name:batchEmbedContents" body = Dict( @@ -285,15 +337,17 @@ function embed_content( ) for prompt in prompts ], ) - response = _request(provider, endpoint, :POST, body) + response = _request(provider, endpoint, :POST, body; https_kwargs...) embedding_values = [ get(embedding, "values", Vector{Float64}()) for embedding in JSON3.read(response.body)["embeddings"] ] return (values=embedding_values, response_status=response.status) end -function embed_content(api_key::String, model_name::String, prompts::Vector{String}) - return embed_content(GoogleProvider(; api_key), model_name, prompts) +function embed_content( + api_key::String, model_name::String, prompts::Vector{String}, https_kwargs=NamedTuple() +) + return embed_content(GoogleProvider(; api_key), model_name, prompts; https_kwargs...) end """ diff --git a/test/runtests.jl b/test/runtests.jl index a7ff0ec..ca7c484 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,8 +6,9 @@ if haskey(ENV, "GOOGLE_API_KEY") const secret_key = ENV["GOOGLE_API_KEY"] @testset "GoogleGenAI.jl" begin + api_kwargs = (max_output_tokens=50,) # Generate text from text - response = generate_content(secret_key, "gemini-pro", "Hello"; max_output_tokens=50) + response = generate_content(secret_key, "gemini-pro", "Hello"; api_kwargs) # Generate text from text+image response = generate_content( @@ -15,13 +16,13 @@ if haskey(ENV, "GOOGLE_API_KEY") "gemini-pro-vision", "What is this picture?", "example.jpg"; - max_output_tokens=50, + api_kwargs ) # Multi-turn conversation conversation = [Dict(:role => "user", :parts => [Dict(:text => "Hello")])] response = generate_content( - secret_key, "gemini-pro", conversation; max_output_tokens=50 + secret_key, "gemini-pro", conversation; api_kwargs ) n_tokens = count_tokens(secret_key, "gemini-pro", "Hello")