Skip to content

Commit

Permalink
add basic kwarg support for generate content
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Feb 15, 2024
1 parent 998bdca commit 4ee1c77
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ returns
"Hello there! How may I assist you today? Feel free to ask me any questions you may have or give me a command. I'm here to help! 😊"
```

```julia
response = generate_content(secret_key, model, prompt; max_output_tokens=10)
println(response.text)
```
returns
```julia
"Hello! How can I assist you today?"
```

### Count Tokens
```julia
using GoogleGenAI
Expand Down
18 changes: 12 additions & 6 deletions src/GoogleGenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function _parse_response(response::HTTP.Messages.Response)
return GoogleTextResponse(candidates, safety_rating, concatenated_texts)
end

#TODO: Add support for the following
#TODO: Add Documentation and tests (this is from the python api)
# temperature: The temperature for randomness in generation. Defaults to None.
# candidate_count: The number of candidates to consider. Defaults to None.
# max_output_tokens: The maximum number of output tokens. Defaults to None.
Expand All @@ -48,22 +48,28 @@ end
# safety_settings: Safety settings for generated text. Defaults to None.
# stop_sequences: Stop sequences to halt text generation. Can be a string
# or iterable of strings. Defaults to None.
function generate_content(provider::GoogleProvider, model_name::String, input::String)
function generate_content(provider::GoogleProvider, model_name::String, input::String; kwargs...)
url = "$(provider.base_url)/models/$model_name:generateContent?key=$(provider.api_key)"
body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])])
generation_config = Dict{String,Any}()
for (key, value) in kwargs
generation_config[string(key)] = value
end

body = Dict(
"contents" => [Dict("parts" => [Dict("text" => input)])],
"generationConfig" => generation_config
)
response = HTTP.post(
url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body)
)

if response.status >= 200 && response.status < 300
return _parse_response(response)
else
error("Request failed with status $(response.status): $(String(response.body))")
end
end
function generate_content(api_key::String, model_name::String, input::String)
return generate_content(GoogleProvider(; api_key), model_name, input)
function generate_content(api_key::String, model_name::String, input::String; kwargs...)
return generate_content(GoogleProvider(; api_key), model_name, input; kwargs...)
end

function count_tokens(provider::GoogleProvider, model_name::String, input::String)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Test
const secret_key = ENV["GOOGLE_API_KEY"]

@testset "GoogleGenAI.jl" begin
response = generate_content(secret_key, "gemini-pro", "Hello")
response = generate_content(secret_key, "gemini-pro", "Hello"; max_output_tokens=25)
@test typeof(response) == GoogleGenAI.GoogleTextResponse

n_tokens = count_tokens(secret_key, "gemini-pro", "Hello")
Expand Down

0 comments on commit 4ee1c77

Please sign in to comment.