Skip to content

Commit

Permalink
add list_models test, revise output formatting. Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Feb 12, 2024
1 parent 9eb3076 commit 2f06b50
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 32 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end
```
returns
```julia
models/gemini-pro
models/gemini-pro-vision
gemini-pro
gemini-pro-vision
```

75 changes: 45 additions & 30 deletions src/GoogleGenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Base.@kwdef struct GoogleProvider
end

struct GoogleTextResponse
candidates::Vector{Dict{Symbol, Any}}
safety_ratings::Dict{Pair{Symbol, String}, Pair{Symbol, String}}
candidates::Vector{Dict{Symbol,Any}}
safety_ratings::Dict{Pair{Symbol,String},Pair{Symbol,String}}
text::String
end

Expand All @@ -22,7 +22,7 @@ end
struct BlockedPromptException <: Exception end

function _extract_text(response::JSON3.Object)
all_texts = String[]
all_texts = String[]
for candidate in response.candidates
candidate_text = join([part.text for part in candidate.content.parts], "")
push!(all_texts, candidate_text)
Expand Down Expand Up @@ -52,78 +52,93 @@ function generate_content(provider::GoogleProvider, model_name::String, input::S
url = "$(provider.base_url)/models/$model_name:generateContent?key=$(provider.api_key)"
body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])])

response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body))

response = HTTP.post(
url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body)
)

if response.status >= 200 && response.status < 300
return _parse_response(response)
else
error("Request failed with status $(response.status): $(String(response.body))")
end
end
generate_content(api_key::String, model_name::String, input::String) = generate_content(GoogleProvider(; api_key), model_name, input)
function generate_content(api_key::String, model_name::String, input::String)
return generate_content(GoogleProvider(; api_key), model_name, input)
end

function count_tokens(provider::GoogleProvider, model_name::String, input::String)
url = "$(provider.base_url)/models/$model_name:countTokens?key=$(provider.api_key)"
body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])])
response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body))

response = HTTP.post(
url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body)
)

if response.status >= 200 && response.status < 300
parsed_response = JSON3.read(response.body)
total_tokens = get(parsed_response, "totalTokens")
total_tokens = get(parsed_response, "totalTokens")
return total_tokens
else
error("Request failed with status $(response.status): $(String(response.body))")
end
end
count_tokens(api_key::String, model_name::String, input::String) = count_tokens(GoogleProvider(; api_key), model_name, input)
function count_tokens(api_key::String, model_name::String, input::String)
return count_tokens(GoogleProvider(; api_key), model_name, input)
end

#TODO: Do we want an embeddings struct, or just the array of embeddings?
function embed_content(provider::GoogleProvider, model_name::String, input::String)
url = "$(provider.base_url)/models/$model_name:embedContent?key=$(provider.api_key)"
body = Dict(
"model" => "models/$model_name",
"content" => Dict("parts" => [Dict("text" => input)])
"content" => Dict("parts" => [Dict("text" => input)]),
)
response = HTTP.post(
url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body)
)
response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body))

if response.status >= 200 && response.status < 300
parsed_response = JSON3.read(response.body)
embedding_values = get(get(parsed_response, "embedding", Dict()), "values", Vector{Float64}())
embedding_values = get(
get(parsed_response, "embedding", Dict()), "values", Vector{Float64}()
)
return GoogleEmbeddingResponse(embedding_values)
else
error("Request failed with status $(response.status): $(String(response.body))")
end
end
embed_content(api_key::String, model_name::String, input::String) = embed_content(GoogleProvider(; api_key), model_name, input)
function embed_content(api_key::String, model_name::String, input::String)
return embed_content(GoogleProvider(; api_key), model_name, input)
end

function list_models(provider::GoogleProvider)
url = "$(provider.base_url)/models?key=$(provider.api_key)"

response = HTTP.get(url, headers = Dict("Content-Type" => "application/json"))
response = HTTP.get(url; headers=Dict("Content-Type" => "application/json"))

if response.status >= 200 && response.status < 300
parsed_response = JSON3.read(response.body)
models = [Dict(
:name => model.name,
:version => model.version,
:display_name => model.displayName,
:description => model.description,
:input_token_limit => model.inputTokenLimit,
:output_token_limit => model.outputTokenLimit,
:supported_generation_methods => model.supportedGenerationMethods,
:temperature => get(model, :temperature, nothing),
:topP => get(model, :topP, nothing),
:topK => get(model, :topK, nothing)
) for model in parsed_response.models]
models = [
Dict(
:name => replace(model.name, "models/" => ""),
:version => model.version,
:display_name => model.displayName,
:description => model.description,
:input_token_limit => model.inputTokenLimit,
:output_token_limit => model.outputTokenLimit,
:supported_generation_methods => model.supportedGenerationMethods,
:temperature => get(model, :temperature, nothing),
:topP => get(model, :topP, nothing),
:topK => get(model, :topK, nothing),
) for model in parsed_response.models
]
return models
else
error("Request failed with status $(response.status): $(String(response.body))")
end
end
list_models(api_key::String) = list_models(GoogleProvider(; api_key))



export GoogleProvider, GoogleResponse, generate_content, count_tokens, embed_content, list_models
export GoogleProvider,
GoogleResponse, generate_content, count_tokens, embed_content, list_models

end # module GoogleGenAI
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ const secret_key = ENV["GOOGLE_API_KEY"]
embeddings = embed_content(secret_key, "embedding-001", "Hello")
@test typeof(embeddings) == GoogleGenAI.GoogleEmbeddingResponse
@test size(embeddings.values) == (768,)

models = list_models(secret_key)
@test length(models) > 0
@test haskey(models[1], :name)
end

Aqua.test_all(GoogleGenAI)

0 comments on commit 2f06b50

Please sign in to comment.