diff --git a/README.md b/README.md index 98f6461..e59e9f3 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ end ``` returns ```julia -models/gemini-pro -models/gemini-pro-vision +gemini-pro +gemini-pro-vision ``` diff --git a/src/GoogleGenAI.jl b/src/GoogleGenAI.jl index cf605da..5537b01 100644 --- a/src/GoogleGenAI.jl +++ b/src/GoogleGenAI.jl @@ -9,8 +9,8 @@ Base.@kwdef struct GoogleProvider end struct GoogleTextResponse - candidates::Vector{Dict{Symbol, Any}} - safety_ratings::Dict{Pair{Symbol, String}, Pair{Symbol, String}} + candidates::Vector{Dict{Symbol,Any}} + safety_ratings::Dict{Pair{Symbol,String},Pair{Symbol,String}} text::String end @@ -22,7 +22,7 @@ end struct BlockedPromptException <: Exception end function _extract_text(response::JSON3.Object) - all_texts = String[] + all_texts = String[] for candidate in response.candidates candidate_text = join([part.text for part in candidate.content.parts], "") push!(all_texts, candidate_text) @@ -52,69 +52,85 @@ function generate_content(provider::GoogleProvider, model_name::String, input::S url = "$(provider.base_url)/models/$model_name:generateContent?key=$(provider.api_key)" body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])]) - response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body)) - + response = HTTP.post( + url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body) + ) + if response.status >= 200 && response.status < 300 return _parse_response(response) else error("Request failed with status $(response.status): $(String(response.body))") end end -generate_content(api_key::String, model_name::String, input::String) = generate_content(GoogleProvider(; api_key), model_name, input) +function generate_content(api_key::String, model_name::String, input::String) + return generate_content(GoogleProvider(; api_key), model_name, input) +end function count_tokens(provider::GoogleProvider, model_name::String, input::String) url = "$(provider.base_url)/models/$model_name:countTokens?key=$(provider.api_key)" body = Dict("contents" => [Dict("parts" => [Dict("text" => input)])]) - response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body)) - + response = HTTP.post( + url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body) + ) + if response.status >= 200 && response.status < 300 parsed_response = JSON3.read(response.body) - total_tokens = get(parsed_response, "totalTokens") + total_tokens = get(parsed_response, "totalTokens") return total_tokens else error("Request failed with status $(response.status): $(String(response.body))") end end -count_tokens(api_key::String, model_name::String, input::String) = count_tokens(GoogleProvider(; api_key), model_name, input) +function count_tokens(api_key::String, model_name::String, input::String) + return count_tokens(GoogleProvider(; api_key), model_name, input) +end #TODO: Do we want an embeddings struct, or just the array of embeddings? function embed_content(provider::GoogleProvider, model_name::String, input::String) url = "$(provider.base_url)/models/$model_name:embedContent?key=$(provider.api_key)" body = Dict( "model" => "models/$model_name", - "content" => Dict("parts" => [Dict("text" => input)]) + "content" => Dict("parts" => [Dict("text" => input)]), + ) + response = HTTP.post( + url; headers=Dict("Content-Type" => "application/json"), body=JSON3.write(body) ) - response = HTTP.post(url, headers = Dict("Content-Type" => "application/json"), body = JSON3.write(body)) if response.status >= 200 && response.status < 300 parsed_response = JSON3.read(response.body) - embedding_values = get(get(parsed_response, "embedding", Dict()), "values", Vector{Float64}()) + embedding_values = get( + get(parsed_response, "embedding", Dict()), "values", Vector{Float64}() + ) return GoogleEmbeddingResponse(embedding_values) else error("Request failed with status $(response.status): $(String(response.body))") end end -embed_content(api_key::String, model_name::String, input::String) = embed_content(GoogleProvider(; api_key), model_name, input) +function embed_content(api_key::String, model_name::String, input::String) + return embed_content(GoogleProvider(; api_key), model_name, input) +end function list_models(provider::GoogleProvider) url = "$(provider.base_url)/models?key=$(provider.api_key)" - response = HTTP.get(url, headers = Dict("Content-Type" => "application/json")) + response = HTTP.get(url; headers=Dict("Content-Type" => "application/json")) if response.status >= 200 && response.status < 300 parsed_response = JSON3.read(response.body) - models = [Dict( - :name => model.name, - :version => model.version, - :display_name => model.displayName, - :description => model.description, - :input_token_limit => model.inputTokenLimit, - :output_token_limit => model.outputTokenLimit, - :supported_generation_methods => model.supportedGenerationMethods, - :temperature => get(model, :temperature, nothing), - :topP => get(model, :topP, nothing), - :topK => get(model, :topK, nothing) - ) for model in parsed_response.models] + models = [ + Dict( + :name => replace(model.name, "models/" => ""), + :version => model.version, + :display_name => model.displayName, + :description => model.description, + :input_token_limit => model.inputTokenLimit, + :output_token_limit => model.outputTokenLimit, + :supported_generation_methods => model.supportedGenerationMethods, + :temperature => get(model, :temperature, nothing), + :topP => get(model, :topP, nothing), + :topK => get(model, :topK, nothing), + ) for model in parsed_response.models + ] return models else error("Request failed with status $(response.status): $(String(response.body))") @@ -122,8 +138,7 @@ function list_models(provider::GoogleProvider) end list_models(api_key::String) = list_models(GoogleProvider(; api_key)) - - -export GoogleProvider, GoogleResponse, generate_content, count_tokens, embed_content, list_models +export GoogleProvider, + GoogleResponse, generate_content, count_tokens, embed_content, list_models end # module GoogleGenAI diff --git a/test/runtests.jl b/test/runtests.jl index 23e9bf4..92cc95f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,10 @@ const secret_key = ENV["GOOGLE_API_KEY"] embeddings = embed_content(secret_key, "embedding-001", "Hello") @test typeof(embeddings) == GoogleGenAI.GoogleEmbeddingResponse @test size(embeddings.values) == (768,) + + models = list_models(secret_key) + @test length(models) > 0 + @test haskey(models[1], :name) end Aqua.test_all(GoogleGenAI)