diff --git a/README.md b/README.md index e59e9f3..0243d18 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,15 @@ returns "Hello there! How may I assist you today? Feel free to ask me any questions you may have or give me a command. I'm here to help! 😊" ``` +```julia +response = generate_content(secret_key, model, prompt; max_output_tokens=10) +println(response.text) +``` +returns +```julia +"Hello! How can I assist you today?" +``` + ### Count Tokens ```julia using GoogleGenAI diff --git a/src/GoogleGenAI.jl b/src/GoogleGenAI.jl index 5537b01..e824c47 100644 --- a/src/GoogleGenAI.jl +++ b/src/GoogleGenAI.jl @@ -39,7 +39,7 @@ function _parse_response(response::HTTP.Messages.Response) return GoogleTextResponse(candidates, safety_rating, concatenated_texts) end -#TODO: Add support for the following +#TODO: Add Documentation and tests (this is from the python api) # temperature: The temperature for randomness in generation. Defaults to None. # candidate_count: The number of candidates to consider. Defaults to None. # max_output_tokens: The maximum number of output tokens. Defaults to None. @@ -48,22 +48,28 @@ end # safety_settings: Safety settings for generated text. Defaults to None. # stop_sequences: Stop sequences to halt text generation. Can be a string # or iterable of strings. Defaults to None. -function generate_content(provider::GoogleProvider, model_name::String, input::String) +function generate_content(provider::GoogleProvider, model_name::String, input::String; kwargs...) url = "$(provider.base_url)/models/$model_name:generateContent?key=$(provider.api_key)" - body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])]) + generation_config = Dict{String,Any}() + for (key, value) in kwargs + generation_config[string(key)] = value + end + body = Dict( + "contents" => [Dict("parts" => [Dict("text" => input)])], + "generationConfig" => generation_config + ) response = HTTP.post( url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body) ) - if response.status >= 200 && response.status < 300 return _parse_response(response) else error("Request failed with status $(response.status): $(String(response.body))") end end -function generate_content(api_key::String, model_name::String, input::String) - return generate_content(GoogleProvider(; api_key), model_name, input) +function generate_content(api_key::String, model_name::String, input::String; kwargs...) + return generate_content(GoogleProvider(; api_key), model_name, input; kwargs...) end function count_tokens(provider::GoogleProvider, model_name::String, input::String) diff --git a/test/runtests.jl b/test/runtests.jl index 92cc95f..da0fc14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Test const secret_key = ENV["GOOGLE_API_KEY"] @testset "GoogleGenAI.jl" begin - response = generate_content(secret_key, "gemini-pro", "Hello") + response = generate_content(secret_key, "gemini-pro", "Hello"; max_output_tokens=25) @test typeof(response) == GoogleGenAI.GoogleTextResponse n_tokens = count_tokens(secret_key, "gemini-pro", "Hello")